Skip to content

Commit

Permalink
[AutoParallel]:nlp support run lora model with intermediate
Browse files Browse the repository at this point in the history
  • Loading branch information
blacksheep-Aristotle authored and zhangyuqin1998 committed Jan 3, 2025
1 parent 6b2425c commit 93041e1
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 16 deletions.
9 changes: 5 additions & 4 deletions llm/auto_parallel/run_finetune_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@
Llama3Tokenizer,
LlamaConfig,
LlamaForCausalLM3DAuto,
LlamaForCausalLM3DNet,
LlamaForCausalLMNet,
LlamaPretrainingCriterion3DAuto,
LlamaPretrainingCriterion3DNet,
LlamaPretrainingCriterionNet,
LlamaTokenizer,
)

MODEL_CLASSES = {
"llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto),
"llama_network": (LlamaConfig, LlamaForCausalLM3DNet, LlamaPretrainingCriterion3DNet),
"llama_network": (LlamaConfig, LlamaForCausalLMNet, LlamaPretrainingCriterionNet),
}

from paddlenlp.trl import DataConfig, ModelConfig, SFTAutoTrainer, SFTConfig
Expand All @@ -80,7 +80,7 @@
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"

flash_mask_support_list = [LlamaForCausalLM3DAuto, LlamaForCausalLM3DNet]
flash_mask_support_list = [LlamaForCausalLM3DAuto, LlamaForCausalLMNet]


def paddlenlp_verison_check():
Expand Down Expand Up @@ -489,6 +489,7 @@ def compute_metrics_do_generation(eval_preds):
# layer.register_forward_pre_hook(forward_pre_hook)
# layer.register_forward_post_hook(forward_post_hook)
# Train
print(trainer.model)
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
Expand Down
176 changes: 166 additions & 10 deletions paddlenlp/peft/lora/auto_lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def auto_dist_config(self, prefix=""):
]


class LoRAAutoModel(PretrainedModel):
class LoRAAutoModel(nn.Layer):
# TODO:lugimzzz support restore in following PR
restore_layer_map: Dict[nn.Layer, nn.Layer] = {
LoRAAutoLinear: nn.Linear,
Expand All @@ -137,7 +137,6 @@ def __init__(self, model, lora_config: LoRAAutoConfig) -> None:
self.model_config = AutoConfig.from_pretrained(lora_config.base_model_name_or_path)
self.quantized = False
self.lora_config = lora_config

if self.lora_config.dtype is None:
self.lora_config.dtype = paddle.get_default_dtype()
with dtype_guard(self.lora_config.dtype):
Expand Down Expand Up @@ -506,13 +505,6 @@ def get_lora_model(self, model: Union[PretrainedModel, nn.Layer], lora_config: L
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `List[str]`, `enable_lora_list` must be `None` or `List[Optional[List[bool]]]`"
)
if lora_config.use_intermediate_api:
assert hasattr(
model, "auto_dist_config"
), "train lora_model requires auto_dist_config when use intermediate api"
auto_dist_config = model.auto_dist_config()
if auto_dist_config["mp_config"] is not None:
mp_parallelize_plan = auto_dist_config["mp_config"]["parallelize_plan"]

def _match_layer(module_name, parallelize_plan):
# Match the layer to a plan.
Expand All @@ -531,14 +523,178 @@ def _match_layer(module_name, parallelize_plan):
):
return plan

if lora_config.use_intermediate_api:
assert hasattr(
model, "auto_dist_config"
), "train lora_model requires auto_dist_config when use intermediate api"
auto_dist_config = model.auto_dist_config()
if auto_dist_config["mp_config"] is not None:
mp_parallelize_plan = auto_dist_config["mp_config"]["parallelize_plan"]
for target_module, enable_lora in zip(target_modules, enable_lora_list):
for i in model.named_sublayers():
module_name = i[0]
if re.fullmatch(target_module, module_name):
layer_parallelize_plan = _match_layer(module_name, mp_parallelize_plan)
layer_parallelize_plan = None
if lora_config.use_intermediate_api:
layer_parallelize_plan = _match_layer(module_name, mp_parallelize_plan)
self._find_and_replace_module(model, module_name, lora_config, enable_lora, layer_parallelize_plan)
return model

def merge_auto_dist_configs(self, configs):
"""
Merged all auto dist configs into one config.
configs is a list of config,every config is a dict,which means a model auto_dist_config.
[
{
mp_config (dict): {
"parallelize_plan": dict, the plan to shard the layer.
}
pp_config (dict): {
"split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
"global_spec": str|list(str), make the output tensor of specific layers on global mesh.
}
},{
mp_config (dict): {
"parallelize_plan": dict, the plan to shard the layer.
}
pp_config (dict): {
"split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
"global_spec": str|list(str), make the output tensor of specific layers on global mesh.
}
},....
]
"""
# import pdb;pdb.set_trace()
assert isinstance(configs, (dict, list))
if isinstance(configs, dict):
return configs
final_config = {
"mp_config": None,
"sp_config": None,
"pp_config": None,
}
for config in configs:
if "mp_config" in config and config["mp_config"] is not None:
if final_config["mp_config"] is None:
final_config["mp_config"] = config["mp_config"]
else:
for k, v in config["mp_config"]["parallelize_plan"].items():
assert (
k not in final_config["mp_config"]["parallelize_plan"].keys()
), f"sublayer mp_config shuld be a subset of model but got sublayer config {config['mp_config']} and model config {final_config['mp_config']}."
final_config["mp_config"]["parallelize_plan"][k] = v
if "sp_config" in config and config["sp_config"] is not None:
if final_config["sp_config"] is None:
final_config["sp_config"] = config["sp_config"]
else:
for k, v in config["sp_config"]["parallelize_plan"].items():
assert (
k not in final_config["sp_config"]["parallelize_plan"].keys()
), f"sublayer sp_config shuld be a subset of model but got sublayer config {config['sp_config']} and model config {final_config['sp_config']}."
final_config["sp_config"]["parallelize_plan"][k] = v
if "pp_config" in config and config["pp_config"] is not None:

def process_spec(spec_name):
if isinstance(config["pp_config"][spec_name], str):
config["pp_config"][spec_name] = [config["pp_config"][spec_name]]
if final_config["pp_config"] is None:
final_config["pp_config"] = config["pp_config"]
elif config["pp_config"][spec_name] not in final_config["pp_config"][spec_name]:
final_config["pp_config"][spec_name] += config["pp_config"][spec_name]
elif isinstance(config["pp_config"][spec_name], (tuple, list)):
if final_config["pp_config"] is None:
final_config["pp_config"] = config["pp_config"]
elif config["pp_config"][spec_name] not in final_config["pp_config"][spec_name]:
final_config["pp_config"][spec_name] += config["pp_config"][spec_name]

process_spec("split_spec")
process_spec("global_spec")

if final_config["pp_config"] is not None:
if len(final_config["pp_config"]["split_spec"]) == 1:
final_config["pp_config"]["split_spec"] = final_config["pp_config"]["split_spec"][0]
elif len(final_config["pp_config"]["split_spec"]) > 1:
final_config["pp_config"]["split_spec"] = list(set(final_config["pp_config"]["split_spec"]))
if len(final_config["pp_config"]["global_spec"]) > 1:
final_config["pp_config"]["global_spec"] = list(set(final_config["pp_config"]["global_spec"]))
# final_config["pp_config"]["split_spec"] = final_config["pp_config"]["split_spec"][0]
# final_config["pp_config"]["global_spec"] = final_config["pp_config"]["global_spec"][0]
return final_config

def _generate_auto_dist_config(self, auto_dist_degree):
merged_config = {
"sp_config": None,
"mp_config": None,
"pp_config": None,
}
layer_name = []
for name, layer in self.named_sublayers(include_self=True):
if hasattr(layer, "auto_dist_config"):
if name != "":
prefix = name + "."
else:
prefix = ""
layer_config = layer.auto_dist_config(prefix)
merged_config = self.merge_auto_dist_configs([merged_config, layer_config])
layer_name.append(name)
# for _, deeper_layer in layer.named_sublayers():
# if hasattr(deeper_layer, "auto_dist_config"):
# # mask all `auto_dist_config` methods in deeper layer
# deeper_layer.auto_dist_config = lambda x: {}
final_config = {
"dp_config": None,
"mp_config": None,
"pp_config": None,
}
if "tensor_parallel" in auto_dist_degree and auto_dist_degree["tensor_parallel"]:
merged_config["mp_config"] is not None
final_config["mp_config"] = merged_config["mp_config"]

if "sequence_parallel" in auto_dist_degree and auto_dist_degree["sequence_parallel"]:
merged_config["sp_config"] is not None
final_config["mp_config"] = merged_config["sp_config"]

if "pipeline_parallel" in auto_dist_degree and auto_dist_degree["pipeline_parallel"]:
merged_config["pp_config"] is not None
final_config["pp_config"] = merged_config["pp_config"]
if final_config["pp_config"]["global_spec"] is not None:
# final_config["pp_config"]["global_spec"] = [spec_name for spec_name in final_config["pp_config"]["global_spec"] if spec_name in layer_name ]
temp_specs_name = final_config["pp_config"]["global_spec"]
for spec_name_i in temp_specs_name:
for spec_name_j in temp_specs_name:
if spec_name_i != spec_name_j and spec_name_i in spec_name_j:
final_config["pp_config"]["global_spec"].remove(spec_name_i)
break

if final_config["pp_config"]["split_spec"] is not None:
# final_config["pp_config"]["split_spec"] = [spec_name for spec_name in final_config["pp_config"]["split_spec"] if spec_name in layer_name ]
temp_specs_name = final_config["pp_config"]["split_spec"]
for spec_name_i in temp_specs_name:
for spec_name_j in temp_specs_name:
if spec_name_i != spec_name_j and spec_name_i in spec_name_j:
final_config["pp_config"]["split_spec"].remove(spec_name_i)
break

if "data_sharding_parallel" in auto_dist_degree and auto_dist_degree["data_sharding_parallel"]:
# to avoid a circular import
from paddlenlp.trainer.trainer_utils import ShardingOption

level = 0
if "sharding" in auto_dist_degree and auto_dist_degree["sharding"] is not None:
sharding = auto_dist_degree["sharding"]
if ShardingOption.SHARD_OP in sharding:
level = 1
if ShardingOption.SHARD_GRAD_OP in sharding:
level = 2
if ShardingOption.FULL_SHARD in sharding:
level = 3
final_config["dp_config"] = {
"sharding_level": level,
"sharding_mesh_dim": auto_dist_degree.get("sharding_mesh_dim", None),
}

return final_config

def restore_original_model(self):
# make sure W and lora weights are not merged before we restore the original model

Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/llama/modeling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def auto_dist_config(self, prefix=""):
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
}
},
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": "llama.global_layer"},
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"},
}

return config
2 changes: 1 addition & 1 deletion paddlenlp/trl/sft_auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def loss_func(loss, outputs):
}
auto_dist_config = model._generate_auto_dist_config(auto_dist_degree)
self.auto_dist_config = auto_dist_config
logger.info(f"auto_dist_config: {self.auto_dist_config['mp_config']}")
logger.info(f"auto_dist_config: {self.auto_dist_config}")
model = parallelize_model(
model,
config=self.auto_dist_config,
Expand Down

0 comments on commit 93041e1

Please sign in to comment.