Skip to content

Commit

Permalink
Catch none-valued rope scaling configs
Browse files Browse the repository at this point in the history
  • Loading branch information
j-frei committed Nov 2, 2023
1 parent 39866af commit ad9e683
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
9 changes: 6 additions & 3 deletions fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def train():

# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
else:
assert model_args.model_type == "llama", "Only support llama and gpt-neox for now"
replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
Expand All @@ -118,7 +118,10 @@ def train():
cache_dir=training_args.cache_dir,
)

orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
orig_rope_scaling = getattr(config, "rope_scaling", None)
if orig_rope_scaling is None:
orig_rope_scaling = {"factor": 1}

orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len:
Expand Down Expand Up @@ -195,7 +198,7 @@ def train():
model.enable_input_require_grads() # required for gradient checkpointing
model.gradient_checkpointing_enable() # enable gradient checkpointing
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args,
model=model, tokenizer=tokenizer, args=training_args,
train_dataset=dataset["train"],
eval_dataset=None,
data_collator=data_collator)
Expand Down
6 changes: 4 additions & 2 deletions supervised-fine-tune-qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def train():

# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
else:
replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)

Expand All @@ -247,7 +247,9 @@ def train():
cache_dir=training_args.cache_dir,
)

orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
orig_rope_scaling = getattr(config, "rope_scaling", None)
if orig_rope_scaling is None:
orig_rope_scaling = {"factor": 1}
orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len:
Expand Down
8 changes: 5 additions & 3 deletions supervised-fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]

targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

logging.warning("Tokenizing inputs... This may take some time...")
Expand Down Expand Up @@ -236,7 +236,7 @@ def train():

# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
else:
replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)

Expand All @@ -246,7 +246,9 @@ def train():
cache_dir=training_args.cache_dir,
)

orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
orig_rope_scaling = getattr(config, "rope_scaling", None)
if orig_rope_scaling is None:
orig_rope_scaling = {"factor": 1}
orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len:
Expand Down

0 comments on commit ad9e683

Please sign in to comment.