Skip to content

Commit

Permalink
fix non-persistent params
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jan 3, 2025
1 parent 3591fe9 commit 644ce53
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,19 @@ def build_model_with_cfg(
else:
model = model_cls(cfg=model_cfg, **kwargs)
if pretrained:
# .to_empty() will also move cpu params/buffers to uninitialized storage.
# this is problematic for non-persistent buffers, since they don't get loaded
# from pretrained weights later (not part of state_dict). hence, we have
# to save them before calling .to_empty() and fill them back after.
buffers = {k: v for k, v in model.named_buffers() if not v.is_meta}
model.to_empty(device="cpu")
for k, v in model.named_buffers():
if k in buffers:
v.data = buffers[k]

# alternative, rely on internal method ._apply()
# model._apply(lambda t: torch.empty_like(t, device="cpu") if t.is_meta else t)

model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat

Expand Down

0 comments on commit 644ce53

Please sign in to comment.