Skip to content

live-enroll.c

This is the source code for the live-enroll command-line tool.

Instructions

See live-enroll.

Code

Available in this TrulyNatural SDK installation at ~/Sensory/TrulyNaturalSDK/7.6.1/sample/c/src/live-enroll.c

live-enroll.c

/* Sensory Confidential
 * Copyright (C)2016-2025 Sensory, Inc. https://sensory.com/
 *
 * TrulyHandsfree SDK keyword spotting command-line enrollment utility,
 * using live audio from the default capture source.
 *------------------------------------------------------------------------------
 */

#include <snsr.h>

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define DEFAULT_OUT  "enrolled-sv.snsr"
#define ENROLL_TASK_VERSION "~0.8.0 || 1.0.0"

#ifdef _MSC_VER
#  define strdup _strdup
#  if _MSC_VER < 1900
#    define snprintf _snprintf
#  endif
#endif


typedef struct {
  const char *enroll;   /* optional enrollment context file name    */
  const char *model;    /* enrolled phrase spotter output file name */
  const char *prefix;   /* audio capture file name prefix           */
  const char **user;    /* pointer to users in argv[]               */
  char *phrase;         /* enrollment phrase                        */
  SnsrStream audio;     /* audio stream handle                      */
  int verbosity;        /* incremented by the -v flag               */
  int userCount;        /* number of users to enroll                */
  int currentUser;      /* current user index                       */
  int failCount;        /* number of failed enrollment attempts     */
} EnrollContext;


static SnsrRC
saveEnrollmentAudio(SnsrSession s, EnrollContext *e, const char *tag, int id)
{
  SnsrStream out, enrollment;
  SnsrRC r;
  const char *dash, *user = NULL;
  const char *format = "%s%s%s-%s-%i.wav";
  char *tmp;
  int len;

  dash = *e->prefix? "-": "";
  snsrGetString(s, SNSR_USER, &user);
  r = snsrGetStream(s, SNSR_AUDIO_STREAM, &enrollment);
  if (r != SNSR_RC_OK) return r;
  if (snsrStreamRC(enrollment) != SNSR_RC_OK) return SNSR_RC_OK;

  len = snprintf(NULL, 0, format, e->prefix, dash, user, tag, id);
  if (len < 0) {
    snsrDescribeError(s, "snprintf() failed, rc = %i\n", len);
    return SNSR_RC_ERROR;
  }
  tmp = malloc(++len);
  if (!tmp) return SNSR_RC_NO_MEMORY;
  snprintf(tmp, len, format, e->prefix, dash, user, tag, id);
  out = snsrStreamFromAudioFile(tmp, "w", SNSR_ST_AF_DEFAULT);
  snsrRetain(out);

  snsrStreamCopy(out, enrollment, SIZE_MAX);
  if ((r = snsrStreamRC(out)) == SNSR_RC_EOF) r = SNSR_RC_OK;
  if (r != SNSR_RC_OK) snsrDescribeError(s, "%s", snsrStreamErrorDetail(out));
  else if (e->verbosity > 0) {
    printf("Saved enrollment audio to %s\n", tmp);
    fflush(stdout);
  }

  snsrRelease(out);
  free(tmp);
  return r;
}


static SnsrRC
nextEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  const char *tag;

  if (e->currentUser >= e->userCount) return SNSR_RC_OK;
  tag = e->user[e->currentUser++] + 1;
  return snsrSetString(s, SNSR_USER, tag);
}


static SnsrRC
passEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  int id = 0;

  if (e->verbosity >= 1) {
    printf("Preliminary enrollment checks passed.\n");
    fflush(stdout);
  }
  if (!e->prefix) return SNSR_RC_OK;
  snsrGetInt(s, SNSR_RES_ENROLLMENT_ID, &id);
  return saveEnrollmentAudio(s, e, "pass", id);
}


static SnsrRC
pauseEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  snsrStreamClose(e->audio);
  printf("\n");
  fflush(stdout);
  return SNSR_RC_OK;
}


static SnsrRC
resumeEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;
  const char *tag;
  int count, target;
  int ctx;

  snsrStreamOpen(e->audio);
  snsrGetInt(s, SNSR_ENROLLMENT_TARGET, &target);
  snsrGetInt(s, SNSR_RES_ENROLLMENT_COUNT, &count);
  snsrGetInt(s, SNSR_ADD_CONTEXT, &ctx);
  r = snsrGetString(s, SNSR_USER, &tag);
  if (r == SNSR_RC_OK) {
    printf("\nSay %s (%i/%i) for \"%s\"", e->phrase, count + 1, target, tag);
    if (ctx) printf(" with context,\n  for example: "
                    "\"<phrase> will it rain tomorrow?\"");
    printf("\n");
    fflush(stdout);
  }
  return r;
}


static SnsrRC
samplesEvent(SnsrSession s, const char *key, void *privateData)
{
  double count;
  snsrGetDouble(s, SNSR_RES_SAMPLES, &count);
  printf("Recording: %6.2f s\r", count / 16000.0);
  fflush(stdout);
  return SNSR_RC_OK;
}


static SnsrRC
doneEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;
  SnsrStream model = NULL, out;
  size_t written;

  r = snsrGetStream(s, SNSR_MODEL_STREAM, &model);
  if (r != SNSR_RC_OK) return r;
  written = snsrStreamGetMeta(model, SNSR_ST_META_BYTES_WRITTEN);
  out = snsrStreamFromFileName(e->model, "w");
  snsrStreamCopy(out, model, written);
  r = snsrStreamRC(out);
  if (r != SNSR_RC_OK) snsrDescribeError(s, "%s", snsrStreamErrorDetail(out));
  snsrRelease(out);
  if (r == SNSR_RC_OK && e->verbosity >= 1) {
    printf("Enrolled model saved to \"%s\"\n", e->model);
    fflush(stdout);
  }
  if (r != SNSR_RC_OK) return r;
  return SNSR_RC_STOP;
}


static SnsrRC
enrolledEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;

  r = snsrSave(s, SNSR_FM_RUNTIME, snsrStreamFromFileName(e->enroll, "w"));
  if (r == SNSR_RC_OK && e->verbosity >= 1) {
    printf("Enrollment context saved to \"%s\"\n", e->enroll);
    fflush(stdout);
  }
  return r;
}


static SnsrRC
printReason(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  const char *guidance, *reason;
  int pass = 0;
  double value = 0.0, threshold = 0.0;

  snsrGetInt(s, SNSR_RES_REASON_PASS, &pass);
  if (pass) return snsrRC(s);
  snsrGetString(s, SNSR_RES_REASON, &reason);
  snsrGetString(s, SNSR_RES_GUIDANCE, &guidance);
  snsrGetDouble(s, SNSR_RES_REASON_VALUE, &value);
  snsrGetDouble(s, SNSR_RES_REASON_THRESHOLD, &threshold);
  if (snsrRC(s) == SNSR_RC_OK) {
    fprintf(stderr, "This enrollment recording is not usable.\n");
    fprintf(stderr, " Reason: %s", reason);
    if (e->verbosity >= 2)
      fprintf(stderr, "   (%.2f, threshold is %.2f)", value, threshold);
    fprintf(stderr, "\n    Fix: %s\n", guidance);
    fflush(stdout);
  }
  return snsrRC(s);
}


static SnsrRC
failEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;

  printReason(s, key, privateData);
  if (e->verbosity >= 3) {
    fprintf(stderr, "\nAll failed checks:\n");
    fflush(stdout);
    snsrForEach(s, SNSR_REASON_LIST, snsrCallback(printReason, NULL, e));
  }
  if (!e->prefix) return SNSR_RC_OK;
  return saveEnrollmentAudio(s, e, "fail", e->failCount++);
}


static SnsrRC
progEvent(SnsrSession s, const char *key, void *privateData)
{
  SnsrRC r = SNSR_RC_OK;
  EnrollContext *e = (EnrollContext *)privateData;

  if (e->verbosity >= 1) {
    double progress;
    r = snsrGetDouble(s, SNSR_RES_PERCENT_DONE, &progress);
    if (r == SNSR_RC_OK) {
      printf("\rAdapting: %3.0f%% complete.", progress);
      if (progress >= 100) printf("\n");
      fflush(stdout);
    }
  }
  return r;
}


static void
fatal(int rc, const char *format, ...)
{
  va_list a;
  fprintf(stderr, "ERROR: ");
  va_start(a, format);
  vfprintf(stderr, format, a);
  va_end(a);
  fprintf(stderr, "\n");
  exit(rc);
}


static const char *usageDetail =
  "Settings are strings used as keys to query or change task behavior.\n"
  "Most frequently used for enrollment is accuracy, which takes a value\n"
  "between 0 and 1.\n"
  "Refer to the " SNSR_NAME " SDK documentation for a complete list and\n"
  "descriptions of all supported settings.\n";

static void
usage(const char *name)
{
  SnsrSession s;
  const char *libInfo;

  fprintf(stderr, "Enrolls " SNSR_NAME " SDK wake words on live audio.\n\n");
  fprintf(stderr,
          "usage: %s -t task [options] +user1 [+user2 ...] [file ...]\n"
          " options:\n"
          "  -e enrollments   : enrollment context output filename\n"
          "  -o out           : enrolled model output filename (default: "
          DEFAULT_OUT ")\n"
          "  -p prefix        : capture each enrollment to file as\n"
          "                     <prefix>-<user>-{pass,fail}-<index>.wav\n"
          "  -s setting=value : override a task setting\n"
          "  -t task          : specify task filename (required)\n"
          "  -v [-v [-v]]     : increase verbosity\n", name);
  fprintf(stderr,
          "\nEnrollment audio is captured from the default microphone, unless\n"
          "the optional [file ...] arguments are supplied.\n");
  fprintf(stderr, "\n%s", usageDetail);

  snsrNew(&s);
  snsrGetString(s, SNSR_LIBRARY_INFO, &libInfo);
  fprintf(stderr, "\n%s\n", libInfo);
  snsrRelease(s);
  exit(199);
}


/* Report model license keys.
 */
static void
reportModelLicense(SnsrSession s, const char *modelfile, int verbose)
{
  const char *msg = NULL;

  if (verbose > 1) {
    snsrGetString(s, SNSR_MODEL_LICENSE_EXPIRES, &msg);
    if (msg)
      fprintf(stderr, "\"%s\": %s.\n", modelfile, msg);
  }
  msg = NULL;
  snsrGetString(s, SNSR_MODEL_LICENSE_WARNING, &msg);
  if (msg)
    fprintf(stderr, "WARNING for model \"%s\": %s.\n", modelfile, msg);
}


/* Store the first enrollment phrase in e.phrase.
 */
static SnsrRC
getVocab(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;
  const char *vocab;

  r = snsrGetString(s, SNSR_RES_TEXT, &vocab);
  if (r != SNSR_RC_OK) return r;
  free(e->phrase);
  e->phrase = strdup(vocab);
  return r;
}


int
main(int argc, char *argv[])
{
  EnrollContext e;
  SnsrRC r;
  SnsrSession s;
  int i, o;
  const char *msg = NULL;
  extern char *optarg;
  extern int optind;
#ifdef SNSR_USE_SECURITY_CHIP
  uint32_t *securityChipComms(uint32_t *in);
  snsrConfig(SNSR_CONFIG_SECURITY_CHIP, securityChipComms);
#endif

  if (argc == 1) usage(argv[0]);
  r = snsrNew(&s);
  if (r != SNSR_RC_OK) fatal(r, s? snsrErrorDetail(s): snsrRCMessage(r));
  e.currentUser = 0;
  e.enroll = NULL;
  e.phrase = strdup("the enrollment phrase");
  e.prefix = NULL;
  e.model  = DEFAULT_OUT;
  e.failCount = 0;
  e.userCount = 0;
  e.verbosity = 0;

  while ((o = getopt(argc, argv, "e:o:p:s:t:v?")) >= 0) {
    switch (o) {
    case 'e':
      e.enroll = optarg;
      break;
    case 'o':
      e.model = optarg;
      break;
    case 'p':
      e.prefix = optarg;
      r = snsrSetInt(s, SNSR_SAVE_ENROLLMENT_AUDIO, 1);
      if (r == SNSR_RC_NO_MODEL)
        fatal(r, "set -t task before -p prefix");
      break;
    case 's':
      r = snsrSet(s, optarg);
      if (r == SNSR_RC_NO_MODEL)
        fatal(r, "set -t task before -s setting=value");
      else if (r != SNSR_RC_OK)
        fatal(r, snsrErrorDetail(s));
      break;
    case 't':
      snsrLoad(s, snsrStreamFromFileName(optarg, "r"));
      snsrRequire(s, SNSR_TASK_TYPE, SNSR_ENROLL);
      r = snsrRequire(s, SNSR_TASK_VERSION, ENROLL_TASK_VERSION);
      if (r != SNSR_RC_OK) fatal(r, snsrErrorDetail(s));
      reportModelLicense(s, optarg, e.verbosity);
      break;
    case 'v': e.verbosity++; break;
    case '?':
    default:  usage(argv[0]);
    }
  }

  if (optind == argc || argv[optind][0] != '+') usage(argv[0]);

  /* Report application license status */
  if (e.verbosity > 1) {
    snsrGetString(s, SNSR_LICENSE_EXPIRES, &msg);
    if (msg) fprintf(stderr, "\"%s\": %s.\n", argv[0], msg);
  }
  msg = NULL;
  snsrGetString(s, SNSR_LICENSE_WARNING, &msg);
  if (msg) fprintf(stderr, "WARNING for \"%s\": %s.\n", argv[0], msg);

  r = snsrSetInt(s, SNSR_INTERACTIVE_MODE, 1);
  if (r  == SNSR_RC_NO_MODEL) usage(argv[0]);
  snsrSetHandler(s, SNSR_NEXT_EVENT, snsrCallback(nextEvent, NULL, &e));
  snsrSetHandler(s, SNSR_DONE_EVENT, snsrCallback(doneEvent, NULL, &e));
  snsrSetHandler(s, SNSR_FAIL_EVENT, snsrCallback(failEvent, NULL, &e));
  snsrSetHandler(s, SNSR_PASS_EVENT, snsrCallback(passEvent, NULL, &e));
  snsrSetHandler(s, SNSR_PROG_EVENT, snsrCallback(progEvent, NULL, &e));
  snsrSetHandler(s, SNSR_PAUSE_EVENT, snsrCallback(pauseEvent, NULL, &e));
  snsrSetHandler(s, SNSR_RESUME_EVENT, snsrCallback(resumeEvent, NULL, &e));
  snsrSetHandler(s, SNSR_SAMPLES_EVENT, snsrCallback(samplesEvent, NULL, &e));
  if (e.enroll) snsrSetHandler(s, SNSR_ENROLLED_EVENT,
                               snsrCallback(enrolledEvent, NULL, &e));

  /* SNSR_VOCAB_LIST is supported for a subset of models only, ignore errors */
  if (snsrRC(s) == SNSR_RC_OK) {
    snsrForEach(s, SNSR_VOCAB_LIST, snsrCallback(getVocab, NULL, &e));
    snsrClearRC(s);
  }

  for (i = optind; i < argc && argv[i][0] == '+'; i++)
    ;
  e.user = (const char **)argv + optind;
  e.userCount = i - optind;
  if (i == argc) {
    e.audio = snsrStreamFromAudioDevice(SNSR_ST_AF_DEFAULT);
  } else {
    SnsrStream tmp;
    e.audio = snsrStreamFromString("");
    for (; i < argc; i++) {
      tmp = snsrStreamFromFileName(argv[i], "r");
      tmp = snsrStreamFromAudioStream(tmp, SNSR_ST_AF_DEFAULT);
      e.audio = snsrStreamFromStreams(e.audio, tmp);
    }
  }
  snsrSetStream(s, SNSR_SOURCE_AUDIO_PCM, e.audio);
  r = snsrRun(s);
  if (r != SNSR_RC_OK && r != SNSR_RC_STOP)
    fatal(snsrRC(s), snsrErrorDetail(s));
  free(e.phrase);
  snsrRelease(s);
  snsrTearDown();
  return 0;
}