Skip to content

Commit

Permalink
add forest/non forest count to cm title
Browse files Browse the repository at this point in the history
  • Loading branch information
jdilger committed Apr 9, 2024
1 parent 3184ec8 commit a42ee46
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions fao_models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def plot_to_image(figure):
return image


def plot_confusion_matrix(cm, class_names):
def plot_confusion_matrix(cm, class_names, count_forest, count_nonforest):
"""
Returns a matplotlib figure containing the plotted confusion matrix.
Expand All @@ -38,7 +38,7 @@ def plot_confusion_matrix(cm, class_names):
"""
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.title(f"Confusion matrix: F-{count_forest} NF-{count_nonforest}")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
Expand Down Expand Up @@ -74,11 +74,18 @@ def log_confusion_matrix(self, epoch):
# Use the model to predict the values from the validation dataset.
test_pred_raw = self.model.predict(self.test_images)
test_pred = np.rint(test_pred_raw)
count_nf = sum(self.test_labels.numpy() == 0)
count_f = sum(self.test_labels.numpy() == 1)

# Calculate the confusion matrix.
cm = confusion_matrix(self.test_labels, test_pred)
# Log the confusion matrix as an image summary.
figure = plot_confusion_matrix(cm, class_names=self.class_names)
figure = plot_confusion_matrix(
cm,
class_names=self.class_names,
count_forest=count_f,
count_nonforest=count_nf,
)
cm_image = plot_to_image(figure)

# Log the confusion matrix as an image summary.
Expand Down Expand Up @@ -128,6 +135,7 @@ def dice_loss(y_true, y_pred, smooth=1):

evaluation_metrics = [categorical_accuracy, f1_m, precision_m, recall_m]


def resnet(
optimizer,
loss_fn,
Expand Down Expand Up @@ -171,20 +179,23 @@ def create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1):
return create_resnet_with_4_channels(input_shape=(32, 32, 4), num_classes=1)


def mobilenet_v3small(optimizer,loss_fn,metrics=[
def mobilenet_v3small(
optimizer,
loss_fn,
metrics=[
dice_coef,
"binary_accuracy",
],
):
input_shape = (32,32,4)
input_shape = (32, 32, 4)

base_model = tf.keras.applications.MobileNetV3Small(
input_shape=input_shape,
include_top=False,
weights=None,
classes=2,
include_preprocessing=False
)
input_shape=input_shape,
include_top=False,
weights=None,
classes=2,
include_preprocessing=False,
)
# Custom model
inputs = layers.Input(shape=input_shape)

Expand All @@ -196,29 +207,31 @@ def mobilenet_v3small(optimizer,loss_fn,metrics=[
# Add custom top layers
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
outputs = layers.Dense(1, activation="sigmoid")(
x
)
outputs = layers.Dense(1, activation="sigmoid")(x)
# Create the final model
model = models.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

return model

def vgg16(optimizer,loss_fn,metrics=[

def vgg16(
optimizer,
loss_fn,
metrics=[
dice_coef,
"binary_accuracy",
],
):
input_shape = (32,32,4)
input_shape = (32, 32, 4)
base_model = tf.keras.applications.VGG16(
include_top=False,
weights=None,
input_shape=input_shape,
classes=2,
include_top=False,
weights=None,
input_shape=input_shape,
classes=2,
)

# Custom model
inputs = layers.Input(shape=input_shape)

Expand All @@ -230,16 +243,15 @@ def vgg16(optimizer,loss_fn,metrics=[
# Add custom top layers
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
outputs = layers.Dense(1, activation="sigmoid")(
x
)
outputs = layers.Dense(1, activation="sigmoid")(x)
# Create the final model
model = models.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

return model


# Create a dictionary of keyword-function pairs
model_dict = {
"resnet": resnet,
Expand Down

0 comments on commit a42ee46

Please sign in to comment.