diff --git a/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_batch_predict.py b/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_batch_predict.py index 83d953293..f3d85fdad 100644 --- a/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_batch_predict.py +++ b/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_batch_predict.py @@ -16,6 +16,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf import numpy as np +import keras strategy = tf.distribute.MirroredStrategy() print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) @@ -28,7 +29,7 @@ def scale(image): images_dir = "/data/mnist_predict/" -img_dataset = tf.keras.utils.image_dataset_from_directory( +img_dataset = keras.utils.image_dataset_from_directory( images_dir, image_size=(28, 28), color_mode="grayscale", @@ -41,13 +42,13 @@ def scale(image): img_prediction_dataset = img_dataset.map(scale) -model_path = '/data/mnist_saved_model/' +model_path = '/data/mnist_saved_model/mnist.keras' with strategy.scope(): - replicated_model = tf.keras.models.load_model(model_path) + replicated_model = keras.models.load_model(model_path) replicated_model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=tf.keras.optimizers.Adam(), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(), metrics=['accuracy']) predictions = replicated_model.predict(img_prediction_dataset) diff --git a/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_train_distributed.py b/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_train_distributed.py index e9b77a656..a9ea0f5e1 100644 --- a/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_train_distributed.py +++ b/tutorials-and-examples/gpu-examples/training-single-gpu/src/tensorflow-mnist-example/tensorflow_mnist_train_distributed.py @@ -16,6 +16,8 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow_datasets as tfds import tensorflow as tf +import keras +import glob datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True) @@ -45,16 +47,17 @@ def scale(image, label): eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) with strategy.scope(): - model = tf.keras.Sequential([ - tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), - tf.keras.layers.MaxPooling2D(), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(64, activation='relu'), - tf.keras.layers.Dense(10) + model = keras.Sequential([ + keras.Input(shape=(28, 28, 1)), + keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + keras.layers.MaxPooling2D(), + keras.layers.Flatten(), + keras.layers.Dense(64, activation='relu'), + keras.layers.Dense(10) ]) - model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=tf.keras.optimizers.Adam(), + model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(), metrics=['accuracy']) # Define the checkpoint directory to store the checkpoints. @@ -71,7 +74,7 @@ def decay(epoch): return 1e-5 # Define a callback for printing the learning rate at the end of each epoch. -class PrintLR(tf.keras.callbacks.Callback): +class PrintLR(keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): print('\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.learning_rate.numpy())) @@ -87,15 +90,25 @@ def on_epoch_end(self, epoch, logs=None): EPOCHS = 12 model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks) -model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) + +# Function to find the latest .h5 file +def find_latest_h5_checkpoint(checkpoint_dir): + list_of_files = glob.glob(f'{checkpoint_dir}/*.h5') + if list_of_files: + latest_file = max(list_of_files, key=os.path.getctime) + return latest_file + else: + return None + +model.load_weights(find_latest_h5_checkpoint(checkpoint_dir)) eval_loss, eval_acc = model.evaluate(eval_dataset) print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc)) -path = '/data/mnist_saved_model/' +path = '/data/mnist_saved_model/mnist.keras' -model.save(path, save_format='tf') +model.save(path) print('Training finished. Model saved')