Skip to content

Commit

Permalink
Fix indice remapping to get the correct label (#201)
Browse files Browse the repository at this point in the history
Co-authored-by: fr.branchaud-charron <[email protected]>
  • Loading branch information
Dref360 and fr.branchaud-charron authored Apr 11, 2022
1 parent ff77819 commit 93ba0aa
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions notebooks/baal_prod_cls.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [\u001b[32minfo ] Starting training dataset=100 epoch=5\n"
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [\u001B[32minfo ] Starting training dataset=100 epoch=5\n"
]
},
{
Expand All @@ -306,9 +306,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [\u001b[32minfo ] Training complete train_loss=2.058176279067993\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [\u001b[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [\u001b[32minfo ] Evaluation complete test_loss=2.0671451091766357\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [\u001B[32minfo ] Training complete train_loss=2.058176279067993\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [\u001B[32minfo ] Evaluation complete test_loss=2.0671451091766357\n",
"Metrics: {'test_loss': 2.0671451091766357, 'train_loss': 2.058176279067993}\n"
]
}
Expand All @@ -334,7 +334,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [\u001b[32minfo ] Start Predict dataset=5074\n"
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [\u001B[32minfo ] Start Predict dataset=5074\n"
]
}
],
Expand Down Expand Up @@ -369,8 +369,9 @@
],
"source": [
"# 4. Label those samples.\n",
"labels = [get_label(train_dataset.files[idx]) for idx in top_uncertainty]\n",
"print(list(zip(labels, top_uncertainty)))\n",
"oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)\n",
"labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]\n",
"print(list(zip(labels, oracle_indices)))\n",
"active_learning_ds.label(top_uncertainty, labels)\n",
"\n"
]
Expand All @@ -389,40 +390,40 @@
"output_type": "stream",
"text": [
"Training on 110 items!\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:50:02.089160Z [\u001b[32minfo ] Starting training dataset=110 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:50:19.678241Z [\u001b[32minfo ] Training complete train_loss=1.9793428182601929\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:50:19.681509Z [\u001b[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:50:33.777658Z [\u001b[32minfo ] Evaluation complete test_loss=2.013453960418701\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:50:02.089160Z [\u001B[32minfo ] Starting training dataset=110 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:50:19.678241Z [\u001B[32minfo ] Training complete train_loss=1.9793428182601929\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:50:19.681509Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:50:33.777658Z [\u001B[32minfo ] Evaluation complete test_loss=2.013453960418701\n",
"Metrics: {'test_loss': 2.013453960418701, 'train_loss': 1.9793428182601929}\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:50:33.784990Z [\u001b[32minfo ] Start Predict dataset=5064\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:50:33.784990Z [\u001B[32minfo ] Start Predict dataset=5064\n",
"Training on 120 items!\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:52:14.295969Z [\u001b[32minfo ] Starting training dataset=120 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:52:32.482238Z [\u001b[32minfo ] Training complete train_loss=1.8900309801101685\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:52:32.484473Z [\u001b[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:52:46.287436Z [\u001b[32minfo ] Evaluation complete test_loss=1.8315811157226562\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:52:14.295969Z [\u001B[32minfo ] Starting training dataset=120 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:52:32.482238Z [\u001B[32minfo ] Training complete train_loss=1.8900309801101685\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:52:32.484473Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:52:46.287436Z [\u001B[32minfo ] Evaluation complete test_loss=1.8315811157226562\n",
"Metrics: {'test_loss': 1.8315811157226562, 'train_loss': 1.8900309801101685}\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:52:46.367016Z [\u001b[32minfo ] Start Predict dataset=5054\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:52:46.367016Z [\u001B[32minfo ] Start Predict dataset=5054\n",
"Training on 130 items!\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:54:26.794349Z [\u001b[32minfo ] Starting training dataset=130 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:54:44.481490Z [\u001b[32minfo ] Training complete train_loss=1.961772084236145\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:54:44.483477Z [\u001b[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:54:58.268424Z [\u001b[32minfo ] Evaluation complete test_loss=1.859472393989563\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:54:26.794349Z [\u001B[32minfo ] Starting training dataset=130 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:54:44.481490Z [\u001B[32minfo ] Training complete train_loss=1.961772084236145\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:54:44.483477Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:54:58.268424Z [\u001B[32minfo ] Evaluation complete test_loss=1.859472393989563\n",
"Metrics: {'test_loss': 1.859472393989563, 'train_loss': 1.961772084236145}\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:54:58.276565Z [\u001b[32minfo ] Start Predict dataset=5044\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:54:58.276565Z [\u001B[32minfo ] Start Predict dataset=5044\n",
"Training on 140 items!\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:56:38.406344Z [\u001b[32minfo ] Starting training dataset=140 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:56:57.088064Z [\u001b[32minfo ] Training complete train_loss=1.8688158988952637\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:56:57.091358Z [\u001b[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:57:10.968456Z [\u001b[32minfo ] Evaluation complete test_loss=1.7242822647094727\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:56:38.406344Z [\u001B[32minfo ] Starting training dataset=140 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:56:57.088064Z [\u001B[32minfo ] Training complete train_loss=1.8688158988952637\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:56:57.091358Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:57:10.968456Z [\u001B[32minfo ] Evaluation complete test_loss=1.7242822647094727\n",
"Metrics: {'test_loss': 1.7242822647094727, 'train_loss': 1.8688158988952637}\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:57:10.977104Z [\u001b[32minfo ] Start Predict dataset=5034\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:57:10.977104Z [\u001B[32minfo ] Start Predict dataset=5034\n",
"Training on 150 items!\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:58:51.197386Z [\u001b[32minfo ] Starting training dataset=150 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:59:09.779341Z [\u001b[32minfo ] Training complete train_loss=1.8381125926971436\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:59:09.782580Z [\u001b[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:59:23.176680Z [\u001b[32minfo ] Evaluation complete test_loss=1.7318601608276367\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:58:51.197386Z [\u001B[32minfo ] Starting training dataset=150 epoch=5\n",
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:59:09.779341Z [\u001B[32minfo ] Training complete train_loss=1.8381125926971436\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:59:09.782580Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:59:23.176680Z [\u001B[32minfo ] Evaluation complete test_loss=1.7318601608276367\n",
"Metrics: {'test_loss': 1.7318601608276367, 'train_loss': 1.8381125926971436}\n",
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:59:23.184444Z [\u001b[32minfo ] Start Predict dataset=5024\n"
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:59:23.184444Z [\u001B[32minfo ] Start Predict dataset=5024\n"
]
}
],
Expand All @@ -444,7 +445,8 @@
" predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)\n",
" top_uncertainty = heuristic(predictions)[:10]\n",
" # 4. Label those samples.\n",
" labels = [get_label(train_dataset.files[idx]) for idx in top_uncertainty]\n",
" oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)\n",
" labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]\n",
" active_learning_ds.label(top_uncertainty, labels)\n",
" \n",
" "
Expand Down Expand Up @@ -513,4 +515,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}

0 comments on commit 93ba0aa

Please sign in to comment.