From 78a5ad88002392b7e0031cda44de0d1609744f26 Mon Sep 17 00:00:00 2001 From: kyle-woodward Date: Mon, 8 Apr 2024 15:31:14 -0400 Subject: [PATCH 1/6] add mobilenet_v3small and vgg16 models --- fao_models/graveyard/model_graveyard.py | 17 +++++ fao_models/models.py | 92 +++++++++++++++++++------ 2 files changed, 88 insertions(+), 21 deletions(-) diff --git a/fao_models/graveyard/model_graveyard.py b/fao_models/graveyard/model_graveyard.py index 23d7adf..74c0ee4 100644 --- a/fao_models/graveyard/model_graveyard.py +++ b/fao_models/graveyard/model_graveyard.py @@ -1,3 +1,20 @@ +def model1(optimizer, loss_fn, metrics=[]): + + model = models.Sequential( + [ + layers.Input(shape=(32, 32, 4)), + layers.Flatten(), + layers.Dense(64, activation="relu"), + layers.Dense(1, activation="softmax"), + ] + ) + + model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics) + + return model + + + # Define your TensorFlow models here def cnn_v1_softmax_onehot(optimizer,loss_fn,metrics=['accuracy']): def conv_block(input_tensor, num_filters): diff --git a/fao_models/models.py b/fao_models/models.py index 8171652..2a77682 100644 --- a/fao_models/models.py +++ b/fao_models/models.py @@ -128,23 +128,6 @@ def dice_loss(y_true, y_pred, smooth=1): evaluation_metrics = [categorical_accuracy, f1_m, precision_m, recall_m] - -def model1(optimizer, loss_fn, metrics=[]): - - model = models.Sequential( - [ - layers.Input(shape=(32, 32, 4)), - layers.Flatten(), - layers.Dense(64, activation="relu"), - layers.Dense(1, activation="softmax"), - ] - ) - - model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics) - - return model - - def resnet( optimizer, loss_fn, @@ -188,13 +171,80 @@ def create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1): return create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1) +def mobilenet_v3small(optimizer,loss_fn,metrics=[ + dice_coef, + "binary_accuracy", + ], +): + input_shape = (32,32,4) + + base_model = tf.keras.applications.MobileNetV3Small( + input_shape=input_shape, + include_top=False, + weights=None, + classes=2, + include_preprocessing=False +) + # Custom model + inputs = layers.Input(shape=input_shape) + + # Pass the input to the base model + x = base_model( + inputs, training=True + ) # Set training=True to enable BatchNormalization layers + + # Add custom top layers + x = layers.Flatten()(x) + x = layers.Dense(256, activation="relu")(x) + outputs = layers.Dense(1, activation="sigmoid")( + x + ) + # Create the final model + model = models.Model(inputs=inputs, outputs=outputs) + + model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics) + + return model + +def vgg16(optimizer,loss_fn,metrics=[ + dice_coef, + "binary_accuracy", + ], +): + input_shape = (32,32,4) + base_model = tf.keras.applications.VGG16( + include_top=False, + weights=None, + input_shape=input_shape, + classes=2, + ) + + # Custom model + inputs = layers.Input(shape=input_shape) + + # Pass the input to the base model + x = base_model( + inputs, training=True + ) # Set training=True to enable BatchNormalization layers + + # Add custom top layers + x = layers.Flatten()(x) + x = layers.Dense(256, activation="relu")(x) + outputs = layers.Dense(1, activation="sigmoid")( + x + ) + # Create the final model + model = models.Model(inputs=inputs, outputs=outputs) + + model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics) + + return model + # Create a dictionary of keyword-function pairs model_dict = { - "model1": model1, "resnet": resnet, - # in graveyard - # 'cnn_v1_softmax_onehot': cnn_v1_softmax_onehot, - # 'simple_model_sigmoid_onehot': simple_model_sigmoid_onehot, + "mobilenet_v3small": mobilenet_v3small, + "vgg16": vgg16, } From ad66a543c0387eafe9acbedfc161d38389628bfb Mon Sep 17 00:00:00 2001 From: kyle-woodward Date: Mon, 8 Apr 2024 15:31:27 -0400 Subject: [PATCH 2/6] ignore tfrecords --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a61a696..32a5c37 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ logs/ /**/*.png *.tif *.tiff +*.tfrecord.gz # Byte-compiled / optimized / DLL files __pycache__/ From f778554b28d2984458a5b353803201233e292cb0 Mon Sep 17 00:00:00 2001 From: kyle-woodward Date: Mon, 8 Apr 2024 15:31:50 -0400 Subject: [PATCH 3/6] my ymls, plus model_fit to CLI --- fao_models/model_fit.py | 188 ++++++++++++++++++++----------------- runc1.yml => runc1_kdw.yml | 0 runc2.yml => runc2_kdw.yml | 0 runc3.yml => runc3_kdw.yml | 0 runc4_kdw.yml | 16 ++++ runc5_kdw.yml | 16 ++++ 6 files changed, 136 insertions(+), 84 deletions(-) rename runc1.yml => runc1_kdw.yml (100%) rename runc2.yml => runc2_kdw.yml (100%) rename runc3.yml => runc3_kdw.yml (100%) create mode 100644 runc4_kdw.yml create mode 100644 runc5_kdw.yml diff --git a/fao_models/model_fit.py b/fao_models/model_fit.py index 973afb7..ed29bde 100644 --- a/fao_models/model_fit.py +++ b/fao_models/model_fit.py @@ -9,9 +9,10 @@ import yaml from pprint import pformat from functools import partial +import argparse -# TODO: make this single CLI arg input -config_file = r"runc3.yml" +# # TODO: make this single CLI arg input +# config_file = r"runc3.yml" # setup logging logging.basicConfig( @@ -26,94 +27,113 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -with open(config_file, "r") as file: - config_data = yaml.safe_load(file) - -# retrieve parameters -experiment_name = config_data["experiment_name"] -model_name = config_data["model_name"] -total_examples = config_data["total_examples"] -data_dir = config_data["data_dir"] -data_split = config_data["data_split"] -epochs = config_data["epochs"] -learning_rate = config_data["learning_rate"] -batch_size = config_data["batch_size"] -buffer_size = config_data["buffer_size"] -optimizer = config_data["optimizer"] -optimizer_use_lr_schedular = config_data["optimizer_use_lr_schedular"] -loss_function = config_data["loss_function"] -early_stopping_patience = config_data["early_stopping_patience"] - -# hyperbolically decrease the learning rate to 1/2 of the base rate at 1,000 epochs, 1/3 at 2,000 epochs, and so on. -if optimizer == "adam": - if optimizer_use_lr_schedular: - steps_per_epoch = total_examples * data_split // batch_size - lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay( - initial_learning_rate=learning_rate, - decay_steps=steps_per_epoch * epochs, - decay_rate=1, - staircase=False, - ) - logger.info( - f"Using a learning rate schedule of InverseTimeDecay, decay_steps={steps_per_epoch*epochs}" - ) - optimizer = tf.keras.optimizers.Adam(lr_schedule) - else: - optimizer = tf.keras.optimizers.Adam() - -# pull model from config -model = get_model(model_name, optimizer=optimizer, loss_fn=loss_function) -print(model.summary()) +def main(): -logger.info("Config file: %s", config_file) -logger.info("Parameters:") -logger.info(pformat(config_data)) + # initalize new cli parser + parser = argparse.ArgumentParser(description="Train a model with a .yml file.") -# Load the dataset without batching -dataset = dl.load_dataset_from_tfrecords(data_dir) + parser.add_argument( + "-c", + "--config", + type=str, + help="path to .yml file", + ) -# Split the dataset into training and testing -train_dataset, test_dataset = dl.split_dataset( - dataset, total_examples, test_split=data_split, batch_size=batch_size -) -train_dataset = train_dataset.shuffle(buffer_size, reshuffle_each_iteration=True) + args = parser.parse_args() + + config_file = args.config + + with open(config_file, "r") as file: + config_data = yaml.safe_load(file) + + # retrieve parameters + experiment_name = config_data["experiment_name"] + model_name = config_data["model_name"] + total_examples = config_data["total_examples"] + data_dir = config_data["data_dir"] + data_split = config_data["data_split"] + epochs = config_data["epochs"] + learning_rate = config_data["learning_rate"] + batch_size = config_data["batch_size"] + buffer_size = config_data["buffer_size"] + optimizer = config_data["optimizer"] + optimizer_use_lr_schedular = config_data["optimizer_use_lr_schedular"] + loss_function = config_data["loss_function"] + early_stopping_patience = config_data["early_stopping_patience"] + + # hyperbolically decrease the learning rate to 1/2 of the base rate at 1,000 epochs, 1/3 at 2,000 epochs, and so on. + if optimizer == "adam": + if optimizer_use_lr_schedular: + steps_per_epoch = total_examples * data_split // batch_size + lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay( + initial_learning_rate=learning_rate, + decay_steps=steps_per_epoch * epochs, + decay_rate=1, + staircase=False, + ) + logger.info( + f"Using a learning rate schedule of InverseTimeDecay, decay_steps={steps_per_epoch*epochs}" + ) + optimizer = tf.keras.optimizers.Adam(lr_schedule) + else: + optimizer = tf.keras.optimizers.Adam() + + # pull model from config + model = get_model(model_name, optimizer=optimizer, loss_fn=loss_function) + print(model.summary()) + + logger.info("Config file: %s", config_file) + logger.info("Parameters:") + logger.info(pformat(config_data)) + + # Load the dataset without batching + dataset = dl.load_dataset_from_tfrecords(data_dir) + + # Split the dataset into training and testing + train_dataset, test_dataset = dl.split_dataset( + dataset, total_examples, test_split=data_split, batch_size=batch_size + ) + train_dataset = train_dataset.shuffle(buffer_size, reshuffle_each_iteration=True) -logger.info("Starting model training...") -LOGS_DIR = os.path.join( - os.path.dirname(os.path.dirname(__file__)), "logs", experiment_name -) -if not os.path.exists(LOGS_DIR): - os.makedirs(LOGS_DIR) - -# setup for confusion matrix callback -tb_samples = train_dataset.take(1) -x = list(map(lambda x: x[0], tb_samples))[0] -y = list(map(lambda x: x[1], tb_samples))[0] -class_names = ["nonforest", "forest"] - -# initialize and add tb callbacks -callbacks = [] -file_writer = tf.summary.create_file_writer(LOGS_DIR) -cm_callback = CmCallback(y, x, class_names, file_writer) - -if early_stopping_patience is not None: - logger.info(f"Using early stopping. Patience: {early_stopping_patience}") - early_stop = tf.keras.callbacks.EarlyStopping( - monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True + logger.info("Starting model training...") + LOGS_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "logs", experiment_name ) - callbacks.append(early_stop) -callbacks.append(cm_callback) -callbacks.append(tf.keras.callbacks.TensorBoard(LOGS_DIR)) + if not os.path.exists(LOGS_DIR): + os.makedirs(LOGS_DIR) + + # setup for confusion matrix callback + tb_samples = train_dataset.take(1) + x = list(map(lambda x: x[0], tb_samples))[0] + y = list(map(lambda x: x[1], tb_samples))[0] + class_names = ["nonforest", "forest"] + + # initialize and add tb callbacks + callbacks = [] + file_writer = tf.summary.create_file_writer(LOGS_DIR) + cm_callback = CmCallback(y, x, class_names, file_writer) + + if early_stopping_patience is not None: + logger.info(f"Using early stopping. Patience: {early_stopping_patience}") + early_stop = tf.keras.callbacks.EarlyStopping( + monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True + ) + callbacks.append(early_stop) + callbacks.append(cm_callback) + callbacks.append(tf.keras.callbacks.TensorBoard(LOGS_DIR)) -history = model.fit( - train_dataset, - epochs=epochs, - validation_data=test_dataset, - callbacks=callbacks, -) + history = model.fit( + train_dataset, + epochs=epochs, + validation_data=test_dataset, + callbacks=callbacks, + ) + + logger.info("Model training complete") + logger.info("Training history:") + logger.info(pformat(history.history)) -logger.info("Model training complete") -logger.info("Training history:") -logger.info(pformat(history.history)) +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/runc1.yml b/runc1_kdw.yml similarity index 100% rename from runc1.yml rename to runc1_kdw.yml diff --git a/runc2.yml b/runc2_kdw.yml similarity index 100% rename from runc2.yml rename to runc2_kdw.yml diff --git a/runc3.yml b/runc3_kdw.yml similarity index 100% rename from runc3.yml rename to runc3_kdw.yml diff --git a/runc4_kdw.yml b/runc4_kdw.yml new file mode 100644 index 0000000..0bddc99 --- /dev/null +++ b/runc4_kdw.yml @@ -0,0 +1,16 @@ +experiment_name: "mobilenet_v3small_try1" +model_name: "mobilenet_v3small" +data_dir: "data_balanced" +total_examples: 77046 # number of geotiffs not tfrecords +data_split: 0.2 + +optimizer: "adam" +optimizer_use_lr_schedular: true +loss_function: "binary_crossentropy" + +epochs: 50 +learning_rate: 0.001 +batch_size: 512 +buffer_size: 77046 + +early_stopping_patience: null # null or int diff --git a/runc5_kdw.yml b/runc5_kdw.yml new file mode 100644 index 0000000..6e0b34f --- /dev/null +++ b/runc5_kdw.yml @@ -0,0 +1,16 @@ +experiment_name: "mobilenet_v3small_try2" +model_name: "mobilenet_v3small" +data_dir: "data_balanced" +total_examples: 77046 # number of geotiffs not tfrecords +data_split: 0.2 + +optimizer: "adam" +optimizer_use_lr_schedular: true +loss_function: "binary_crossentropy" + +epochs: 100 +learning_rate: 0.01 +batch_size: 128 +buffer_size: 77046 + +early_stopping_patience: null # null or int From 513d0369a34632f038c41ac22a434a56ab44e024 Mon Sep 17 00:00:00 2001 From: kyle-woodward Date: Mon, 8 Apr 2024 16:33:23 -0400 Subject: [PATCH 4/6] vgg16 try1 yml --- runc6_kdw.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 runc6_kdw.yml diff --git a/runc6_kdw.yml b/runc6_kdw.yml new file mode 100644 index 0000000..3bb6f26 --- /dev/null +++ b/runc6_kdw.yml @@ -0,0 +1,16 @@ +experiment_name: "vgg16_try1" +model_name: "vgg16" +data_dir: "data_balanced" +total_examples: 77046 # number of geotiffs not tfrecords +data_split: 0.2 + +optimizer: "adam" +optimizer_use_lr_schedular: true +loss_function: "binary_crossentropy" + +epochs: 100 +learning_rate: 0.01 +batch_size: 128 +buffer_size: 77046 + +early_stopping_patience: null # null or int From 3184ec863a56d5567aeb9478396eea3b22937e91 Mon Sep 17 00:00:00 2001 From: John Dilger Date: Tue, 9 Apr 2024 13:33:58 -0500 Subject: [PATCH 5/6] fix batch, shuffle data once on load --- fao_models/dataloader.py | 27 +++++++++++++++++---------- fao_models/model_fit.py | 14 ++++++++------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/fao_models/dataloader.py b/fao_models/dataloader.py index dd87109..19274b0 100644 --- a/fao_models/dataloader.py +++ b/fao_models/dataloader.py @@ -1,20 +1,22 @@ import tensorflow as tf import os + def _parse_function(proto): # Define the parsing schema feature_description = { - 'image': tf.io.FixedLenFeature([], tf.string), - 'label': tf.io.FixedLenFeature([], tf.string), + "image": tf.io.FixedLenFeature([], tf.string), + "label": tf.io.FixedLenFeature([], tf.string), } # Parse the input `tf.train.Example` proto using the schema example = tf.io.parse_single_example(proto, feature_description) - image = tf.io.parse_tensor(example['image'], out_type=tf.float32) - label = tf.io.parse_tensor(example['label'], out_type=tf.int64) + image = tf.io.parse_tensor(example["image"], out_type=tf.float32) + label = tf.io.parse_tensor(example["label"], out_type=tf.int64) image.set_shape([32, 32, 4]) # Set the shape explicitly if not already defined label.set_shape([]) # For scalar labels return image, label + def load_dataset_from_tfrecords(tfrecord_dir, batch_size=32): pattern = tfrecord_dir + "/*.tfrecord.gz" @@ -22,12 +24,15 @@ def load_dataset_from_tfrecords(tfrecord_dir, batch_size=32): dataset = files.interleave( lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"), cycle_length=tf.data.AUTOTUNE, - block_length=1 + block_length=1, + ) + dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE).batch( + batch_size, drop_remainder=True ) - dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE) - dataset = dataset.shuffle(buffer_size=1000) + dataset = dataset.shuffle(buffer_size=100_000, seed=42) return dataset + def split_dataset(dataset, total_examples, test_split=0.2, batch_size=32): test_size = int(total_examples * test_split) train_size = total_examples - test_size @@ -36,7 +41,9 @@ def split_dataset(dataset, total_examples, test_split=0.2, batch_size=32): train_batches = train_size // batch_size test_batches = test_size // batch_size - train_dataset = dataset.take(train_batches).batch(batch_size).prefetch(tf.data.AUTOTUNE) - test_dataset = dataset.skip(train_batches).take(test_batches).batch(batch_size).prefetch(tf.data.AUTOTUNE) + train_dataset = dataset.take(train_batches).prefetch(tf.data.AUTOTUNE) + test_dataset = ( + dataset.skip(train_batches).take(test_batches).prefetch(tf.data.AUTOTUNE) + ) - return train_dataset, test_dataset \ No newline at end of file + return train_dataset, test_dataset diff --git a/fao_models/model_fit.py b/fao_models/model_fit.py index ed29bde..0c2d092 100644 --- a/fao_models/model_fit.py +++ b/fao_models/model_fit.py @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + def main(): # initalize new cli parser @@ -42,7 +43,7 @@ def main(): args = parser.parse_args() config_file = args.config - + with open(config_file, "r") as file: config_data = yaml.safe_load(file) @@ -87,7 +88,7 @@ def main(): logger.info(pformat(config_data)) # Load the dataset without batching - dataset = dl.load_dataset_from_tfrecords(data_dir) + dataset = dl.load_dataset_from_tfrecords(data_dir, batch_size=batch_size) # Split the dataset into training and testing train_dataset, test_dataset = dl.split_dataset( @@ -95,7 +96,6 @@ def main(): ) train_dataset = train_dataset.shuffle(buffer_size, reshuffle_each_iteration=True) - logger.info("Starting model training...") LOGS_DIR = os.path.join( os.path.dirname(os.path.dirname(__file__)), "logs", experiment_name @@ -117,13 +117,14 @@ def main(): if early_stopping_patience is not None: logger.info(f"Using early stopping. Patience: {early_stopping_patience}") early_stop = tf.keras.callbacks.EarlyStopping( - monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True + monitor="val_loss", + patience=early_stopping_patience, + restore_best_weights=True, ) callbacks.append(early_stop) callbacks.append(cm_callback) callbacks.append(tf.keras.callbacks.TensorBoard(LOGS_DIR)) - history = model.fit( train_dataset, epochs=epochs, @@ -135,5 +136,6 @@ def main(): logger.info("Training history:") logger.info(pformat(history.history)) + if __name__ == "__main__": - main() \ No newline at end of file + main() From a42ee460243b8a282273025ef39ca426cedca1fb Mon Sep 17 00:00:00 2001 From: John Dilger Date: Tue, 9 Apr 2024 13:34:36 -0500 Subject: [PATCH 6/6] add forest/non forest count to cm title --- fao_models/models.py | 62 ++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/fao_models/models.py b/fao_models/models.py index 2a77682..1af15f0 100644 --- a/fao_models/models.py +++ b/fao_models/models.py @@ -28,7 +28,7 @@ def plot_to_image(figure): return image -def plot_confusion_matrix(cm, class_names): +def plot_confusion_matrix(cm, class_names, count_forest, count_nonforest): """ Returns a matplotlib figure containing the plotted confusion matrix. @@ -38,7 +38,7 @@ def plot_confusion_matrix(cm, class_names): """ figure = plt.figure(figsize=(8, 8)) plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues) - plt.title("Confusion matrix") + plt.title(f"Confusion matrix: F-{count_forest} NF-{count_nonforest}") plt.colorbar() tick_marks = np.arange(len(class_names)) plt.xticks(tick_marks, class_names, rotation=45) @@ -74,11 +74,18 @@ def log_confusion_matrix(self, epoch): # Use the model to predict the values from the validation dataset. test_pred_raw = self.model.predict(self.test_images) test_pred = np.rint(test_pred_raw) + count_nf = sum(self.test_labels.numpy() == 0) + count_f = sum(self.test_labels.numpy() == 1) # Calculate the confusion matrix. cm = confusion_matrix(self.test_labels, test_pred) # Log the confusion matrix as an image summary. - figure = plot_confusion_matrix(cm, class_names=self.class_names) + figure = plot_confusion_matrix( + cm, + class_names=self.class_names, + count_forest=count_f, + count_nonforest=count_nf, + ) cm_image = plot_to_image(figure) # Log the confusion matrix as an image summary. @@ -128,6 +135,7 @@ def dice_loss(y_true, y_pred, smooth=1): evaluation_metrics = [categorical_accuracy, f1_m, precision_m, recall_m] + def resnet( optimizer, loss_fn, @@ -171,20 +179,23 @@ def create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1): return create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1) -def mobilenet_v3small(optimizer,loss_fn,metrics=[ +def mobilenet_v3small( + optimizer, + loss_fn, + metrics=[ dice_coef, "binary_accuracy", ], ): - input_shape = (32,32,4) - + input_shape = (32, 32, 4) + base_model = tf.keras.applications.MobileNetV3Small( - input_shape=input_shape, - include_top=False, - weights=None, - classes=2, - include_preprocessing=False -) + input_shape=input_shape, + include_top=False, + weights=None, + classes=2, + include_preprocessing=False, + ) # Custom model inputs = layers.Input(shape=input_shape) @@ -196,9 +207,7 @@ def mobilenet_v3small(optimizer,loss_fn,metrics=[ # Add custom top layers x = layers.Flatten()(x) x = layers.Dense(256, activation="relu")(x) - outputs = layers.Dense(1, activation="sigmoid")( - x - ) + outputs = layers.Dense(1, activation="sigmoid")(x) # Create the final model model = models.Model(inputs=inputs, outputs=outputs) @@ -206,19 +215,23 @@ def mobilenet_v3small(optimizer,loss_fn,metrics=[ return model -def vgg16(optimizer,loss_fn,metrics=[ + +def vgg16( + optimizer, + loss_fn, + metrics=[ dice_coef, "binary_accuracy", ], ): - input_shape = (32,32,4) + input_shape = (32, 32, 4) base_model = tf.keras.applications.VGG16( - include_top=False, - weights=None, - input_shape=input_shape, - classes=2, + include_top=False, + weights=None, + input_shape=input_shape, + classes=2, ) - + # Custom model inputs = layers.Input(shape=input_shape) @@ -230,9 +243,7 @@ def vgg16(optimizer,loss_fn,metrics=[ # Add custom top layers x = layers.Flatten()(x) x = layers.Dense(256, activation="relu")(x) - outputs = layers.Dense(1, activation="sigmoid")( - x - ) + outputs = layers.Dense(1, activation="sigmoid")(x) # Create the final model model = models.Model(inputs=inputs, outputs=outputs) @@ -240,6 +251,7 @@ def vgg16(optimizer,loss_fn,metrics=[ return model + # Create a dictionary of keyword-function pairs model_dict = { "resnet": resnet,