From a42ee460243b8a282273025ef39ca426cedca1fb Mon Sep 17 00:00:00 2001 From: John Dilger Date: Tue, 9 Apr 2024 13:34:36 -0500 Subject: [PATCH] add forest/non forest count to cm title --- fao_models/models.py | 62 ++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/fao_models/models.py b/fao_models/models.py index 2a77682..1af15f0 100644 --- a/fao_models/models.py +++ b/fao_models/models.py @@ -28,7 +28,7 @@ def plot_to_image(figure): return image -def plot_confusion_matrix(cm, class_names): +def plot_confusion_matrix(cm, class_names, count_forest, count_nonforest): """ Returns a matplotlib figure containing the plotted confusion matrix. @@ -38,7 +38,7 @@ def plot_confusion_matrix(cm, class_names): """ figure = plt.figure(figsize=(8, 8)) plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues) - plt.title("Confusion matrix") + plt.title(f"Confusion matrix: F-{count_forest} NF-{count_nonforest}") plt.colorbar() tick_marks = np.arange(len(class_names)) plt.xticks(tick_marks, class_names, rotation=45) @@ -74,11 +74,18 @@ def log_confusion_matrix(self, epoch): # Use the model to predict the values from the validation dataset. test_pred_raw = self.model.predict(self.test_images) test_pred = np.rint(test_pred_raw) + count_nf = sum(self.test_labels.numpy() == 0) + count_f = sum(self.test_labels.numpy() == 1) # Calculate the confusion matrix. cm = confusion_matrix(self.test_labels, test_pred) # Log the confusion matrix as an image summary. - figure = plot_confusion_matrix(cm, class_names=self.class_names) + figure = plot_confusion_matrix( + cm, + class_names=self.class_names, + count_forest=count_f, + count_nonforest=count_nf, + ) cm_image = plot_to_image(figure) # Log the confusion matrix as an image summary. @@ -128,6 +135,7 @@ def dice_loss(y_true, y_pred, smooth=1): evaluation_metrics = [categorical_accuracy, f1_m, precision_m, recall_m] + def resnet( optimizer, loss_fn, @@ -171,20 +179,23 @@ def create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1): return create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1) -def mobilenet_v3small(optimizer,loss_fn,metrics=[ +def mobilenet_v3small( + optimizer, + loss_fn, + metrics=[ dice_coef, "binary_accuracy", ], ): - input_shape = (32,32,4) - + input_shape = (32, 32, 4) + base_model = tf.keras.applications.MobileNetV3Small( - input_shape=input_shape, - include_top=False, - weights=None, - classes=2, - include_preprocessing=False -) + input_shape=input_shape, + include_top=False, + weights=None, + classes=2, + include_preprocessing=False, + ) # Custom model inputs = layers.Input(shape=input_shape) @@ -196,9 +207,7 @@ def mobilenet_v3small(optimizer,loss_fn,metrics=[ # Add custom top layers x = layers.Flatten()(x) x = layers.Dense(256, activation="relu")(x) - outputs = layers.Dense(1, activation="sigmoid")( - x - ) + outputs = layers.Dense(1, activation="sigmoid")(x) # Create the final model model = models.Model(inputs=inputs, outputs=outputs) @@ -206,19 +215,23 @@ def mobilenet_v3small(optimizer,loss_fn,metrics=[ return model -def vgg16(optimizer,loss_fn,metrics=[ + +def vgg16( + optimizer, + loss_fn, + metrics=[ dice_coef, "binary_accuracy", ], ): - input_shape = (32,32,4) + input_shape = (32, 32, 4) base_model = tf.keras.applications.VGG16( - include_top=False, - weights=None, - input_shape=input_shape, - classes=2, + include_top=False, + weights=None, + input_shape=input_shape, + classes=2, ) - + # Custom model inputs = layers.Input(shape=input_shape) @@ -230,9 +243,7 @@ def vgg16(optimizer,loss_fn,metrics=[ # Add custom top layers x = layers.Flatten()(x) x = layers.Dense(256, activation="relu")(x) - outputs = layers.Dense(1, activation="sigmoid")( - x - ) + outputs = layers.Dense(1, activation="sigmoid")(x) # Create the final model model = models.Model(inputs=inputs, outputs=outputs) @@ -240,6 +251,7 @@ def vgg16(optimizer,loss_fn,metrics=[ return model + # Create a dictionary of keyword-function pairs model_dict = { "resnet": resnet,