Skip to content

Commit

Permalink
Patch generation config when invalid
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Jan 18, 2024
1 parent a314310 commit 95bf6bf
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import tempfile
import time
import warnings
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
Expand All @@ -29,6 +30,7 @@
AutoTokenizer,
BitsAndBytesConfig,
EarlyStoppingCallback,
GenerationConfig,
HfArgumentParser,
IntervalStrategy,
Trainer,
Expand Down Expand Up @@ -200,7 +202,15 @@ class OtherArguments:
)
pad_to_max_length: bool = field(
default=False,
metadata={"help": "If to always pad to the given/found max sequence length (default: False)"},
metadata={
"help": "If to always pad to the given/found max sequence length. Mostly meant for testing (default: False)"
},
)
truncate_to_max_length: bool = field(
default=False,
metadata={
"help": "If to truncate to max sequence length. When set to False, error is raised for any sequence longer than max sequence length (default: False)"
},
)
max_num_samples: _NullableInt = field(
default=None,
Expand Down Expand Up @@ -304,6 +314,26 @@ def _cleanup_gpus():
logger.info(get_gpu_metrics())


def _fix_generation_config(model, tokenizer):
if hasattr(model, "generation_config"):
with warnings.catch_warnings(record=True) as caught_warnings:
model.generation_config.validate()
if len(caught_warnings) > 0:
messages = [str(w) for w in caught_warnings]
logger.info(
f"`generation_config` is invalid: {messages}. "
f"Because transformers refuses to save invalid config since 4.37 we are resetting it to default values. "
f"This may cause loss of some generation settings. It is recommended to set the correct generation config "
f"during inference."
)
model.generation_config = GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return model


def merge_adapters_if_any(
model_id: str,
revision: Optional[str],
Expand All @@ -314,12 +344,14 @@ def merge_adapters_if_any(
logger.info("Loading model and lora layers for merging ...")
model = AutoPeftModelForCausalLM.from_pretrained(
output_dir,
revision=revision,
revision=None,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
device_map="balanced",
)
tokenizer = get_tokenizer(model_source=output_dir, revision=None)
model = _fix_generation_config(model=model, tokenizer=tokenizer)
logger.info("Merging lora adapter into main model. This can take a while ...")
model = model.merge_and_unload()
model.save_pretrained(output_dir, safe_serialization=True)
Expand Down Expand Up @@ -465,6 +497,7 @@ def find_all_linear_names(model, other_arguments: OtherArguments, exclude_lm_hea

def get_model(
model_source: str,
revision: Optional[str],
model_config,
training_arguments: HFTrainingArguments,
other_arguments: OtherArguments,
Expand Down Expand Up @@ -500,7 +533,7 @@ def get_model(
)
model = AutoModelForCausalLM.from_pretrained(
model_source,
revision=other_arguments.revision,
revision=revision,
trust_remote_code=True,
torch_dtype=torch_dtype,
quantization_config=bnb_config,
Expand All @@ -521,7 +554,7 @@ def get_model(
model_load_kwargs.pop("low_cpu_mem_usage", None)
model = AutoModelForCausalLM.from_pretrained(
model_source,
revision=other_arguments.revision,
revision=revision,
trust_remote_code=True,
torch_dtype=get_torch_dtype(training_arguments),
device_map=device_map,
Expand Down Expand Up @@ -742,6 +775,7 @@ def dist_build_dataset(
tokenizer,
max_length: int,
pad_to_max_length: bool,
truncate_to_max_length: bool,
train_on_prompt: bool,
training_arguments: HFTrainingArguments,
):
Expand All @@ -759,7 +793,7 @@ def dist_build_dataset(
)
dataset_dict.save_to_disk(dataset_cache_path)
logger.info(f"Dataset max sequence lengths: {dataset_info}")
if any(
if not truncate_to_max_length and any(
length > max_length
for length in (
dataset_info.train_max_prompt_length,
Expand Down Expand Up @@ -836,6 +870,7 @@ def _train(
tokenizer=tokenizer,
max_length=max_length,
pad_to_max_length=other_arguments.pad_to_max_length,
truncate_to_max_length=other_arguments.truncate_to_max_length,
train_on_prompt=other_arguments.train_on_prompt,
training_arguments=training_arguments,
)
Expand Down Expand Up @@ -863,9 +898,11 @@ def _train(
# However the layer updating code is broken that it does not move the layers to correct device.
# So you'll get base layers on gpu and lora layers on cpu crashing the code.
# There is a massive refactor in peft 0.7.0 which has mostly solved this but will need some time to migrate correctly
# So for now, we always load the base model from pretrained version, resize embeddings is tokenizer from checkpoint has more tokens, and re-apply the peft config from scratch
# So for now, we always load the base model from pretrained version, resize embeddings is tokenizer from checkpoint has more tokens,
# and re-apply the peft config from scratch
model = get_model(
model_source=other_arguments.model_id, # This is not a bug
revision=other_arguments.revision,
model_config=model_config,
device_map=device_map,
training_arguments=training_arguments,
Expand All @@ -889,6 +926,8 @@ def _train(
other_arguments=other_arguments,
modules_to_save=lora_modules_to_save,
)

model = _fix_generation_config(model=model, tokenizer=tokenizer)
logger.info("Training...")
# TODO (chiragjn): Add text generation metrics to `compute_metrics
callbacks = [ExtraMetricsCallback(), TensorBoardCallback()]
Expand Down

0 comments on commit 95bf6bf

Please sign in to comment.