Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming ww dev #168

Merged
merged 183 commits into from
Jan 20, 2025
Merged
Changes from 1 commit
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
b16b8d9
moved training_torch to experimental and added a README
jeremy-syn Oct 9, 2023
31dd4be
starting to move code into here
jeremy-syn Jan 22, 2024
0626c97
some updates to streaming wakeword
jeremy-syn Jan 27, 2024
03c9d8b
updated streaming wakeword model to be the actual candidate DS-TCN mo…
jeremy-syn Jan 27, 2024
725cb69
set default features to 40-D LFBEs
jeremy-syn Jan 27, 2024
8e533a3
changed num_classes to 3 in train.py
jeremy-syn Jan 29, 2024
9c15b56
demo notebook (in progress) added
jeremy-syn Jan 29, 2024
5475c7b
demo notebook runs through small training run
jeremy-syn Jan 29, 2024
c7138c4
added count_labels and is_batched()
jeremy-syn Jan 29, 2024
b089ce0
demo now adds some silence waveforms (which then have noise added) to…
jeremy-syn Jan 29, 2024
6c8f06d
updated get_dataset (mostly copied from demo.ipynb) and removed use_s…
jeremy-syn Jan 29, 2024
e7922f7
fixed default model architecture flag
jeremy-syn Jan 29, 2024
5a9a235
fixed some issues with building model
jeremy-syn Jan 29, 2024
7193979
added from_logits argument to model.compile
jeremy-syn Jan 31, 2024
5e74320
cleanup changes to get_dataset and demo notebook
jeremy-syn Feb 2, 2024
4e02106
catching up on edits
jeremy-syn Feb 2, 2024
5d9932c
keras_model does not need tf datasets module
jeremy-syn Feb 2, 2024
4edc86e
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Feb 2, 2024
b72d850
cleaning up demo notebook
jeremy-syn Feb 5, 2024
fc7a4d8
cleaning up demo notebook
jeremy-syn Feb 5, 2024
e9de194
made path to speech commands dataset easier to config per location/us…
jeremy-syn Feb 5, 2024
73e3d63
beginning of code to test long waveform in python
jeremy-syn Feb 5, 2024
2d69501
some updates
jeremy-syn Feb 7, 2024
6916250
added option to read in model config file
jeremy-syn Feb 7, 2024
e91a3ed
set validation set to incorporate background noise. also fixed issue…
jeremy-syn Feb 7, 2024
eb4d120
moved code to add silent (or white noise) frames to dataset into its …
jeremy-syn Feb 7, 2024
29da207
fixed argument error
jeremy-syn Feb 7, 2024
c6f23f1
added post-training quantization
jeremy-syn Feb 10, 2024
f478259
several changes in order to use QAT and evaluate on long waveforms:
jeremy-syn Feb 11, 2024
b46d797
notebook updated to work with last commits on get_data, keras_model
jeremy-syn Feb 11, 2024
c6185a7
changed default LR schedule to reduce_on_plateau so it scales better …
jeremy-syn Feb 11, 2024
9dd2f1b
some more edits to get QAT working
jeremy-syn Feb 11, 2024
19470b0
changed labels to one-hot to work with precision/recall metrics. Also…
jeremy-syn Feb 13, 2024
772e6cc
added notebook to develop tflite model for feature extraction
jeremy-syn Feb 13, 2024
a6ac856
removed some prints from get_dataset. added an evaluation to train
jeremy-syn Feb 13, 2024
7302a8d
adjusted reduce lr on plateau settings
jeremy-syn Feb 13, 2024
20abafc
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Feb 13, 2024
8ef587c
fixed plotting error
jeremy-syn Feb 13, 2024
deec7c9
working on different options to run the feature extractor on MCU
jeremy-syn Feb 14, 2024
468a344
small changes to notebook
jeremy-syn Feb 17, 2024
2964369
removed old commented-out code that loaded pre-built dataset
jeremy-syn Feb 17, 2024
ee12fe4
tflite_feature_extractor.ipynb very much a work in progress
jeremy-syn Feb 17, 2024
ac40098
added setup instructions and a to the streaming wakeword benchmark (…
AlexMontgomerie Mar 6, 2024
380f117
Merge branch 'master' of gh-syn:mlcommons/tiny
jeremy-syn Apr 15, 2024
9798093
cache datasets after spectrogram computation to avoid recomputing the…
jeremy-syn May 21, 2024
c584f23
fixed data_dir default to point to speech_commands_v0.02
jeremy-syn May 21, 2024
12d727f
fixed data_dir default to point to speech_commands_v0.02
jeremy-syn May 21, 2024
0e122cf
added BooleanOptionalAction to correctly parse boolean Flags
jeremy-syn May 21, 2024
7bfefc4
fixed parsing of bool args (use_qat, run_test_set) to work with pytho…
jeremy-syn May 21, 2024
49330a9
changed so parse_command raises exception on unrecognized flags
jeremy-syn May 26, 2024
6dc9314
changed so parse_command raises exception on unrecognized flags
jeremy-syn May 26, 2024
252cafc
added foreground scaling args foreground_volume_min, _max to train on…
jeremy-syn May 28, 2024
2d80a04
set is_training true for ds_val so it gets noise added
jeremy-syn Jun 3, 2024
7270a9a
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Jun 3, 2024
5cf7657
edits to str ww model
jeremy-syn Jun 3, 2024
eca617c
edits to data set building
jeremy-syn Jun 3, 2024
0e7ac0f
saved training history along with plot
jeremy-syn Jun 3, 2024
8bedf74
Merge branch 'streaming_ww_dev' of gh-syn:mlcommons/tiny into streami…
jeremy-syn Jun 3, 2024
7ac2852
removed average pooling, increased initial feature stride
jeremy-syn Jun 3, 2024
01470d7
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Jun 3, 2024
3b0fb65
Fixed bug where np.random is only evaluated at graph creation, so all…
jeremy-syn Jun 4, 2024
3f3dabf
fixed several places where np.random was used in a tf graph, resultin…
jeremy-syn Jun 5, 2024
4fd29c1
widened filters in 2nd,3rd layers to 128
jeremy-syn Jun 5, 2024
6f8a272
changed back from 32 LFBEs to 40
jeremy-syn Jun 5, 2024
ac85b48
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Jun 5, 2024
d7c607c
minor cleanup -- whitespace, removing old commented out lines, etc.
jeremy-syn Jun 6, 2024
870ac09
fixed error - val set was using target words from training set
jeremy-syn Jun 6, 2024
8e9a0ee
minor cleanup -- whitespace, removing old commented out lines, etc.
jeremy-syn Jun 6, 2024
264eae7
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Jun 6, 2024
6a3b15e
changed ordering in data prep, now shuffle before batching
jeremy-syn Jun 6, 2024
76413de
adding current version of trained and quantized streaming ww model
jeremy-syn Jun 6, 2024
dec0abb
minor edits/cleanup
jeremy-syn Jun 7, 2024
b11d123
changed Flags.num_train_samples to num_samples_training. same for tes…
jeremy-syn Jun 7, 2024
b1b1204
added 1st pass at get_data_config(), refactoring dataset build
jeremy-syn Jun 8, 2024
d313d65
refactored dataset building. train.py runs now, have not tested perfo…
jeremy-syn Jun 9, 2024
c4c654b
setup_example is work in progress, just capturing progress
jeremy-syn Jun 10, 2024
a6028cf
train.py runs but gives random-level validation accuracy. demo noteb…
jeremy-syn Jun 10, 2024
d6e23bb
flag parsing used 'train' instead of 'training' and therefore was not…
jeremy-syn Jun 14, 2024
de7a198
updated demo to match changes in data
jeremy-syn Jun 15, 2024
af713d8
minor updates
jeremy-syn Jun 22, 2024
b50d8fc
dumps options as json into plot_dir
jeremy-syn Jun 22, 2024
7168f50
fixed demo to work with new get_data code. moved take after shuffle …
jeremy-syn Jun 22, 2024
859b3b3
moved softmax to inside the model; adjusted loss function accordingly
jeremy-syn Jun 24, 2024
36cd5b8
moved softmax calculation into the model
jeremy-syn Jun 29, 2024
ea3eb2e
working on true pos/false pos computation
jeremy-syn Jun 29, 2024
c7d2f1c
resolved merge
jeremy-syn Jun 29, 2024
35d4394
fixed error, post-wakeword extension was being added twice
jeremy-syn Jun 30, 2024
9184f66
fixing notebook counting of true/false positives
jeremy-syn Jul 1, 2024
8e671d0
removed commented-out code; added zero2nan()
jeremy-syn Jul 1, 2024
a1b5673
added multiple background noise paths, can split long bg files into s…
jeremy-syn Jul 1, 2024
37ca39d
change QAT initial LR to Flags.lr, LR is too small after float pre-tr…
jeremy-syn Jul 1, 2024
7f14c47
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Jul 1, 2024
811db67
fixed cmd line arg processing to accomodate multile bg noise paths
jeremy-syn Jul 1, 2024
53bb590
removed commented out code from demo notebook
jeremy-syn Jul 2, 2024
e1d75f9
convert only-target dataset to numpy array and back so cardinality() …
jeremy-syn Jul 7, 2024
3eb0329
refactored num_silent, num_repeats in to fraction_silent and fraction…
jeremy-syn Jul 9, 2024
ec0d024
fixed cmd line arg processing to accomodate multile bg noise paths
jeremy-syn Jul 9, 2024
5ecfeec
fixing code for smaller datasets
jeremy-syn Jul 16, 2024
6a2d715
catching up on demo edits
jeremy-syn Jul 16, 2024
a2a5f8b
added code to run quantized model on long waveform
jeremy-syn Jul 16, 2024
45d0b72
working on long wav file creation; added poisson process to place wak…
jeremy-syn Jul 16, 2024
609f901
updated long wave creation, need to move it to a separate file soon. …
jeremy-syn Jul 16, 2024
d059f4a
increased number of background files from 50 to 100
jeremy-syn Jul 16, 2024
9d11002
added code to illustrate false detects/rejects
jeremy-syn Jul 18, 2024
95dda3e
updating background noise creation to avoid train/val duplicates
jeremy-syn Jul 18, 2024
3473dbb
added exclude_background_files.txt
jeremy-syn Jul 18, 2024
5114246
put code to build the long test wav into its own (two) files
jeremy-syn Jul 19, 2024
ccec9f4
added eval_long_wav.py to test fpr, fnr on a long wav
jeremy-syn Jul 20, 2024
6910b1c
made build_long_wav work with the musan_path from streaming_config.json
jeremy-syn Jul 20, 2024
04b0608
made build_long_wav work with the musan_path from streaming_config.json
jeremy-syn Jul 20, 2024
c599598
fixed a typo
jeremy-syn Jul 20, 2024
0dced76
fixed issue with path construction in long wav spec
jeremy-syn Jul 20, 2024
223f315
added l2 reg to conv layers
jeremy-syn Jul 21, 2024
76de629
added L2 reg to conv layers
jeremy-syn Jul 21, 2024
2c77d4e
removed some old commented-out code
jeremy-syn Jul 21, 2024
1d1a112
eval_long_wav can now test either h5 models or tflite models
jeremy-syn Jul 21, 2024
c6332f2
Merge branch 'streaming_ww_dev' of gh-syn:mlcommons/tiny into streami…
jeremy-syn Jul 21, 2024
7724e98
added script to create indices into the val set for calibration
jeremy-syn Jul 21, 2024
05cf6d3
code to create calibration set is working
jeremy-syn Jul 22, 2024
7cbcd9f
fixed quantize.py to work with extracted npz calibration set
jeremy-syn Jul 22, 2024
df1d3e8
adjusted volume of foreground and background for testing
jeremy-syn Jul 22, 2024
d341993
added code to save spectrogram in build_long_wav.py
jeremy-syn Jul 22, 2024
e30cbd3
demo notebook should now work with current code
jeremy-syn Jul 23, 2024
0fba129
separated augmentation (built by get_augment_wavs_func()) and feature…
jeremy-syn Jul 25, 2024
40cbecf
made l2 reg parameter a commmand line flag
jeremy-syn Jul 25, 2024
6c6c330
fixed eval_long_wav to work with feature extractor changes
jeremy-syn Jul 25, 2024
f480533
added validation set measurments to eval_long_wav.py
jeremy-syn Jul 25, 2024
8cfc991
moved eval_long_wav to evaluate.py
jeremy-syn Jul 25, 2024
7801dd2
added threshold=0.95 to precision/recall metrics to match evaluate.py
jeremy-syn Jul 25, 2024
23e4434
added a list of 'bad' marvin wav files. modified build_long_wav_spec …
jeremy-syn Jul 26, 2024
2672489
edited comment on saved_model_path to reflect evaluate.py
jeremy-syn Jul 26, 2024
d0f30f8
added bad_marvin_files.txt
jeremy-syn Jul 26, 2024
6d8cb29
fixed error in number of unknown samples for reduced runs
jeremy-syn Jul 27, 2024
ba80cc1
renamed build_long_wav_spec.py -> build_long_wav_def.py to avoid ambi…
jeremy-syn Jul 27, 2024
d626e61
renamed features back to audio to allow easy skipping of feature extr…
jeremy-syn Jul 28, 2024
8eb1f6c
minor edits
jeremy-syn Jul 29, 2024
5960e11
removed debug print statement
jeremy-syn Jul 29, 2024
b0c424a
fixed code for tflite models
jeremy-syn Jul 29, 2024
0ba2b18
adjusted some default training params
jeremy-syn Jul 29, 2024
547569b
catching notebook up to other code
jeremy-syn Jul 29, 2024
6341ee8
clearing out some debug prints
jeremy-syn Jul 29, 2024
5a16bb9
added trained model
jeremy-syn Jul 29, 2024
fad10d4
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Jul 29, 2024
0c0c89f
moved label_count out of model_settings into a flag
jeremy-syn Jul 29, 2024
a348974
Merge branch 'streaming_ww_dev' of gh-syn:mlcommons/tiny into streami…
jeremy-syn Jul 29, 2024
a707678
minor edits
jeremy-syn Jul 29, 2024
6836643
added random timeshift
jeremy-syn Aug 4, 2024
143dd8e
added a couple more bad marvins to exclude
jeremy-syn Aug 4, 2024
2461453
added random timeshifting to augmentation function
jeremy-syn Aug 4, 2024
9c89eb7
added flag to enforce a minimum SNR
jeremy-syn Aug 5, 2024
c91b792
centralized data paths in streaming_config.json (no command line argu…
jeremy-syn Aug 5, 2024
c02b502
removed some obsolete cmd line args and modified get_dataset to respe…
jeremy-syn Aug 5, 2024
0a88769
fixed evaulate.py to work with changes on speech_commands_path
jeremy-syn Aug 5, 2024
f0cd836
changed evaluate and quantize to use model_init_path, so by default t…
jeremy-syn Aug 5, 2024
61153a4
adjusted trainign params
jeremy-syn Aug 5, 2024
42d51a0
adjusted trainign params
jeremy-syn Aug 5, 2024
9e9e2ad
updated long wav info
jeremy-syn Aug 5, 2024
2258d55
updated README
jeremy-syn Aug 5, 2024
15a2d89
updated reference model
jeremy-syn Aug 5, 2024
b49eb7d
removed some info messages
jeremy-syn Aug 6, 2024
9c655d2
add line to create plot_dir if it does not exist
jeremy-syn Aug 9, 2024
b9cd142
reduced noise level in long wav
jeremy-syn Aug 9, 2024
54de3e9
refactored command line argument parsing
jeremy-syn Aug 11, 2024
c354ac6
refactored command line argument parsing
jeremy-syn Aug 12, 2024
7e69344
Merge branch 'streaming_ww_dev' of github.com:mlcommons/tiny into str…
jeremy-syn Aug 12, 2024
478f6ae
fixed some errors in README
jeremy-syn Aug 12, 2024
7d97168
fixed quantize to use saved_model_path instead of model_init_path
jeremy-syn Aug 12, 2024
bd96237
added calibration_samples.npz
jeremy-syn Aug 12, 2024
2bc1175
fixing argument processing for evaluate.py to work with either tflite…
jeremy-syn Aug 12, 2024
8e7b718
fixed typo in evaluate.py
jeremy-syn Aug 12, 2024
8c63415
fixed typo in evaluate
jeremy-syn Aug 12, 2024
7608316
fixing merge
jeremy-syn Aug 12, 2024
30308c9
updated tflite model
jeremy-syn Aug 12, 2024
1fea3a0
fixed issue with plot_dir
jeremy-syn Aug 12, 2024
0c73cf8
Merge branch 'streaming_ww_dev' of gh-syn:mlcommons/tiny into streami…
jeremy-syn Aug 12, 2024
fb4e6a2
ignoring trained models other than reference model
jeremy-syn Aug 12, 2024
68f53b8
updated readme
jeremy-syn Aug 12, 2024
02953c9
updated demo notebook
jeremy-syn Aug 12, 2024
c280e5d
added note about the demo notebook to readme
jeremy-syn Aug 12, 2024
900e126
merged gitignore from master branch
jeremy-syn Aug 12, 2024
f9ec337
Merge branch 'master' into streaming_ww_dev
Peter-Chang Sep 9, 2024
5cfabb6
merged work from runner_dev_jeremy
jeremy-syn Jan 19, 2025
50104d0
Merge branch 'master' into streaming_ww_dev
jeremy-syn Jan 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
updated demo notebook
jeremy-syn committed Aug 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 02953c91573892abc45e2068a4475cfd805cd398
189 changes: 56 additions & 133 deletions benchmark/training/streaming_wakeword/demo.ipynb
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@
"# unrecognized argument error\n",
"sys.argv = sys.argv[0:1] \n",
"\n",
"Flags = util.parse_command()"
"Flags = util.parse_command(\"train\")"
]
},
{
@@ -79,6 +79,7 @@
"\n",
"if notebook_mode == \"inference\": \n",
" load_pretrained_model = True\n",
" Flags.num_samples_training = 2000 # we don't need the full set for inference\n",
" save_model = False\n",
"elif notebook_mode == \"short_training\":\n",
" ## Set these for an extra short test just to validate that the code runs\n",
@@ -97,8 +98,7 @@
" pass\n",
"\n",
"# 'trained_models/str_ww_model.h5' is the default save path for train.py\n",
"# pretrained_model_path = 'trained_models/str_ww_ref_model.h5' # path to load from if load_pretrained_model is True\n",
"pretrained_model_path = 'trained_models/str_ww_model.h5' # path to load from if load_pretrained_model is True\n",
"pretrained_model_path = 'trained_models/str_ww_ref_model.h5' # path to load from if load_pretrained_model is True\n",
"\n",
"samp_freq = Flags.sample_rate"
]
@@ -184,26 +184,25 @@
"metadata": {},
"outputs": [],
"source": [
"# max_target_examples = 3\n",
"# target_count = 0\n",
"max_target_examples = 3\n",
"target_count = 0\n",
"\n",
"# plt.Figure(figsize=(10,4))\n",
"# for dat in ds_train.unbatch():\n",
"# # label_string = dat[1].numpy().decode('utf8')\n",
"# if np.argmax(dat[1]) == 0:\n",
"# target_count += 1\n",
"# ax = plt.subplot(max_target_examples, 1, target_count)\n",
"# # display.display(display.Audio(dat[0].numpy(), rate=16000))\n",
"plt.Figure(figsize=(10,4))\n",
"for dat in ds_train.unbatch():\n",
" if np.argmax(dat[1]) == 0:\n",
" target_count += 1\n",
" ax = plt.subplot(max_target_examples, 1, target_count)\n",
" # display.display(display.Audio(dat[0].numpy(), rate=16000))\n",
"\n",
"# log_spec = dat[0].numpy().squeeze()\n",
"# height = log_spec.shape[0]\n",
"# width = log_spec.shape[1]\n",
"# X = np.linspace(0, 1.0, num=width, dtype=float)\n",
"# Y = range(height)\n",
"# ax.pcolormesh(X, Y, np.squeeze(log_spec))\n",
"# if target_count >= max_target_examples:\n",
"# break\n",
"# plt.tight_layout()"
" log_spec = dat[0].numpy().squeeze()\n",
" height = log_spec.shape[0]\n",
" width = log_spec.shape[1]\n",
" X = np.linspace(0, 1.0, num=width, dtype=float)\n",
" Y = range(height)\n",
" ax.pcolormesh(X, Y, np.squeeze(log_spec))\n",
" if target_count >= max_target_examples:\n",
" break\n",
"plt.tight_layout()"
]
},
{
@@ -214,7 +213,7 @@
"outputs": [],
"source": [
"## look at the label breakdown in the training set\n",
"# print(get_dataset.count_labels(ds_train))\n",
"print(get_dataset.count_labels(ds_train))\n",
"\n"
]
},
@@ -347,8 +346,6 @@
"metadata": {},
"outputs": [],
"source": [
"label_list = ['marvin', 'silent', 'other']\n",
"\n",
"build_and_plot_confusion_matrix(model, ds_train)"
]
},
@@ -363,38 +360,11 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c0a3d5e5-27fc-4059-be36-6132fe073155",
"id": "45495640-7e98-474e-910f-6070f157ad60",
"metadata": {},
"outputs": [],
"source": [
"num_calibration_steps = 5\n",
"tfl_file_name = \"strm_ww_int8.tflite\"\n",
"\n",
"# converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)\n",
"converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
"if True: \n",
" # If we omit this block, we'll get a floating-point TFLite model,\n",
" # with this block, the weights and activations should be quantized to 8b integers, \n",
" converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"\n",
" ds_calibration = ds_val.unbatch().batch(1).take(num_calibration_steps)\n",
" def representative_dataset_gen():\n",
" for next_spec, label in ds_calibration:\n",
" yield [next_spec] \n",
" \n",
" converter.representative_dataset = representative_dataset_gen\n",
" converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] # use this one\n",
" # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]\n",
"\n",
" converter.inference_input_type = tf.int8 # or tf.uint8; should match dat_q in eval_quantized_model.py\n",
" converter.inference_output_type = tf.int8 # or tf.uint8\n",
"\n",
"tflite_quant_model = converter.convert()\n",
"\n",
"with open(tfl_file_name, \"wb\") as fpo:\n",
" fpo.write(tflite_quant_model)\n",
"print(f\"Wrote to {tfl_file_name}\")\n",
"!ls -l $tfl_file_name"
"!python quantize.py --saved_model_path=trained_models/str_ww_ref_model.h5\n"
]
},
{
@@ -433,7 +403,6 @@
"spec, label = next(ds_val.unbatch().batch(1).take(1).as_numpy_iterator())\n",
"\n",
"spec_q = np.array(spec/input_scale + input_zero_point, dtype=np.int8)\n",
"print(f\"min = {np.min(spec_q)}, max = {np.max(spec_q)}\")\n",
"\n",
"interpreter.set_tensor(input_details[0]['index'], spec_q)\n",
"interpreter.invoke()\n",
@@ -491,7 +460,7 @@
"id": "7100ecee-9ed6-42b4-984f-8b75ad642d11",
"metadata": {},
"source": [
"As of 10 Feb 2024, the quantized accuracy on the training set is 83% and 83% on the validation set."
"As of 12 Aug 2024, the quantized accuracy on the validation set is 95.5%. Now we can plot the confusion matrix of the quantized model."
]
},
{
@@ -501,6 +470,7 @@
"metadata": {},
"outputs": [],
"source": [
"label_list = ['marvin', 'silent', 'other']\n",
"confusion_mtx = tf.math.confusion_matrix(labels, predictions)\n",
"plt.figure(figsize=(6, 6))\n",
"sns.heatmap(confusion_mtx, xticklabels=label_list, yticklabels=label_list, \n",
@@ -516,7 +486,9 @@
"id": "6b2750b5-4aa6-4613-bbc9-ddca78fe6cdd",
"metadata": {},
"source": [
"## Run Model on Long Waveform"
"## Run Model on Long Waveform\n",
"\n",
"The use case this benchmark is meant to model is one of detecting a \"wakeword\" (similar to \"Hey Siri\", \"Alexa\", or \"OK Google\") in a continuous stream of sound, including background noise. So to mimic that use case, we will run the model on a longer waveform that includes several instances of the wakeword (\"Marvin\") and some background noise."
]
},
{
@@ -552,14 +524,11 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "857bba03-be91-4f66-b161-539a15b911e9",
"cell_type": "markdown",
"id": "e0e9c14e-48b6-4c1c-97c4-ee24016f64d9",
"metadata": {},
"outputs": [],
"source": [
"ww = model.get_weights()\n",
"ww_tv = model_tv.get_weights()\n"
"For the keras model, we can build an alternate version of the model that accepts inputs of arbitrary length."
]
},
{
@@ -569,8 +538,9 @@
"metadata": {},
"outputs": [],
"source": [
"pretrained_model_uses_qat = hasattr(model.layers[1], \"quantizer\")\n",
"Flags.variable_length=True\n",
"model_tv = models.get_model(args=Flags, use_qat=Flags.use_qat)\n",
"model_tv = models.get_model(args=Flags, use_qat=pretrained_model_uses_qat)\n",
"Flags.variable_length=False\n",
"# transfer weights from trained model into variable-length model\n",
"model_tv.set_weights(model.get_weights())"
@@ -581,7 +551,9 @@
"id": "d31d4877-e789-49ab-ac73-92065a747446",
"metadata": {},
"source": [
"## Run Streaming Test on Long Waveform"
"## Run Streaming Test on Long Waveform\n",
"\n",
"A pre-constructed test wav is included in the repo (`long_wav.wav`) along with a json file that indicates the beginning and end of every instance of the wakeword, `long_wav_ww_windows.json`."
]
},
{
@@ -619,11 +591,6 @@
"## build a feature extractor that can operate on longer waveforms.\n",
"## this one can operate on waveforms up to len(long_wav)\n",
"data_config_long = get_dataset.get_data_config(Flags, 'training')\n",
"data_config_long['foreground_volume_max'] = data_config_long['foreground_volume_min'] = 1.0 # scale to [-1.0,1.0]\n",
"data_config_long['background_frequency'] = 0.0 # do not add background noise or time-shift the input\n",
"data_config_long['time_shift_ms'] = 0.0\n",
"data_config_long['desired_samples'] = len(long_wav)\n",
"data_config_long['num_samples'] = -1\n",
"\n",
"with open(\"data_config_nb.json\", \"w\") as fpo:\n",
" json.dump(data_config_long, fpo, indent=4)\n",
@@ -652,24 +619,14 @@
"print(f\"Does spectrogram loaded from file match the one we created?: {specgrams_match}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a46d67e-418c-45a9-b529-52087dc17be5",
"metadata": {},
"outputs": [],
"source": [
"# plt.plot(long_spec.reshape(-1), long_spec_from_file.reshape(-1), '.')\n",
"# plt.grid(True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97becd93-f5c3-4998-af50-f43bd6fd5f38",
"metadata": {},
"outputs": [],
"source": [
"# We'll count a detection when the softmax output for the wakeword exceeds the detection threshold det_thresh\n",
"det_thresh = 0.95\n",
"\n",
"yy = model_tv(np.expand_dims(long_spec, 0))[0].numpy()\n",
@@ -688,6 +645,14 @@
"plt.grid(True)"
]
},
{
"cell_type": "markdown",
"id": "dcbb6fe0-1c4e-4321-bb62-0b5ff05d2efa",
"metadata": {},
"source": [
"Take a look at some of the false positives here, and then in the next cell, some of the false negatives."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -699,7 +664,7 @@
"for i in range(num_fp_clips_to_show):\n",
" fp_start = np.nonzero(ww_false_detects)[0][i] # sample number where the false pos starts\n",
" print(f\"False positive at {fp_start/samp_freq:3.2f}s (sample {fp_start})\")\n",
" fp_clip = slice(fp_start-32000,fp_start+32000) # add 2s before and after\n",
" fp_clip = slice(fp_start-16000,fp_start+16000) # add 2s before and after\n",
" display.display(display.Audio(long_wav[fp_clip], rate=16000))\n"
]
},
@@ -754,6 +719,14 @@
"\n"
]
},
{
"cell_type": "markdown",
"id": "50c9adcc-7a75-4297-8e29-755f494cd0a4",
"metadata": {},
"source": [
"Now we can take a closer look at one of the errors, showing the waveform plot, listening to the audio, and showing the spectrogram, along with the model outputs. "
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -770,56 +743,6 @@
"examine_clip(wav_clip, model_tv, feature_extractor_long)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c80ccf31-8240-4c23-accd-e9eee62265c2",
"metadata": {},
"outputs": [],
"source": [
"t_start = 258.85 - 2.0\n",
"t_stop = 258.85 + 2.0\n",
"i_start_wav = int(t_start*Flags.sample_rate)\n",
"i_stop_wav = int(t_stop*Flags.sample_rate)\n",
"i_start_spec = int(t_start/(Flags.window_stride_ms/1000))\n",
"i_stop_spec = int(t_stop/(Flags.window_stride_ms/1000))\n",
"\n",
"wav_slice = slice(i_start_wav, i_stop_wav)\n",
"spec_slice = slice(i_start_spec, i_stop_spec)\n",
"\n",
"t_spec= np.arange(long_spec.shape[0])*(Flags.window_stride_ms/1000)\n",
"\n",
"ww_detected = np.repeat(ww_detected_spec_scale, Flags.window_stride_ms*Flags.sample_rate/1000)\n",
"extra_zeros = np.zeros(len(long_wav)-len(ww_detected))\n",
"print(f\"added {len(extra_zeros)} extra zeros\")\n",
"ww_detected = np.concatenate((extra_zeros, ww_detected), axis=0)\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"\n",
"plt.subplot(3,1,1)\n",
"# plt.imshow(np.squeeze(long_spec).T, origin=\"lower\", aspect='auto')\n",
"plt.pcolormesh(t_spec[spec_slice], np.arange(long_spec.shape[-1]), long_spec[spec_slice].squeeze().T)\n",
"\n",
"plt.subplot(3,1,2)\n",
"plt.plot(t[wav_slice], long_wav[wav_slice], \n",
" t[wav_slice], ww_present[wav_slice],\n",
" t[wav_slice], 1.1*ww_detected[wav_slice])\n",
"plt.xlim([t_start, t_stop])\n",
"plt.grid(True)\n",
"plt.legend(['Waveform', 'Wakeword Present', 'Wakeword Detected'], loc='lower right', fontsize=8)\n",
"\n",
"# The model output yy loses some length because of valid-padded convolutions. \n",
"# Add that length back to time-align input and output\n",
"yy_ext = np.concatenate((np.zeros((len(long_spec)-len(yy), yy.shape[1])), yy))\n",
"plt.subplot(3,1,3)\n",
"plt.plot(t_spec[spec_slice], yy_ext[spec_slice])\n",
"plt.plot(t_spec[spec_slice], det_thresh*np.ones(t_spec[spec_slice].shape), 'k-', linewidth=0.5)\n",
"plt.legend(label_list+[\"Threshold\"], loc='lower right', fontsize=8);\n",
"plt.xlim([t_start, t_stop])\n",
"plt.tight_layout()\n",
"# display.display(display.Audio(long_wav, rate=16000))"
]
},
{
"cell_type": "markdown",
"id": "19666bc5-4b43-41ef-b93d-050ff7f87dd0",
@@ -865,7 +788,7 @@
"metadata": {},
"outputs": [],
"source": [
"det_thresh = 0.85\n",
"det_thresh = 0.95\n",
"## shows detection when wakeword activation is strongest output\n",
"# ww_detected_spec_scale = (np.argmax(yy, axis=1)==0) # detections on the time scale of spectrograms\n",
"\n",
@@ -904,7 +827,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0ff4c8cc-80ce-4ecb-b30f-2f545ce4f14c",
"id": "5fffca80-705b-46ff-b259-12624d10be62",
"metadata": {},
"outputs": [],
"source": []