Skip to content

spot-enroll.c

This is the source code for the spot-enroll command-line tool.

Instructions

See spot-enroll.

Code

Available in this TrulyNatural SDK installation at ~/Sensory/TrulyNaturalSDK/7.6.1/sample/c/src/spot-enroll.c

spot-enroll.c

/* Sensory Confidential
 * Copyright (C)2016-2025 Sensory, Inc. https://sensory.com/
 *
 * TrulyHandsfree SDK keyword spotting command-line enrollment utility.
 *------------------------------------------------------------------------------
 */

#include <snsr.h>

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define DEFAULT_OUT  "enrolled-sv.snsr"
#define ENROLL_TASK_VERSION "~0.10.0 || 1.0.0"

typedef struct {
  const char *enrollfile; /* current enrollment filename              */
  const char **filename;  /* enrollment filenames, for error messages */
  const char *enrolled;   /* optional enrollment context file name    */
  const char *adapted;    /* optional adapted context file name       */
  const char *model;      /* enrolled phrase spotter output file name */
  size_t fileCount;       /* number of allocated filenames            */
  int failed;             /* number of failed enrollments             */
  int verbosity;          /* incremented by the -v flag               */
} EnrollContext;


static SnsrRC
doneEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;
  SnsrStream model = NULL, out;
  size_t written;

  r = snsrGetStream(s, SNSR_MODEL_STREAM, &model);
  if (r != SNSR_RC_OK) return r;
  written = snsrStreamGetMeta(model, SNSR_ST_META_BYTES_WRITTEN);
  out = snsrStreamFromFileName(e->model, "w");
  snsrStreamCopy(out, model, written);
  r = snsrStreamRC(out);
  if (r != SNSR_RC_OK) snsrDescribeError(s, "%s", snsrStreamErrorDetail(out));
  snsrRelease(out);
  if (r == SNSR_RC_OK && e->verbosity >= 1) {
    printf("Enrolled model saved to \"%s\"\n", e->model);
    fflush(stdout);
  }
  return r;
}


static SnsrRC
adaptedEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;

  r = snsrSave(s, SNSR_FM_RUNTIME, snsrStreamFromFileName(e->adapted, "w"));
  if (r == SNSR_RC_OK && e->verbosity >= 1) {
    printf("Adapted enrollment context saved to \"%s\"\n", e->adapted);
    fflush(stdout);
  }
  return r;
}


static SnsrRC
enrolledEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;

  r = snsrSave(s, SNSR_FM_RUNTIME, snsrStreamFromFileName(e->enrolled, "w"));
  if (r == SNSR_RC_OK && e->verbosity >= 1) {
    printf("Enrollment context saved to \"%s\"\n", e->enrolled);
    fflush(stdout);
  }
  return r;
}


static SnsrRC
printReason(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  const char *guidance, *reason;
  int pass = 0;
  double value = 0.0, threshold = 0.0;

  snsrGetInt(s, SNSR_RES_REASON_PASS, &pass);
  if (pass) return snsrRC(s);
  snsrGetString(s, SNSR_RES_REASON, &reason);
  snsrGetString(s, SNSR_RES_GUIDANCE, &guidance);
  snsrGetDouble(s, SNSR_RES_REASON_VALUE, &value);
  snsrGetDouble(s, SNSR_RES_REASON_THRESHOLD, &threshold);
  if (snsrRC(s) == SNSR_RC_OK) {
    fprintf(stderr, " reason: %s", reason);
    if (e->verbosity >= 2)
      fprintf(stderr, "   (%.2f, threshold is %.2f)", value, threshold);
    fprintf(stderr, "\n    fix: %s\n", guidance);
    fflush(stdout);
  }
  return snsrRC(s);
}


static SnsrRC
failEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;
  int id;

  r = snsrGetInt(s, SNSR_RES_ENROLLMENT_ID, &id);
  if (r != SNSR_RC_OK) return r;
  fprintf(stderr, "Enrollment from file \"%s\" failed:\n",
          (size_t)id < e->fileCount? e->filename[id]: e->enrollfile);
  printReason(s, key, privateData);
  if (e->verbosity >= 3) {
    fprintf(stderr, "\nAll failed checks:\n");
    snsrForEach(s, SNSR_REASON_LIST, snsrCallback(printReason, NULL, e));
  }
  fflush(stdout);
  e->failed++;
  return SNSR_RC_OK;
}


static SnsrRC
passEvent(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;
  int id;

  r = snsrGetInt(s, SNSR_RES_ENROLLMENT_ID, &id);
  if (r != SNSR_RC_OK) return r;
  if ((size_t)id >= e->fileCount) {
    e->fileCount++;
    e->filename = (const char **)realloc((char **)e->filename,
                                         sizeof(*e->filename) * e->fileCount);
  }
  e->filename[id] = e->enrollfile;
  return SNSR_RC_OK;
}


static SnsrRC
progEvent(SnsrSession s, const char *key, void *privateData)
{
  SnsrRC r = SNSR_RC_OK;
  EnrollContext *e = (EnrollContext *)privateData;

  if (e->verbosity >= 1) {
    double progress;
    r = snsrGetDouble(s, SNSR_RES_PERCENT_DONE, &progress);
    if (r == SNSR_RC_OK) {
      printf("\rAdapting: %3.0f%% complete.", progress);
      if (progress >= 100) printf("\n");
      fflush(stdout);
    }
  }
  return r;
}


static SnsrRC
userIterator(SnsrSession s, const char *key, void *privateData)
{
  EnrollContext *e = (EnrollContext *)privateData;
  SnsrRC r;
  int count, recommended;
  const char *user;

  snsrGetString(s, SNSR_USER, &user);
  snsrGetInt(s, SNSR_ENROLLMENT_TARGET, &recommended);
  r = snsrGetInt(s, SNSR_RES_ENROLLMENT_COUNT, &count);
  if (r == SNSR_RC_OK) {
    if (e->verbosity >= 2)
      printf("%16s: %u enrollment%s.\n", user, count, count == 1? "": "s");
    if (count != recommended)
      fprintf(stderr, "WARNING: \"%s\" has %i enrollment%s, task recommends "
              "%i for optimal performance.\n",
              user, count, count == 1? "": "s", recommended);
    fflush(stdout);
  }
  return r;
}


static void
fatal(int rc, const char *format, ...)
{
  va_list a;
  fprintf(stderr, "ERROR: ");
  va_start(a, format);
  vfprintf(stderr, format, a);
  va_end(a);
  fprintf(stderr, "\n");
  exit(rc);
}


static const char *usageDetail =
  "Settings are strings used as keys to query or change task behavior.\n"
  "Most frequently used for enrollment is accuracy, which takes a value\n"
  "between 0 and 1.\n"
  "Refer to the " SNSR_NAME " SDK documentation for a complete list and\n"
  "descriptions of all supported settings.\n";

static void
usage(const char *name)
{
  SnsrSession s;
  const char *libInfo;

  fprintf(stderr, "Enrolls " SNSR_NAME " SDK wake words on audio files.\n\n");
  fprintf(stderr,
          "usage: %s -t task [options] "
          "[+user1 file1 [-c] file2 ...] [+user2 ...]\n"
          " options:\n"
          "  -a adaptedfile   : adapted enrollment context output filename\n"
          "  -c file          : recording contains trailing context\n"
          "  -e enrolledfile  : enrollment context output filename\n"
          "  -o out           : enrolled model output filename (default: "
          DEFAULT_OUT ")\n"
          "  -s setting=value : override a task setting\n"
          "  -t task          : specify task filename (required)\n"
          "  -v [-v [-v]]     : increase verbosity\n", name);
  fprintf(stderr, "\n%s", usageDetail);

  snsrNew(&s);
  snsrGetString(s, SNSR_LIBRARY_INFO, &libInfo);
  fprintf(stderr, "\n%s\n", libInfo);
  snsrRelease(s);
  exit(199);
}


/* Report model license keys.
 */
static void
reportModelLicense(SnsrSession s, const char *modelfile, int verbose)
{
  const char *msg = NULL;

  if (verbose > 1) {
    snsrGetString(s, SNSR_MODEL_LICENSE_EXPIRES, &msg);
    if (msg)
      fprintf(stderr, "\"%s\": %s.\n", modelfile, msg);
  }
  msg = NULL;
  snsrGetString(s, SNSR_MODEL_LICENSE_WARNING, &msg);
  if (msg)
    fprintf(stderr, "WARNING for model \"%s\": %s.\n", modelfile, msg);
}


/* List enrollment phrases and IDs, where available
 */
static SnsrRC
showVocab(SnsrSession s, const char *key, void *privateData)
{
  SnsrRC r;
  const char *text = NULL;
  int id = -1, *first = (int *)privateData;

  snsrGetInt(s, SNSR_RES_ID, &id);
  r = snsrGetString(s, SNSR_RES_TEXT, &text);
  if (r != SNSR_RC_OK) return r;
  if (*first) printf("Available vocabulary:\n");
  printf(" %2i: \"%s\"\n", id, text);
  *first = 0;
  return r;
}


int
main(int argc, char *argv[])
{
  EnrollContext e;
  SnsrRC r;
  SnsrSession s;
  int i, o, rejected = 0;
  const char *msg = NULL;
  extern char *optarg;
  extern int optind;
  const char *u = NULL;
#ifdef SNSR_USE_SECURITY_CHIP
  uint32_t *securityChipComms(uint32_t *in);
  snsrConfig(SNSR_CONFIG_SECURITY_CHIP, securityChipComms);
#endif

  if (argc == 1) usage(argv[0]);
  r = snsrNew(&s);
  if (r != SNSR_RC_OK) fatal(r, s? snsrErrorDetail(s): snsrRCMessage(r));
  e.failed = 0;
  e.verbosity = 0;
  e.enrolled = NULL;
  e.adapted = NULL;
  e.model  = DEFAULT_OUT;
  e.fileCount = 0;
  e.filename = NULL;

  while ((o = getopt(argc, argv, "+a:e:o:s:t:v?")) >= 0) {
    switch (o) {
    case 'a':
      e.adapted = optarg;
      break;
    case 'e':
      e.enrolled = optarg;
      break;
    case 'o':
      e.model = optarg;
      break;
    case 's':
      r = snsrSet(s, optarg);
      if (r == SNSR_RC_NO_MODEL)
        fatal(r, "set -t task before -s setting=value");
      else if (r != SNSR_RC_OK)
        fatal(r, snsrErrorDetail(s));
      break;
    case 't':
      snsrLoad(s, snsrStreamFromFileName(optarg, "r"));
      snsrRequire(s, SNSR_TASK_TYPE, SNSR_ENROLL);
      r = snsrRequire(s, SNSR_TASK_VERSION, ENROLL_TASK_VERSION);
      if (r != SNSR_RC_OK) fatal(r, snsrErrorDetail(s));
      reportModelLicense(s, optarg, e.verbosity);
      break;
    case 'v': e.verbosity++; break;
    case '?':
    default: usage(argv[0]);
    }
  }

  r = snsrSetInt(s, SNSR_INTERACTIVE_MODE, 0);
  if (r  == SNSR_RC_NO_MODEL) usage(argv[0]);

  /* Report application license status */
  if (e.verbosity > 1) {
    snsrGetString(s, SNSR_LICENSE_EXPIRES, &msg);
    if (msg) fprintf(stderr, "\"%s\": %s.\n", argv[0], msg);
  }
  msg = NULL;
  snsrGetString(s, SNSR_LICENSE_WARNING, &msg);
  if (msg) fprintf(stderr, "WARNING for \"%s\": %s.\n", argv[0], msg);

  snsrSetHandler(s, SNSR_DONE_EVENT, snsrCallback(doneEvent, NULL, &e));
  snsrSetHandler(s, SNSR_FAIL_EVENT, snsrCallback(failEvent, NULL, &e));
  snsrSetHandler(s, SNSR_PASS_EVENT, snsrCallback(passEvent, NULL, &e));
  snsrSetHandler(s, SNSR_PROG_EVENT, snsrCallback(progEvent, NULL, &e));
  if (e.enrolled)  snsrSetHandler(s, SNSR_ENROLLED_EVENT,
                                  snsrCallback(enrolledEvent, NULL, &e));
  if (e.adapted)  snsrSetHandler(s, SNSR_ADAPTED_EVENT,
                                 snsrCallback(adaptedEvent, NULL, &e));

  /* SNSR_VOCAB_LIST is supported for a subset of models only, ignore errors */
  if (e.verbosity > 2 && snsrRC(s) == SNSR_RC_OK) {
    int first = 1;
    snsrForEach(s, SNSR_VOCAB_LIST, snsrCallback(showVocab, NULL, &first));
    snsrClearRC(s);
  }

  if (optind + 1 < argc) {
    int enrollmentIndex = 0, idx = -1, errors;
    if (argv[optind][0] != '+') usage(argv[0]);
    for (i = optind; i < argc; i++) {
      if (argv[i][0] == '+') {
        u = argv[i] + 1;
        snsrSetString(s, SNSR_USER, u);
      } else {
        SnsrStream a;
        int hasContext;
        hasContext = !strcmp("-c", argv[i]);
        if (hasContext && ++i >= argc) usage(argv[0]);
        a = snsrStreamFromFileName(argv[i], "r");
        e.enrollfile = argv[i];
        a = snsrStreamFromAudioStream(a, SNSR_ST_AF_DEFAULT);
        snsrSetStream(s, SNSR_SOURCE_AUDIO_PCM, a);
        snsrSetInt(s, SNSR_ADD_CONTEXT, hasContext);
        if (e.verbosity >= 2) {
          printf("Enrolling user \"%s\"%s from file \"%s\".\n",
                 u, hasContext? " with context": "", argv[i]);
          fflush(stdout);
        }
        errors = e.failed;
        if (snsrRun(s) == SNSR_RC_STREAM_END) snsrClearRC(s);
        snsrGetInt(s, SNSR_RES_ENROLLMENT_COUNT, &idx);
        if (idx == enrollmentIndex && errors == e.failed) {
          fprintf(stderr, "Enrollment skipped for \"%s\", amplitude too low?\n",
                  e.enrollfile);
          e.failed++;
        }
        if (e.failed > errors) rejected++;
        enrollmentIndex = idx;
      }
      if (snsrRC(s) != SNSR_RC_OK) fatal(snsrRC(s), snsrErrorDetail(s));
    }
  }
  if (rejected) fatal(100,"%u enrollment %s rejected.",
                      rejected, rejected == 1? "file was": "files were");
  snsrForEach(s, SNSR_USER_LIST, snsrCallback(userIterator, NULL, &e));
  snsrSetString(s, SNSR_USER, NULL); /* end-of-enrollment marker */
  snsrSetStream(s, SNSR_SOURCE_AUDIO_PCM, snsrStreamFromString(""));
  if (snsrRun(s) == SNSR_RC_STREAM_END) snsrClearRC(s);
  if (snsrRC(s) != SNSR_RC_OK) fatal(snsrRC(s), snsrErrorDetail(s));
  snsrRelease(s);
  snsrTearDown();
  free((char **)e.filename);

  return e.failed;
}