Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing max_history_size causes checkpoint loading to fail. #35

Closed
scottcha opened this issue Sep 16, 2024 · 1 comment
Closed

Changing max_history_size causes checkpoint loading to fail. #35

scottcha opened this issue Sep 16, 2024 · 1 comment

Comments

@scottcha
Copy link
Contributor

The current logic to adjust the weight dimensions when loading a checkpoint seems to assume a max_history_size of 2 (the default). Increasing this causes key mismatch errors when using load_checkpoint. Since there is already logic to adjust the shared weights this should also account for a different max_history_size.

Repro steps:

from aurora import AuroraSmall
model = AuroraSmall(use_lora=False, max_history_size=4)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt", strict=False)

Expected result:
Checkpoint should load. I also am assuming the correct logic to fill in the new weights would be to copy the existing weights for each var.

Actual Result.
Checkpoint fails to load.

{
	"name": "RuntimeError",
	"message": "Error(s) in loading state_dict for Aurora:
\tsize mismatch for encoder.surf_token_embeds.weights.10u: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.10v: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.2t: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.lsm: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.msl: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.slt: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.z: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.q: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.t: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.u: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.v: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.z: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[319], line 1
----> 1 model.load_checkpoint(\"microsoft/aurora\", \"aurora-0.25-small-pretrained.ckpt\", strict=False)

File c:\\Users\\scott\\miniconda3\\envs\\aurora\\lib\\site-packages\\aurora\\model\\aurora.py:231, in Aurora.load_checkpoint(self, repo, name, strict)
    221 \"\"\"Load a checkpoint from HuggingFace.
    222 
    223 Args:
   (...)
    228         parameters in the checkpoint. Defaults to `True`.
    229 \"\"\"
    230 path = hf_hub_download(repo_id=repo, filename=name)
--> 231 self.load_checkpoint_local(path, strict=strict)

File c:\\Users\\scott\\miniconda3\\envs\\aurora\\lib\\site-packages\\aurora\\model\\aurora.py:304, in Aurora.load_checkpoint_local(self, path, strict)
    301         d[f\"decoder.atmos_heads.{name}.weight\"] = weight[:, i]
    302         d[f\"decoder.atmos_heads.{name}.bias\"] = bias[:, i]
--> 304 self.load_state_dict(d, strict=strict)

File c:\\Users\\scott\\miniconda3\\envs\\aurora\\lib\\site-packages\\torch\
n\\modules\\module.py:2215, in Module.load_state_dict(self, state_dict, strict, assign)
   2210         error_msgs.insert(
   2211             0, 'Missing key(s) in state_dict: {}. '.format(
   2212                 ', '.join(f'\"{k}\"' for k in missing_keys)))
   2214 if len(error_msgs) > 0:
-> 2215     raise RuntimeError('Error(s) in loading state_dict for {}:\
\\t{}'.format(
   2216                        self.__class__.__name__, \"\
\\t\".join(error_msgs)))
   2217 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Aurora:
\tsize mismatch for encoder.surf_token_embeds.weights.10u: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.10v: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.2t: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.lsm: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.msl: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.slt: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.surf_token_embeds.weights.z: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.q: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.t: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.u: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.v: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4]).
\tsize mismatch for encoder.atmos_token_embeds.weights.z: copying a param with shape torch.Size([256, 1, 2, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 1, 4, 4, 4])."
}
@wesselb
Copy link
Contributor

wesselb commented Oct 11, 2024

Closing this as #36 has been merged.

@wesselb wesselb closed this as completed Oct 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants