live_enroll.py¶
This example runs User-Defined Trigger (UDT) wake-word enrollment from Python. By default it feeds four pre-recorded enrollment WAV files from data/enrollments/. Pass --live to capture enrollment audio from the default microphone instead.
Instructions¶
-
Set up the sample project environment:
cd ~/Sensory/TrulyNaturalSDK/7.8.0-pre.2/sample/python uv venv uv sync -
Run file-based enrollment with the default
armadillo-1recordings:uv run src/live_enroll.pyThe sample writes enrolled-sv.snsr unless you pass
--output. -
To enroll from live microphone input instead:
uv run src/live_enroll.py --live --user my-wake-word
Code¶
Available in this TrulyNatural SDK installation at ~/Sensory/TrulyNaturalSDK/7.8.0-pre.2/sample/python/src/live_enroll.py
live_enroll.py
"""Live wake-word enrollment for the TrulyNatural SDK Python binding.
Loads a User-Defined Trigger (UDT) enrollment task and walks
through interactive enrollment: assign users, capture enrollment
utterances, run quality checks, and write an enrolled spotter model.
This mirrors the ``live-enroll`` command-line tool and its C sample
(``live-enroll.c``).
By default the sample feeds pre-recorded enrollment WAV files from
``<sdk-root>/data/enrollments/`` (the same ``armadillo-1`` set used by
``live-enroll`` in CI). Pass ``--live`` to capture from the host's
default microphone instead.
Usage::
uv run src/live_enroll.py [--sdk-root PATH] [options]
uv run src/live_enroll.py --live --user my-wake-word
"""
from __future__ import annotations
import argparse
import sys
from dataclasses import dataclass
from pathlib import Path
import snsr
TASK_MODEL = "udt-universal-3.67.1.0.snsr"
ENROLL_TASK_VERSION = "~0.8.0 || 1.0.0"
DEFAULT_OUTPUT = "enrolled-sv.snsr"
DEFAULT_USER = "armadillo-1"
DEFAULT_ACCURACY = 0.1
SAMPLES_PER_SECOND = 16_000
# Same four utterances as devkit-live-enroll-1.6 in live-enroll.test.
DEFAULT_AUDIO = tuple(f"armadillo-1-{i}.wav" for i in range(4))
def default_sdk_root() -> Path:
return Path(__file__).resolve().parents[3]
def enrollment_audio_dir(sdk_root: Path) -> Path:
return sdk_root / "data" / "enrollments"
@dataclass
class EnrollState:
"""Mutable state shared by enrollment event handlers."""
users: list[str]
model_path: Path
enroll_path: Path | None
prefix: str | None
verbosity: int
audio: snsr.Stream | None = None
current_user: int = 0
phrase: str = "the enrollment phrase"
fail_count: int = 0
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=__doc__.splitlines()[0],
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=(
"Without --live, enrollment audio defaults to four WAV files under\n"
f" <sdk-root>/data/enrollments/{DEFAULT_AUDIO[0]} ... {DEFAULT_AUDIO[-1]}\n"
"Pass additional paths after the options to use your own recordings."
),
)
parser.add_argument(
"--sdk-root",
type=Path,
default=default_sdk_root(),
help="TrulyNatural SDK install root (default: auto-detect)",
)
parser.add_argument(
"--task",
type=Path,
help=f"enrollment task model (default: <sdk-root>/model/{TASK_MODEL})",
)
parser.add_argument(
"--output",
"-o",
type=Path,
default=Path(DEFAULT_OUTPUT),
help=f"enrolled spotter output path (default: {DEFAULT_OUTPUT})",
)
parser.add_argument(
"--enroll",
"-e",
type=Path,
help="optional enrollment-context output path",
)
parser.add_argument(
"--prefix",
"-p",
type=str,
help="save each enrollment capture as <prefix>-<user>-{pass,fail}-<n>.wav",
)
parser.add_argument(
"--user",
action="append",
default=[],
metavar="NAME",
help="user to enroll (repeat for multiple; default: armadillo-1)",
)
parser.add_argument(
"--accuracy",
type=float,
default=DEFAULT_ACCURACY,
help=f"enrollment accuracy setting (default: {DEFAULT_ACCURACY})",
)
parser.add_argument(
"--live",
action="store_true",
help="capture from the default microphone instead of WAV files",
)
parser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
help="increase status output (repeat up to three times)",
)
parser.add_argument(
"audio",
nargs="*",
type=Path,
help="enrollment WAV file(s); ignored when --live is set",
)
return parser.parse_args(argv)
def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path, list[Path]]:
sdk_root = args.sdk_root.resolve()
task_path = (args.task or sdk_root / "model" / TASK_MODEL).resolve()
audio_paths: list[Path] = []
if not args.live:
if args.audio:
audio_paths = [p.resolve() for p in args.audio]
else:
enroll_dir = enrollment_audio_dir(sdk_root)
audio_paths = [enroll_dir / name for name in DEFAULT_AUDIO]
return sdk_root, task_path, audio_paths
def build_audio_stream(paths: list[Path]) -> snsr.Stream:
"""Concatenate enrollment WAVs into one PCM stream (``live-enroll`` file mode)."""
chain = snsr.Stream.from_string("")
for path in paths:
wav = snsr.Stream.from_audio_file(str(path))
chain = snsr.Stream.from_streams(chain, wav)
return chain
def save_enrollment_audio(
s: snsr.Session, state: EnrollState, tag: str, enroll_id: int
) -> None:
if not state.prefix:
return
user = s.get_string(snsr.USER)
if isinstance(user, bytes):
user = user.decode()
dash = "-" if state.prefix else ""
out_name = f"{state.prefix}{dash}{user}-{tag}-{enroll_id}.wav"
enrollment = s.get_stream(snsr.AUDIO_STREAM)
if enrollment.rc != snsr.RC.OK:
return
with snsr.Stream.from_audio_file(out_name, "w") as out:
out.copy(enrollment, 2**63 - 1)
if out.rc not in (snsr.RC.OK, snsr.RC.EOF):
raise snsr.Error(out.rc, message=out.error_detail)
if state.verbosity >= 1:
print(f"Saved enrollment audio to {out_name}")
def print_reason(s: snsr.Session, state: EnrollState) -> None:
if s.get_int(snsr.RES_REASON_PASS):
return
reason = s.get_string(snsr.RES_REASON)
guidance = s.get_string(snsr.RES_GUIDANCE)
if isinstance(reason, bytes):
reason = reason.decode()
if isinstance(guidance, bytes):
guidance = guidance.decode()
print("This enrollment recording is not usable.", file=sys.stderr)
print(f" Reason: {reason}", file=sys.stderr)
if state.verbosity >= 2:
value = s.get_double(snsr.RES_REASON_VALUE)
threshold = s.get_double(snsr.RES_REASON_THRESHOLD)
print(f" ({value:.2f}, threshold is {threshold:.2f})", file=sys.stderr)
print(f" Fix: {guidance}", file=sys.stderr)
def install_handlers(s: snsr.Session, state: EnrollState) -> None:
def next_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
if state.current_user >= len(state.users):
return snsr.RC.OK
user = state.users[state.current_user]
state.current_user += 1
sess.set_string(snsr.USER, user)
return snsr.RC.OK
def pass_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
if state.verbosity >= 1:
print("Preliminary enrollment checks passed.")
if state.prefix:
enroll_id = sess.get_int(snsr.RES_ENROLLMENT_ID)
save_enrollment_audio(sess, state, "pass", enroll_id)
return snsr.RC.OK
def fail_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
print_reason(sess, state)
if state.prefix:
save_enrollment_audio(sess, state, "fail", state.fail_count)
state.fail_count += 1
return snsr.RC.OK
def pause_event(_sess: snsr.Session, _key: str) -> snsr.RC | None:
if state.audio is not None:
state.audio.close()
print()
return snsr.RC.OK
def resume_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
if state.audio is not None:
state.audio.open()
count = sess.get_int(snsr.RES_ENROLLMENT_COUNT)
target = sess.get_int(snsr.ENROLLMENT_TARGET)
user = sess.get_string(snsr.USER)
if isinstance(user, bytes):
user = user.decode()
ctx = sess.get_int(snsr.ADD_CONTEXT)
print(f'\nSay {state.phrase} ({count + 1}/{target}) for "{user}"', end="")
if ctx:
print(
',\n for example: "<phrase> will it rain tomorrow?"',
end="",
)
print()
return snsr.RC.OK
def samples_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
seconds = sess.get_double(snsr.RES_SAMPLES) / SAMPLES_PER_SECOND
print(f"Recording: {seconds:6.2f} s\r", end="", flush=True)
return snsr.RC.OK
def prog_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
if state.verbosity >= 1:
progress = sess.get_double(snsr.RES_PERCENT_DONE)
print(f"\rAdapting: {progress:3.0f}% complete.", end="", flush=True)
if progress >= 100:
print()
return snsr.RC.OK
def done_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
model = sess.get_stream(snsr.MODEL_STREAM)
written = model.get_meta(snsr.StreamMeta.BYTES_WRITTEN)
state.model_path.parent.mkdir(parents=True, exist_ok=True)
with snsr.Stream.from_filename(str(state.model_path), "w") as out:
out.copy(model, written)
if out.rc != snsr.RC.OK:
raise snsr.Error(out.rc, message=out.error_detail)
if state.verbosity >= 1:
print(f'Enrolled model saved to "{state.model_path}"')
return snsr.RC.STOP
def enrolled_event(sess: snsr.Session, _key: str) -> snsr.RC | None:
if state.enroll_path is None:
return snsr.RC.OK
state.enroll_path.parent.mkdir(parents=True, exist_ok=True)
sess.save(snsr.DataFormat.RUNTIME, str(state.enroll_path))
if state.verbosity >= 1:
print(f'Enrollment context saved to "{state.enroll_path}"')
return snsr.RC.OK
def capture_phrase(sess: snsr.Session, _key: str) -> snsr.RC | None:
vocab = sess.get_string(snsr.RES_TEXT)
if isinstance(vocab, bytes):
vocab = vocab.decode()
state.phrase = vocab
return snsr.RC.OK
s.set_handler(snsr.NEXT_EVENT, next_event)
s.set_handler(snsr.DONE_EVENT, done_event)
s.set_handler(snsr.FAIL_EVENT, fail_event)
s.set_handler(snsr.PASS_EVENT, pass_event)
s.set_handler(snsr.PROG_EVENT, prog_event)
s.set_handler(snsr.PAUSE_EVENT, pause_event)
s.set_handler(snsr.RESUME_EVENT, resume_event)
s.set_handler(snsr.SAMPLES_EVENT, samples_event)
if state.enroll_path is not None:
s.set_handler(snsr.ENROLLED_EVENT, enrolled_event)
try:
s.for_each(snsr.VOCAB_LIST, capture_phrase)
except snsr.Error:
pass
def run_enrollment(
task_path: Path,
audio_paths: list[Path],
*,
live: bool,
users: list[str],
output: Path,
enroll: Path | None,
prefix: str | None,
accuracy: float,
verbosity: int,
) -> None:
state = EnrollState(
users=users,
model_path=output.resolve(),
enroll_path=enroll.resolve() if enroll else None,
prefix=prefix,
verbosity=verbosity,
)
print(f"snsr {snsr.VERSION}")
print(f" task: {task_path}")
if live:
print(" audio: default capture device")
else:
print(f" audio: {len(audio_paths)} file(s)")
print(f" users: {', '.join(users)}")
print(f" output: {state.model_path}")
if state.enroll_path:
print(f" enroll: {state.enroll_path}")
print()
with snsr.Session(str(task_path)) as s:
s.require(snsr.TASK_TYPE, snsr.ENROLL)
s.require(snsr.TASK_VERSION, ENROLL_TASK_VERSION)
s.set_int(snsr.INTERACTIVE_MODE, 1)
s.set_double(snsr.ACCURACY, accuracy)
if prefix:
s.set_int(snsr.SAVE_ENROLLMENT_AUDIO, 1)
install_handlers(s, state)
if live:
state.audio = snsr.Stream.from_audio_device()
else:
state.audio = build_audio_stream(audio_paths)
with state.audio:
s.set_stream(snsr.SOURCE_AUDIO_PCM, state.audio)
rc = s.run()
if rc not in (snsr.RC.OK, snsr.RC.STOP):
raise snsr.Error(rc, message=snsr.Session.rc_message(rc))
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
sdk_root, task_path, audio_paths = resolve_paths(args)
users = args.user or [DEFAULT_USER]
if not task_path.is_file():
print(f"error: enrollment task not found: {task_path}", file=sys.stderr)
print(
f"hint: pass --sdk-root or --task pointing at a TrulyNatural SDK install",
file=sys.stderr,
)
return 2
if not args.live:
missing = [p for p in audio_paths if not p.is_file()]
if missing:
print("error: enrollment audio not found:", file=sys.stderr)
for path in missing:
print(f" {path}", file=sys.stderr)
print(
"hint: pass WAV paths, use --live, or install SDK data/enrollments/",
file=sys.stderr,
)
return 2
try:
run_enrollment(
task_path,
audio_paths,
live=args.live,
users=users,
output=args.output,
enroll=args.enroll,
prefix=args.prefix,
accuracy=args.accuracy,
verbosity=args.verbose,
)
except snsr.Error as e:
print(f"error: {e.message}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())