Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for converting from safetensors #740

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Added support for safetensors in `hf_olmo` conversion script.

## [v0.5.1](https://github.com/allenai/OLMo/releases/tag/v0.5.1) - 2024-10-17

### Added
Expand Down Expand Up @@ -45,7 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Swapped in correct flan data mix.
- Fix bug where the attention norm, when applied before the attention block, was modifying the residual stream.
- Fixed `OLMo.from_checkpoint()` so that it correctly loads `olmo_core` and `torch_new` style checkpoints.
- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout
- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout


## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11
Expand Down
23 changes: 16 additions & 7 deletions hf_olmo/convert_olmo_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from urllib.parse import urlparse

import torch
from olmo import ModelConfig, Tokenizer, TrainConfig
from olmo.checkpoint import build_sharded_checkpointer
from olmo.util import _get_s3_client
from omegaconf import OmegaConf as om
from safetensors.torch import load_file
from tqdm import tqdm

from hf_olmo.configuration_olmo import OLMoConfig
from hf_olmo.modeling_olmo import OLMoForCausalLM
from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
from olmo import ModelConfig, Tokenizer, TrainConfig
from olmo.checkpoint import build_sharded_checkpointer
from olmo.util import _get_s3_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,10 +68,16 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
# For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
# So, we explicitly store the model with the expected prefix.

old_model_path = os.path.join(checkpoint_dir, "model.pt")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
if os.path.exists(os.path.join(checkpoint_dir, "model.pt")):
old_model_path = os.path.join(checkpoint_dir, "model.pt")
state_dict = torch.load(old_model_path, map_location="cpu")
elif os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")):
old_model_path = os.path.join(checkpoint_dir, "model.safetensors")
state_dict = load_file(old_model_path, device="cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this? If you are getting the safetensors file using unshard.py then you should probably use safetensors_file_to_state_dict instead. See https://github.com/allenai/OLMo/blob/main/docs/Safetensors.md for context.

else:
raise ValueError(f"No model found in {checkpoint_dir}")

state_dict = torch.load(old_model_path, map_location="cpu")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")

# this takes care of the case where the model was saved with a different prefix,
# typically due to unsharding.
Expand Down Expand Up @@ -233,7 +240,9 @@ def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str):


def maybe_unshard(checkpoint_dir: str):
if os.path.exists(os.path.join(checkpoint_dir, "model.pt")):
if os.path.exists(os.path.join(checkpoint_dir, "model.pt")) or os.path.exists(
os.path.join(checkpoint_dir, "model.safetensors")
):
return

print(f"Unsharding {checkpoint_dir}...")
Expand Down
Loading