diff --git a/examples/demo_mnist_convnet.py b/examples/demo_mnist_convnet.py index f5f4e3f4d..d1d45a2eb 100644 --- a/examples/demo_mnist_convnet.py +++ b/examples/demo_mnist_convnet.py @@ -35,6 +35,8 @@ layers.MaxPooling2D(pool_size=(2, 2)), layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(128, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), layers.Flatten(), layers.Dropout(0.5), layers.Dense(num_classes, activation="softmax"), @@ -44,7 +46,7 @@ model.summary() model.compile( - loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] + loss="categorical_crossentropy", optimizer=keras_core.optimizers.SGD(learning_rate=0.001), metrics=["accuracy"] ) model.fit(