Skip to content

Commit

Permalink
Use state_dict for saving, update LICENSE, vgg-download link
Browse files Browse the repository at this point in the history
  • Loading branch information
abhiskk committed Apr 3, 2017
1 parent bf4aecf commit 2c2443c
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
The MIT License (MIT)

Copyright (c) 2016 Abhishek Kadian
Copyright (c) 2017 Abhishek Kadian

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
5 changes: 3 additions & 2 deletions neural_style/neural_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def train(args):
save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
args.content_weight) + "_" + str(args.style_weight) + ".model"
save_model_path = os.path.join(args.save_model_dir, save_model_filename)
torch.save(transformer, save_model_path)
torch.save(transformer.state_dict(), save_model_path)

print("\nDone, trained model saved at", save_model_path)

Expand All @@ -133,7 +133,8 @@ def stylize(args):
if args.cuda:
content_image = content_image.cuda()
content_image = Variable(utils.preprocess_batch(content_image))
style_model = torch.load(args.model)
style_model = TransformerNet()
style_model.load_state_dict(torch.load(args.model))

if args.cuda:
style_model.cuda()
Expand Down
4 changes: 1 addition & 3 deletions neural_style/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def gram_matrix(y):


def subtract_imagenet_mean_batch(batch):
"""Subtract ImageNet mean pixel-wise from a BGR image."""
tensortype = type(batch.data)
mean = tensortype(batch.data.size())
mean[:, 0, :, :] = 103.939
Expand All @@ -63,11 +62,10 @@ def preprocess_batch(batch):


def init_vgg16(model_folder):
"""load the vgg16 model feature"""
if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
os.system(
'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7'))
'wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_folder, 'vgg16.t7'))
vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7'))
vgg = Vgg16()
for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
Expand Down
Binary file renamed saved-models/candy.model → saved-models/candy.pth
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file renamed saved-models/udnie.model → saved-models/udnie.pth
Binary file not shown.

0 comments on commit 2c2443c

Please sign in to comment.