Skip to content

Commit

Permalink
Merge pull request #77 from Achazwl/tempfix_opendelta
Browse files Browse the repository at this point in the history
temparary fix of bmtrain+opendelta load state dict
  • Loading branch information
a710128 authored Apr 19, 2023
2 parents d05a519 + 672878b commit b0c4b3c
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,39 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
elif strict:
missing_keys.append(key)

for name, param in self.named_parameters():
if isinstance(param, DistributedParameter) and not param._in_checkpoint_block:
key = prefix + name
all_keys.append(key)
if key in state_dict:
input_param = state_dict[key]
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]

if not is_param_lazy and not isinstance(param, DistributedParameter) and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape:
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
try:
with torch.no_grad():
param._copy_data(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.format(key, param.size(), input_param.size(), ex.args))
elif strict:
missing_keys.append(key)

if strict:
all_keys = set(all_keys)
for key in state_dict.keys():
Expand Down

0 comments on commit b0c4b3c

Please sign in to comment.