Skip to content

Commit

Permalink
Update to Keras 2
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed May 23, 2017
1 parent fa600c6 commit c37cde2
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 34 deletions.
67 changes: 67 additions & 0 deletions .idea/markdown-navigator.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions .idea/markdown-navigator/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions cifar10_wrn_16_8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import keras.callbacks as callbacks
import keras.utils.np_utils as kutils
from keras.preprocessing.image import ImageDataGenerator
from keras.utils.visualize_util import plot
from keras.utils import plot_model

from keras import backend as K

Expand Down Expand Up @@ -39,7 +39,7 @@
model = wrn.create_wide_residual_network(init_shape, nb_classes=10, N=2, k=8, dropout=0.00)

model.summary()
#plot(model, "WRN-16-8.png", show_shapes=False)
#plot_model(model, "WRN-16-8.png", show_shapes=False)

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["acc"])
print("Finished compiling")
Expand All @@ -48,10 +48,10 @@
print("Model loaded.")
print("Allocating GPU memory")

#model.fit_generator(generator.flow(trainX, trainY, batch_size=batch_size), samples_per_epoch=len(trainX), nb_epoch=nb_epoch,
# model.fit_generator(generator.flow(trainX, trainY, batch_size=batch_size), steps_per_epoch=len(trainX) // batch_size + 1, nb_epoch=nb_epoch,
# callbacks=[callbacks.ModelCheckpoint("WRN-16-8 Weights.h5", monitor="val_acc", save_best_only=True)],
# validation_data=(testX, testY),
# nb_val_samples=testX.shape[0],)
# validation_steps=testX.shape[0] // batch_size + 1,)

yPreds = model.predict(testX)
yPred = np.argmax(yPreds, axis=1)
Expand Down
14 changes: 7 additions & 7 deletions cifar10_wrn_28_8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import keras.callbacks as callbacks
import keras.utils.np_utils as kutils
from keras.preprocessing.image import ImageDataGenerator
from keras.utils.visualize_util import plot
from keras.utils import plot_model

from keras import backend as K

Expand Down Expand Up @@ -47,7 +47,7 @@
model = wrn.create_wide_residual_network(init_shape, nb_classes=10, N=4, k=8, dropout=0.0)

model.summary()
#plot(model, "WRN-28-8.png", show_shapes=False)
#plot_model(model, "WRN-28-8.png", show_shapes=False)

model.compile(loss="categorical_crossentropy", optimizer="adadelta", metrics=["acc"])
print("Finished compiling")
Expand All @@ -56,10 +56,10 @@
model.load_weights("weights/WRN-28-8 Weights.h5")
print("Model loaded.")

#model.fit_generator(generator.flow(trainX, trainY, batch_size=batch_size), samples_per_epoch=len(trainX), nb_epoch=nb_epoch,
# callbacks=[callbacks.ModelCheckpoint("WRN-28-8 Weights.h5", monitor="val_acc", save_best_only=True)],
# validation_data=test_generator.flow(testX, testY, batch_size=batch_size),
# nb_val_samples=testX.shape[0],)
model.fit_generator(generator.flow(trainX, trainY, batch_size=batch_size), steps_per_epoch=len(trainX) // batch_size + 1, nb_epoch=nb_epoch,
callbacks=[callbacks.ModelCheckpoint("WRN-28-8 Weights.h5", monitor="val_acc", save_best_only=True)],
validation_data=test_generator.flow(testX, testY, batch_size=batch_size),
validation_steps=testX.shape[0] // batch_size + 1,)

scores = model.evaluate_generator(test_generator.flow(testX, testY, nb_epoch), testX.shape[0])
scores = model.evaluate_generator(test_generator.flow(testX, testY, nb_epoch), testX.shape[0] // batch_size + 1)
print("Accuracy = %f" % (100 * scores[1]))
3 changes: 3 additions & 0 deletions weights/ADD WEIGHT FILES HERE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Please download the approproate Weight files and extract them here.

Weights for WRN-16-8 and WRN-28-8 are available [in the release tab](https://github.com/titu1994/Wide-Residual-Networks/releases)
Loading

0 comments on commit c37cde2

Please sign in to comment.