Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Mar 18, 2024
1 parent 169b7b8 commit 28719d8
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,17 @@ def load_state_dict(
:raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
"""
path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)

if path.suffix == ".pt":
safetensors_path = path.with_suffix(".safetensors")
if safetensors_path.is_file():
return safetensors_file_to_state_dict(safetensors_path, map_location=map_location)
if fname.endswith(".pt"):
# Try safetensors version first.
try:
path = resource_path(
str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
)
return safetensors_file_to_state_dict(path, map_location=map_location)
except FileNotFoundError:
pass

path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
return torch.load(path, map_location=map_location)


Expand Down Expand Up @@ -1270,12 +1274,6 @@ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
return [fsdp_model._handle] # type: ignore
else:
return []
# elif version.parse(torch.__version__) < version.parse("2.3.0"):
# # Could be None if the FSDP wrapper doesn't manage any parameters.
# if hasattr(fsdp_model, "_all_handles") and fsdp_model._all_handles is not None:
# return fsdp_model._all_handles
# else:
# return []
else:
# Need to verify FSDP internals with newer versions.
raise NotImplementedError
Expand Down

0 comments on commit 28719d8

Please sign in to comment.