Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yanncalec committed May 19, 2024
1 parent 7dc15aa commit 5e9cce5
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
33 changes: 26 additions & 7 deletions dpmhm/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,28 @@


def get_mapping_supervised(labels:list, *, feature_field:str='feature', label_field:str='label') -> callable:
"""Get a preprocessing mapping to transform a dataset to the format `(data,label)` for supervised training.
"""Get a preprocessing mapping for supervised training.
This processing model performs the following transformations on a dataset:
- conversion of label from string to integer
- conversion from the format of channel first to channel last
After transformation, the dataset is in the format `(data,label)`.
Usage
-----
```python
func = get_mapping_supervised(labels)
ds = ds.map(func, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(lambda x,y: (tf.ensure_shape(x, dimx), y),
```
Note
----
The `ensure_shape()` step seems necessary in Keras 3 to recover the dimension information lost after `.map()`. Otherwise the training will be failed.
See also:
https://github.com/tensorflow/tensorflow/issues/64177
"""
label_layer = layers.StringLookup(
# num_oov_indices=0, # force zero-based integer
Expand All @@ -39,6 +60,8 @@ def nested_type_spec(sp:dict) -> dict:
return tp


# The following method doesn't work in Keras 3. Use `get_mapping_supervised()` instead.

def keras_model_supervised(ds:Dataset, labels:list=None, normalize:bool=False, *, shape:tuple=None, feature_field:str='feature', label_field:str='label') -> models.Model:
"""Initialize a Keras preprocessing model for supervised training.
Expand Down Expand Up @@ -125,13 +148,9 @@ def keras_model_supervised(ds:Dataset, labels:list=None, normalize:bool=False, *
feature_output = ops.transpose(feature_output, [0,2,3,1])

# Restore the shape information
# if shape is not None:
# feature_output = tf.reshape(feature_output, shape)
# label_output = tf.reshape(label_output, ())
if shape is not None:
feature_output = ops.reshape(feature_output, shape)
# label_output = ops.reshape(label_output, ())
# print(label_output)
label_output = ops.reshape(label_output, ())

outputs = (feature_output, label_output)

Expand All @@ -150,4 +169,4 @@ def keras_model_supervised(ds:Dataset, labels:list=None, normalize:bool=False, *
return model


__all__ = ['nested_type_spec', 'keras_model_supervised']
__all__ = ['nested_type_spec', 'get_mapping_supervised']
50 changes: 36 additions & 14 deletions notebooks/models/supervised_vggish.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"id": "2f8354ee-7bf5-49c6-90bc-b5ee6ba6b604",
"metadata": {},
"outputs": [
Expand All @@ -611,7 +611,13 @@
"Epoch 1/5\n",
"\u001b[1m357/357\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m127s\u001b[0m 350ms/step - accuracy: 0.5214 - loss: 3.7552 - val_accuracy: 0.9101 - val_loss: 0.2769\n",
"Epoch 2/5\n",
"\u001b[1m 72/357\u001b[0m \u001b[32m━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1:26\u001b[0m 305ms/step - accuracy: 0.9204 - loss: 0.2379"
"\u001b[1m357/357\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m123s\u001b[0m 346ms/step - accuracy: 0.9255 - loss: 0.2217 - val_accuracy: 0.9515 - val_loss: 0.1628\n",
"Epoch 3/5\n",
"\u001b[1m357/357\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m123s\u001b[0m 345ms/step - accuracy: 0.9483 - loss: 0.1590 - val_accuracy: 0.9727 - val_loss: 0.0853\n",
"Epoch 4/5\n",
"\u001b[1m357/357\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m122s\u001b[0m 341ms/step - accuracy: 0.9588 - loss: 0.1418 - val_accuracy: 0.9693 - val_loss: 0.0822\n",
"Epoch 5/5\n",
"\u001b[1m357/357\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m122s\u001b[0m 342ms/step - accuracy: 0.9665 - loss: 0.1127 - val_accuracy: 0.9834 - val_loss: 0.0552\n"
]
}
],
Expand All @@ -626,22 +632,40 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc947ba1",
"execution_count": 15,
"id": "06f79296-3a90-47e5-ba61-815507251d89",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m1630/1630\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 13ms/step - accuracy: 0.9777 - loss: 0.0619\n"
]
},
{
"data": {
"text/plain": [
"[0.05482785403728485, 0.9822086095809937]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.save(str(outdir / 'saved_model'))"
"model.evaluate(ds_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06f79296-3a90-47e5-ba61-815507251d89",
"execution_count": 17,
"id": "fc947ba1",
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(ds_test)"
"model.save(str(outdir / 'saved_model.keras'))"
]
},
{
Expand All @@ -656,7 +680,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"id": "10b97f80-0af9-4ee6-a943-1ae9658be30e",
"metadata": {},
"outputs": [],
Expand All @@ -672,7 +696,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"id": "4fc06497-9ce1-4ca9-aa9d-9f68500eea2a",
"metadata": {},
"outputs": [
Expand All @@ -681,9 +705,7 @@
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"713/713 [==============================] - 454s 633ms/step - loss: 0.1665 - accuracy: 0.9604 - val_loss: 0.2006 - val_accuracy: 0.9521\n",
"Epoch 2/2\n",
"713/713 [==============================] - 459s 643ms/step - loss: 0.0477 - accuracy: 0.9876 - val_loss: 0.0283 - val_accuracy: 0.9914\n"
"\u001b[1m174/357\u001b[0m \u001b[32m━━━━━━━━━\u001b[0m\u001b[37m━━━━━━━━━━━\u001b[0m \u001b[1m2:46\u001b[0m 911ms/step - accuracy: 0.9695 - loss: 0.1080"
]
}
],
Expand Down

0 comments on commit 5e9cce5

Please sign in to comment.