Skip to content

Commit

Permalink
Backend TensorFlow: Fix model saving (#1782)
Browse files Browse the repository at this point in the history
  • Loading branch information
agniv-the-marker authored Jun 20, 2024
1 parent 737c2f8 commit 51581b2
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,6 @@ def save(self, save_path, protocol="backend", verbose=0):
Returns:
string: Path where model is saved.
"""
# TODO: backend tensorflow
save_path = f"{save_path}-{self.train_state.epoch}"
if protocol == "pickle":
save_path += ".pkl"
Expand All @@ -1042,7 +1041,7 @@ def save(self, save_path, protocol="backend", verbose=0):
save_path += ".ckpt"
self.saver.save(self.sess, save_path)
elif backend_name == "tensorflow":
save_path += ".ckpt"
save_path += ".weights.h5"
self.net.save_weights(save_path)
elif backend_name == "pytorch":
save_path += ".pt"
Expand Down

0 comments on commit 51581b2

Please sign in to comment.