Skip to content

Commit

Permalink
fix(abstractions): support finetuning for auto templates
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyiQ committed Dec 21, 2024
1 parent 4de1544 commit b3d4a09
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
13 changes: 4 additions & 9 deletions src/abstractions/configs/templates_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ def __exit__(self, type, value, traceback):
def register_destroyer(destroyer: Callable[[], None]):
GlobalState.__active_backend_destroyers.append(destroyer)


bash_command_template = f"""PYTHONNOUSERSITE=1 MASTER_PORT=9902 conda run --no-capture-output -n %s deepspeed %s --master_port=9902 {root}/libs/llama_factory/src/train_bash.py \\
--deepspeed %s \\
--ddp_timeout 180000000 \\
--stage %s \\
--do_%s \\%s
--model_name_or_path %s \\
--dataset %s \\
--dataset_dir {root}/libs/llama_factory/data \\
--template %s \\
--dataset_dir {root}/libs/llama_factory/data \\%s
--finetuning_type %s \\
--lora_target q_proj,v_proj \\
--output_dir %s \\
Expand Down Expand Up @@ -89,8 +87,7 @@ def register_destroyer(destroyer: Callable[[], None]):
--finetuning_type %s \\
--lora_target q_proj,v_proj \\
--dataset %s \\
--dataset_dir {root}/libs/llama_factory/data \\
--template %s \\
--dataset_dir {root}/libs/llama_factory/data \\%s
--cutoff_len 1024 \\
--overwrite_cache \\
--preprocessing_num_workers 16 \\
Expand Down Expand Up @@ -122,8 +119,7 @@ def register_destroyer(destroyer: Callable[[], None]):
--model_name_or_path %s \\
--finetuning_type full \\
--dataset %s \\
--dataset_dir {root}/libs/llama_factory/data \\
--template %s \\
--dataset_dir {root}/libs/llama_factory/data \\%s
--lora_target q_proj,v_proj \\
--output_dir %s \\
--overwrite_cache \\
Expand All @@ -147,8 +143,7 @@ def register_destroyer(destroyer: Callable[[], None]):

bash_command_for_lora_merging = f"""PYTHONNOUSERSITE=1 conda run --no-capture-output -n %s python {root}/libs/llama_factory/src/export_model.py \\
--model_name_or_path %s \\
--adapter_name_or_path %s \\
--template %s \\
--adapter_name_or_path %s \\%s
--finetuning_type lora \\
--export_dir %s \\
--export_size 2 \\
Expand Down
23 changes: 11 additions & 12 deletions src/abstractions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
is_instruct_finetuned: bool = True,
model_path_or_repoid: Optional[str] = None,
num_gpus: int = None,
template_type: Literal["auto", "alpaca", "mistral", "llama3"] = "auto",
template_type: Literal["auto", "alpaca", "mistral", "llama3"] = None,
):
"""
Initialize.
Expand All @@ -141,7 +141,7 @@ def __init__(
:param num_gpus: Number of GPUs to use for parallel finetuning/inference. Default to the total number of gpus on the machine.
:type num_gpus: Optional[int] = None
:param template_type: The type of template to use, which can be "auto", "alpaca", "mistral", or "llama3". If "auto", the template type is inferred from the model's config file.
:param template_type: The type of template to use, which can be "auto", "alpaca", "mistral", or "llama3". If "auto", the template type is inferred from the model's config file. Set the environment variable DEFAULT_TEMPLATE to specify the default template type, if some other value than "auto" is desired.
:type template_type: Literal["auto", "alpaca", "mistral", "llama3"] = "auto"
Examples:
Expand All @@ -151,6 +151,10 @@ def __init__(
Model(model_name = 'Gemma-2B_sft', is_instruct_finetuned = True)
"""
if os.environ.get("DEFAULT_TEMPLATE") and not template_type:
template_type = os.environ["DEFAULT_TEMPLATE"].lower()
assert template_type in ["auto", "alpaca", "mistral", "llama3"]

if not num_gpus:
num_gpus = torch.cuda.device_count()

Expand Down Expand Up @@ -357,11 +361,6 @@ def finetune(
:return: Returns a Model instance with name {result_model_name}, which is the result of the finetuning.
:rtype: Model.
"""
if self.template_type == "auto":
raise ValueError(
"Finetuning is not supported for models with auto template type."
)

if stage == "pretrain":
assert (
data.data_type == "pretrain"
Expand Down Expand Up @@ -454,7 +453,7 @@ def finetune(
"", # do sample; ignored here
self.model_path, # where to find the original model
data.data_name, # dataset (automatically registered in llama-factory)
self.template_type, # template type
(f"\n --template {self.template_type} \\" if self.template_type != "auto" else ""), # template type
("lora" if algo == "lora" else "full"), # type - full_param or lora
f"{root}/output/training_results/{escape(result_model_name)}/", # where to save the training results (and checkpoints etc.)
2
Expand Down Expand Up @@ -529,7 +528,7 @@ def finetune(
"pa38-lf",
self.model_path,
result.model_path,
self.template_type,
(f"\n --template {self.template_type} \\" if self.template_type != "auto" else ""), # template type
merged_model_path,
)
print(cmd)
Expand Down Expand Up @@ -595,7 +594,7 @@ def __rlhf(
f"{root}/src/abstractions/configs/LF_examples/full_multi_gpu/ds_z3_config.json",
rw_path,
rw_data.data_name,
self.template_type,
(f"\n --template {self.template_type} \\" if self.template_type != "auto" else ""), # template type
rw_results,
2 ** max(0, 3 + batch_size_multiplier_log2), # per_device_train_batch_size
2 ** max(0, 4 + batch_size_multiplier_log2), # per_device_eval_batch_size
Expand Down Expand Up @@ -644,7 +643,7 @@ def __rlhf(
rw_results,
"lora" if use_lora else "full",
ppo_data.data_name,
self.template_type,
(f"\n --template {self.template_type} \\" if self.template_type != "auto" else ""), # template type
the_path,
2 ** max(0, 1 + batch_size_multiplier_log2), # per_device_train_batch_size
2 ** max(0, 2 + batch_size_multiplier_log2), # per_device_eval_batch_size
Expand Down Expand Up @@ -905,7 +904,7 @@ def __inference_parallel_deepspeed(
"\n--do_sample \\", # do sample
self.model_path, # where to save the resulting model
data.data_name, # dataset (automatically registered in llama-factory)
self.template_type, # template type
(f"\n --template {self.template_type} \\" if self.template_type != "auto" else ""), # template type
"full", # type - full_param or lora; useless here
result_data_path, # where to save the inference results
2 ** max(0, 3 + batch_size_multiplier_log2), # per_device_train_batch_size
Expand Down

0 comments on commit b3d4a09

Please sign in to comment.