Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Model state should support PyTorch API (#727)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #727

Classy Models should work like regular PyTorch models. The `{get, set}_classy_state` functions for state are the only blockers which this diff fixes by moving over to `state_dict` and `load_state_dict`.

`{get, set}_classy_state` will still work for backwards compatibility, but will call the PyTorch functions directly.

Differential Revision: D25213283

fbshipit-source-id: 3cd64f530de83574174884d3b8848a1f6003f854
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Mar 24, 2021
1 parent c4d9725 commit b73871e
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions classy_vision/models/classy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def from_checkpoint(cls, checkpoint):
model.set_classy_state(checkpoint["classy_state_dict"]["base_model"])
return model

def get_classy_state(self, deep_copy=False):
def state_dict(self, deep_copy=False):
"""Get the state of the ClassyModel.
The returned state is used for checkpointing.
Expand All @@ -222,7 +222,7 @@ def get_classy_state(self, deep_copy=False):
# as the trunk state. If the model doesn't have heads attached, all of the
# model's state lives in the trunk.
self.clear_heads()
trunk_state_dict = self.state_dict()
trunk_state_dict = super().state_dict()
self.set_heads(attached_heads)

head_state_dict = {}
Expand Down Expand Up @@ -252,7 +252,7 @@ def load_head_states(self, state, strict=True):
for head_name, head_state in head_states.items():
self._heads[block_name][head_name].load_state_dict(head_state, strict)

def set_classy_state(self, state, strict=True):
def load_state_dict(self, state, strict=True):
"""Set the state of the ClassyModel.
Args:
Expand All @@ -270,11 +270,17 @@ def set_classy_state(self, state, strict=True):
# fetched / set when there are no blocks attached.
attached_heads = self.get_heads()
self.clear_heads()
self.load_state_dict(state["model"]["trunk"], strict)
super().load_state_dict(state["model"]["trunk"], strict)

# set the heads back again
self.set_heads(attached_heads)

def get_classy_state(self, deep_copy=False):
return self.state_dict(deep_copy=deep_copy)

def set_classy_state(self, state, strict=True):
self.load_state_dict(state, strict=strict)

def forward(self, x):
"""
Perform computation of blocks in the order define in get_blocks.
Expand Down

0 comments on commit b73871e

Please sign in to comment.