Skip to content

live_enroll.py

This example runs User-Defined Trigger (UDT) wake-word enrollment from Python. By default it feeds four pre-recorded enrollment WAV files from data/enrollments/. Pass --live to capture enrollment audio from the default microphone instead.

Instructions

  1. Set up the sample project environment:

    cd ~/Sensory/TrulyNaturalSDK/7.8.0-pre.2/sample/python
    uv venv
    uv sync
    
  2. Run file-based enrollment with the default armadillo-1 recordings:

    uv run src/live_enroll.py
    

    The sample writes enrolled-sv.snsr unless you pass --output.

  3. To enroll from live microphone input instead:

    uv run src/live_enroll.py --live --user my-wake-word
    

Code

Available in this TrulyNatural SDK installation at ~/Sensory/TrulyNaturalSDK/7.8.0-pre.2/sample/python/src/live_enroll.py

live_enroll.py

"""Live wake-word enrollment for the TrulyNatural SDK Python binding.

Loads a User-Defined Trigger (UDT) enrollment task and walks
through interactive enrollment: assign users, capture enrollment
utterances, run quality checks, and write an enrolled spotter model.
This mirrors the ``live-enroll`` command-line tool and its C sample
(``live-enroll.c``).

By default the sample feeds pre-recorded enrollment WAV files from
``<sdk-root>/data/enrollments/`` (the same ``armadillo-1`` set used by
``live-enroll`` in CI). Pass ``--live`` to capture from the host's
default microphone instead.

Usage::

    uv run src/live_enroll.py [--sdk-root PATH] [options]

    uv run src/live_enroll.py --live --user my-wake-word
"""

from __future__ import annotations

import argparse
import sys
from dataclasses import dataclass
from pathlib import Path

import snsr


TASK_MODEL = "udt-universal-3.67.1.0.snsr"
ENROLL_TASK_VERSION = "~0.8.0 || 1.0.0"
DEFAULT_OUTPUT = "enrolled-sv.snsr"
DEFAULT_USER = "armadillo-1"
DEFAULT_ACCURACY = 0.1
SAMPLES_PER_SECOND = 16_000

# Same four utterances as devkit-live-enroll-1.6 in live-enroll.test.
DEFAULT_AUDIO = tuple(f"armadillo-1-{i}.wav" for i in range(4))


def default_sdk_root() -> Path:
    return Path(__file__).resolve().parents[3]


def enrollment_audio_dir(sdk_root: Path) -> Path:
    return sdk_root / "data" / "enrollments"


@dataclass
class EnrollState:
    """Mutable state shared by enrollment event handlers."""

    users: list[str]
    model_path: Path
    enroll_path: Path | None
    prefix: str | None
    verbosity: int
    audio: snsr.Stream | None = None
    current_user: int = 0
    phrase: str = "the enrollment phrase"
    fail_count: int = 0


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=__doc__.splitlines()[0],
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            "Without --live, enrollment audio defaults to four WAV files under\n"
            f"  <sdk-root>/data/enrollments/{DEFAULT_AUDIO[0]} ... {DEFAULT_AUDIO[-1]}\n"
            "Pass additional paths after the options to use your own recordings."
        ),
    )
    parser.add_argument(
        "--sdk-root",
        type=Path,
        default=default_sdk_root(),
        help="TrulyNatural SDK install root (default: auto-detect)",
    )
    parser.add_argument(
        "--task",
        type=Path,
        help=f"enrollment task model (default: <sdk-root>/model/{TASK_MODEL})",
    )
    parser.add_argument(
        "--output",
        "-o",
        type=Path,
        default=Path(DEFAULT_OUTPUT),
        help=f"enrolled spotter output path (default: {DEFAULT_OUTPUT})",
    )
    parser.add_argument(
        "--enroll",
        "-e",
        type=Path,
        help="optional enrollment-context output path",
    )
    parser.add_argument(
        "--prefix",
        "-p",
        type=str,
        help="save each enrollment capture as <prefix>-<user>-{pass,fail}-<n>.wav",
    )
    parser.add_argument(
        "--user",
        action="append",
        default=[],
        metavar="NAME",
        help="user to enroll (repeat for multiple; default: armadillo-1)",
    )
    parser.add_argument(
        "--accuracy",
        type=float,
        default=DEFAULT_ACCURACY,
        help=f"enrollment accuracy setting (default: {DEFAULT_ACCURACY})",
    )
    parser.add_argument(
        "--live",
        action="store_true",
        help="capture from the default microphone instead of WAV files",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="count",
        default=0,
        help="increase status output (repeat up to three times)",
    )
    parser.add_argument(
        "audio",
        nargs="*",
        type=Path,
        help="enrollment WAV file(s); ignored when --live is set",
    )
    return parser.parse_args(argv)


def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path, list[Path]]:
    sdk_root = args.sdk_root.resolve()
    task_path = (args.task or sdk_root / "model" / TASK_MODEL).resolve()
    audio_paths: list[Path] = []

    if not args.live:
        if args.audio:
            audio_paths = [p.resolve() for p in args.audio]
        else:
            enroll_dir = enrollment_audio_dir(sdk_root)
            audio_paths = [enroll_dir / name for name in DEFAULT_AUDIO]

    return sdk_root, task_path, audio_paths


def build_audio_stream(paths: list[Path]) -> snsr.Stream:
    """Concatenate enrollment WAVs into one PCM stream (``live-enroll`` file mode)."""
    chain = snsr.Stream.from_string("")
    for path in paths:
        wav = snsr.Stream.from_audio_file(str(path))
        chain = snsr.Stream.from_streams(chain, wav)
    return chain


def save_enrollment_audio(
    s: snsr.Session, state: EnrollState, tag: str, enroll_id: int
) -> None:
    if not state.prefix:
        return
    user = s.get_string(snsr.USER)
    if isinstance(user, bytes):
        user = user.decode()
    dash = "-" if state.prefix else ""
    out_name = f"{state.prefix}{dash}{user}-{tag}-{enroll_id}.wav"
    enrollment = s.get_stream(snsr.AUDIO_STREAM)
    if enrollment.rc != snsr.RC.OK:
        return
    with snsr.Stream.from_audio_file(out_name, "w") as out:
        out.copy(enrollment, 2**63 - 1)
        if out.rc not in (snsr.RC.OK, snsr.RC.EOF):
            raise snsr.Error(out.rc, message=out.error_detail)
    if state.verbosity >= 1:
        print(f"Saved enrollment audio to {out_name}")


def print_reason(s: snsr.Session, state: EnrollState) -> None:
    if s.get_int(snsr.RES_REASON_PASS):
        return
    reason = s.get_string(snsr.RES_REASON)
    guidance = s.get_string(snsr.RES_GUIDANCE)
    if isinstance(reason, bytes):
        reason = reason.decode()
    if isinstance(guidance, bytes):
        guidance = guidance.decode()
    print("This enrollment recording is not usable.", file=sys.stderr)
    print(f" Reason: {reason}", file=sys.stderr)
    if state.verbosity >= 2:
        value = s.get_double(snsr.RES_REASON_VALUE)
        threshold = s.get_double(snsr.RES_REASON_THRESHOLD)
        print(f"   ({value:.2f}, threshold is {threshold:.2f})", file=sys.stderr)
    print(f"    Fix: {guidance}", file=sys.stderr)


def install_handlers(s: snsr.Session, state: EnrollState) -> None:
    def next_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        if state.current_user >= len(state.users):
            return snsr.RC.OK
        user = state.users[state.current_user]
        state.current_user += 1
        sess.set_string(snsr.USER, user)
        return snsr.RC.OK

    def pass_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        if state.verbosity >= 1:
            print("Preliminary enrollment checks passed.")
        if state.prefix:
            enroll_id = sess.get_int(snsr.RES_ENROLLMENT_ID)
            save_enrollment_audio(sess, state, "pass", enroll_id)
        return snsr.RC.OK

    def fail_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        print_reason(sess, state)
        if state.prefix:
            save_enrollment_audio(sess, state, "fail", state.fail_count)
            state.fail_count += 1
        return snsr.RC.OK

    def pause_event(_sess: snsr.Session, _key: str) -> snsr.RC | None:
        if state.audio is not None:
            state.audio.close()
        print()
        return snsr.RC.OK

    def resume_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        if state.audio is not None:
            state.audio.open()
        count = sess.get_int(snsr.RES_ENROLLMENT_COUNT)
        target = sess.get_int(snsr.ENROLLMENT_TARGET)
        user = sess.get_string(snsr.USER)
        if isinstance(user, bytes):
            user = user.decode()
        ctx = sess.get_int(snsr.ADD_CONTEXT)
        print(f'\nSay {state.phrase} ({count + 1}/{target}) for "{user}"', end="")
        if ctx:
            print(
                ',\n  for example: "<phrase> will it rain tomorrow?"',
                end="",
            )
        print()
        return snsr.RC.OK

    def samples_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        seconds = sess.get_double(snsr.RES_SAMPLES) / SAMPLES_PER_SECOND
        print(f"Recording: {seconds:6.2f} s\r", end="", flush=True)
        return snsr.RC.OK

    def prog_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        if state.verbosity >= 1:
            progress = sess.get_double(snsr.RES_PERCENT_DONE)
            print(f"\rAdapting: {progress:3.0f}% complete.", end="", flush=True)
            if progress >= 100:
                print()
        return snsr.RC.OK

    def done_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        model = sess.get_stream(snsr.MODEL_STREAM)
        written = model.get_meta(snsr.StreamMeta.BYTES_WRITTEN)
        state.model_path.parent.mkdir(parents=True, exist_ok=True)
        with snsr.Stream.from_filename(str(state.model_path), "w") as out:
            out.copy(model, written)
            if out.rc != snsr.RC.OK:
                raise snsr.Error(out.rc, message=out.error_detail)
        if state.verbosity >= 1:
            print(f'Enrolled model saved to "{state.model_path}"')
        return snsr.RC.STOP

    def enrolled_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
        if state.enroll_path is None:
            return snsr.RC.OK
        state.enroll_path.parent.mkdir(parents=True, exist_ok=True)
        sess.save(snsr.DataFormat.RUNTIME, str(state.enroll_path))
        if state.verbosity >= 1:
            print(f'Enrollment context saved to "{state.enroll_path}"')
        return snsr.RC.OK

    def capture_phrase(sess: snsr.Session, _key: str) -> snsr.RC | None:
        vocab = sess.get_string(snsr.RES_TEXT)
        if isinstance(vocab, bytes):
            vocab = vocab.decode()
        state.phrase = vocab
        return snsr.RC.OK

    s.set_handler(snsr.NEXT_EVENT, next_event)
    s.set_handler(snsr.DONE_EVENT, done_event)
    s.set_handler(snsr.FAIL_EVENT, fail_event)
    s.set_handler(snsr.PASS_EVENT, pass_event)
    s.set_handler(snsr.PROG_EVENT, prog_event)
    s.set_handler(snsr.PAUSE_EVENT, pause_event)
    s.set_handler(snsr.RESUME_EVENT, resume_event)
    s.set_handler(snsr.SAMPLES_EVENT, samples_event)
    if state.enroll_path is not None:
        s.set_handler(snsr.ENROLLED_EVENT, enrolled_event)

    try:
        s.for_each(snsr.VOCAB_LIST, capture_phrase)
    except snsr.Error:
        pass


def run_enrollment(
    task_path: Path,
    audio_paths: list[Path],
    *,
    live: bool,
    users: list[str],
    output: Path,
    enroll: Path | None,
    prefix: str | None,
    accuracy: float,
    verbosity: int,
) -> None:
    state = EnrollState(
        users=users,
        model_path=output.resolve(),
        enroll_path=enroll.resolve() if enroll else None,
        prefix=prefix,
        verbosity=verbosity,
    )

    print(f"snsr {snsr.VERSION}")
    print(f"  task: {task_path}")
    if live:
        print("  audio: default capture device")
    else:
        print(f"  audio: {len(audio_paths)} file(s)")
    print(f"  users: {', '.join(users)}")
    print(f"  output: {state.model_path}")
    if state.enroll_path:
        print(f"  enroll: {state.enroll_path}")
    print()

    with snsr.Session(str(task_path)) as s:
        s.require(snsr.TASK_TYPE, snsr.ENROLL)
        s.require(snsr.TASK_VERSION, ENROLL_TASK_VERSION)
        s.set_int(snsr.INTERACTIVE_MODE, 1)
        s.set_double(snsr.ACCURACY, accuracy)
        if prefix:
            s.set_int(snsr.SAVE_ENROLLMENT_AUDIO, 1)
        install_handlers(s, state)

        if live:
            state.audio = snsr.Stream.from_audio_device()
        else:
            state.audio = build_audio_stream(audio_paths)

        with state.audio:
            s.set_stream(snsr.SOURCE_AUDIO_PCM, state.audio)
            rc = s.run()

        if rc not in (snsr.RC.OK, snsr.RC.STOP):
            raise snsr.Error(rc, message=snsr.Session.rc_message(rc))


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    sdk_root, task_path, audio_paths = resolve_paths(args)
    users = args.user or [DEFAULT_USER]

    if not task_path.is_file():
        print(f"error: enrollment task not found: {task_path}", file=sys.stderr)
        print(
            f"hint: pass --sdk-root or --task pointing at a TrulyNatural SDK install",
            file=sys.stderr,
        )
        return 2

    if not args.live:
        missing = [p for p in audio_paths if not p.is_file()]
        if missing:
            print("error: enrollment audio not found:", file=sys.stderr)
            for path in missing:
                print(f"  {path}", file=sys.stderr)
            print(
                "hint: pass WAV paths, use --live, or install SDK data/enrollments/",
                file=sys.stderr,
            )
            return 2

    try:
        run_enrollment(
            task_path,
            audio_paths,
            live=args.live,
            users=users,
            output=args.output,
            enroll=args.enroll,
            prefix=args.prefix,
            accuracy=args.accuracy,
            verbosity=args.verbose,
        )
    except snsr.Error as e:
        print(f"error: {e.message}", file=sys.stderr)
        return 1

    return 0


if __name__ == "__main__":
    raise SystemExit(main())