Skip to content

Commit

Permalink
Utility improvements (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Aug 12, 2024
1 parent 3c137c3 commit fce3f56
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from rich.text import Text
from torch.nn.parallel.distributed import DistributedDataParallel
from transformers import PreTrainedModel, PreTrainedTokenizer
from huggingface_hub import HfApi


@dataclass
Expand Down Expand Up @@ -306,9 +307,9 @@ def generate(
# https://github.com/huggingface/transformers/blob/ac33aeeeee2a7a89b89c93c2962e6feb90daef0a/src/transformers/models/gpt2/modeling_gpt2.py#L1227-L1250
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
output_logits=True,
)
logits = torch.stack(output.scores, 1)
logits = torch.stack(output.logits, 1)
return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits


Expand Down Expand Up @@ -344,6 +345,7 @@ def save_with_accelerate(
push_to_hub: bool = False,
hf_repo_id: Optional[str] = None,
hf_repo_revision: Optional[str] = None,
private: bool = True,
) -> None:
unwrapped_model: PreTrainedModel = accelerator.unwrap_model(model)
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
Expand All @@ -366,12 +368,15 @@ def save_with_accelerate(
safe_serialization=False,
)
if accelerator.is_main_process and push_to_hub:
unwrapped_model.push_to_hub(
hf_repo_url = f"https://huggingface.co/{hf_repo_id}/tree/{hf_repo_revision}"
api = HfApi()
api.create_repo(hf_repo_id, exist_ok=True, private=private)
api.upload_folder(
repo_id=hf_repo_id,
revision=hf_repo_revision,
safe_serialization=False,
folder_path=output_dir,
commit_message="upload checkpoint",
run_as_future=False,
)
hf_repo_url = f"https://huggingface.co/{hf_repo_id}/tree/{hf_repo_revision}"
print(f"🔥 pushed to {hf_repo_url}")

if accelerator.is_main_process:
Expand Down Expand Up @@ -441,9 +446,7 @@ def unwrap_model_for_generation(
yield unwrapped_model


def prepare_deepspeed(
model: torch.nn.Module, per_device_train_batch_size: int, fp16: bool = False, bf16: bool = False
):
def prepare_deepspeed(model: torch.nn.Module, per_device_train_batch_size: int, mixed_precision: str):
"""
Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and
batch size.
Expand All @@ -452,6 +455,8 @@ def prepare_deepspeed(
The model to be prepared for DeepSpeed training.
per_device_train_batch_size (`int`):
The training batch size per device.
mixed_precision (`str`):
The mixed precision setting to use.
Returns:
`torch.nn.Module`:
The model initialized and configured with DeepSpeed for training.
Expand All @@ -467,10 +472,8 @@ def prepare_deepspeed(
"prescale_gradients": False,
"wall_clock_breakdown": False,
}
if bf16:
config_kwargs["bf16"] = {"enabled": True}
elif fp16:
config_kwargs["fp16"] = {"enabled": True}
if mixed_precision in ["bf16", "fp16"]:
config_kwargs[mixed_precision] = {"enabled": True}
else:
if hasattr(model, "config"):
hidden_size = (
Expand Down

0 comments on commit fce3f56

Please sign in to comment.