Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel]:update autoparallel format #9747

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llm/auto_parallel/gpt-3/gpt_with_intermediate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ to_static=1
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3" \
--log_dir ${log_dir} \
run_pretrain_auto.py \
../run_pretrain_auto.py \
--model_name_or_path gpt3-13B-en \
--tokenizer_name_or_path gpt3-13B-en \
--to_static ${to_static} \
Expand Down
2 changes: 1 addition & 1 deletion llm/auto_parallel/gpt-3/run_pretrain_auto_dp2mp2pp2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ rm -rf $output_dir
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir ${log_dir} \
run_pretrain_auto.py \
../run_pretrain_auto.py \
--model_name_or_path gpt2-medium-en \
--tokenizer_name_or_path gpt2-medium-en \
--input_dir "../data" \
Expand Down
2 changes: 1 addition & 1 deletion llm/auto_parallel/llama/llama_with_api.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export PYTHONPATH=../../../:$PYTHONPATH
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "single" \
./run_pretrain_auto.py \
../run_pretrain_auto.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
Expand Down
2 changes: 1 addition & 1 deletion llm/auto_parallel/llama/run_llama3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export PYTHONPATH=../../../:$PYTHONPATH
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name""_log" \
./run_pretrain_auto.py \
../run_pretrain_auto.py \
--model_name_or_path "meta-llama/Meta-Llama-3-8B-Instruct" \
--tokenizer_name_or_path "meta-llama/Meta-Llama-3-8B-Instruct" \
--input_dir "./data" \
Expand Down
2 changes: 1 addition & 1 deletion llm/auto_parallel/llama/run_pretrain_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ to_static=0 # 是否开启动转静训练
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "auto_3d" \
run_pretrain_auto.py \
../run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
Expand Down
2 changes: 1 addition & 1 deletion llm/auto_parallel/qwen/run_intermediate_api.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export FLAGS_enable_pir_api=1
python -u -m paddle.distributed.launch \
--gpus "4,5" \
--log_dir "log/auto_3d_auto" \
run_pretrain_3D_auto.py \
../run_pretrain_auto.py \
--model_name_or_path "qwen/qwen-14b" \
--tokenizer_name_or_path "qwen/qwen-14b" \
--model_type "qwen_network" \
Expand Down
2 changes: 1 addition & 1 deletion llm/auto_parallel/qwen/run_pretrain_3D_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ export FLAGS_enable_pir_in_executor=0
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "auto_3d" \
run_pretrain_3D_auto.py ./pretrain_argument_auto_dp2tp2pp2.json
../run_pretrain_auto.py ./pretrain_argument_auto_dp2tp2pp2.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
GPT/Llama auto parallel pretraining scripts.
GPT/Llama/Qwen auto parallel pretraining scripts.
"""
import os
import random
Expand All @@ -24,6 +24,7 @@
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet

from paddlenlp.ops import Topology
from paddlenlp.trainer import (
Expand All @@ -42,12 +43,26 @@
GPTPretrainingCriterionAuto,
GPTPretrainingCriterionNet,
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLM3DAuto,
LlamaForCausalLMNet,
LlamaPretrainingCriterion3DAuto,
LlamaPretrainingCriterionNet,
QWenConfig,
QWenForCausalLM3DAuto,
QWenForCausalLMNet,
QWenPretrainingCriterionAuto,
QWenPretrainingCriterionNet,
)
from paddlenlp.utils.log import logger

MODEL_CLASSES = {
"gpt": (GPTConfig, GPTForCausalLMAuto, GPTPretrainingCriterionAuto),
"gpt_network": (GPTConfig, GPTForCausalLMNet, GPTPretrainingCriterionNet),
"llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto),
"llama_network": (LlamaConfig, LlamaForCausalLMNet, LlamaPretrainingCriterionNet),
"qwen": (QWenConfig, QWenForCausalLM3DAuto, QWenPretrainingCriterionAuto),
"qwen_network": (QWenConfig, QWenForCausalLMNet, QWenPretrainingCriterionNet),
}

from paddlenlp.data.causal_dataset import (
Expand All @@ -57,6 +72,9 @@
)
from paddlenlp.trainer.utils.doc import add_start_docstrings

# Pretaining Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "True"


@dataclass
@add_start_docstrings(AutoTrainingArguments.__doc__)
Expand All @@ -71,6 +89,12 @@ class PreTrainingArguments(AutoTrainingArguments):
"help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate."
},
)
enable_linear_fused_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand All @@ -91,6 +115,10 @@ class PreTrainingArguments(AutoTrainingArguments):
default=False,
metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."},
)
fine_grained_log: bool = field(
default=False,
metadata={"help": "whether print find-grained performance log"},
)

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -198,14 +226,13 @@ class ModelArguments:
default=False,
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
)
# this optional can be use in run_pretrain.py
use_fast_layer_norm: bool = field(
default=False,
metadata={"help": "GPT3 model, use fast layernorm"},
metadata={"help": "whether to use fast layernorm"},
)
use_fused_dropout_add: bool = field(
default=False,
metadata={"help": "Gpt3 model, use_fused_dropout_add"},
metadata={"help": "whether to use_fused_dropout_add"},
)
recompute_granularity: str = field(
default="full",
Expand All @@ -221,7 +248,6 @@ class ModelArguments:
"help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models."
},
)

hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."})
attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."})
use_fused_rope: Optional[bool] = field(
Expand Down Expand Up @@ -256,14 +282,14 @@ def create_pretrained_dataset(

train_val_test_num_samples = [
training_args.per_device_train_batch_size
* training_args.data_parallel_degree
* training_args.dataset_world_size
* training_args.max_steps
* training_args.gradient_accumulation_steps,
training_args.per_device_eval_batch_size
* training_args.data_parallel_degree
* training_args.dataset_world_size
* training_args.eval_iters
* (training_args.max_steps // training_args.eval_steps + 1),
training_args.per_device_eval_batch_size * training_args.data_parallel_degree * training_args.test_iters,
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
]

print_rank_0(" > datasets target sizes (minimum size):")
Expand Down Expand Up @@ -354,6 +380,24 @@ def _wrap_for_dist_loader(self, train_dataloader):
dist_loader._input_keys = ["input_ids", "labels"]
return dist_loader

def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
if self.train_dataset is None:
return None

total_batch_size_per_acc_step = self.args.per_device_train_batch_size * self.args.dataset_world_size
total_batch_size = total_batch_size_per_acc_step

# In llm/llama/run_pretrain.py, it uses paddlenlp.utils.batch_sampler.DistributedBatchSampler,
# which does no shuffle when shuffle is set True.
sampler = paddle.io.BatchSampler(
dataset=self.train_dataset,
shuffle=False,
batch_size=total_batch_size,
drop_last=self.args.dataloader_drop_last,
)
sampler._acc_steps = self.args.gradient_accumulation_steps
return sampler


def print_config(args, key=""):
"""
Expand Down Expand Up @@ -393,10 +437,10 @@ def init_seed(seed: int = 1234, args=None):
topo = Topology(
dist.get_rank(),
dist.get_world_size(),
dp_degree=max(args.data_parallel_degree, args.sharding_parallel_degree),
dp_degree=args.dataset_world_size,
pp_degree=args.pipeline_parallel_degree,
mp_degree=args.tensor_parallel_degree,
sharding_degree=1,
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
order=order,
)

Expand All @@ -416,6 +460,13 @@ def init_seed(seed: int = 1234, args=None):
paddle.seed(args.seed)


def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp")[pp_idx]
return mesh


def main():
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand Down Expand Up @@ -462,7 +513,7 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)

config = config_class.from_pretrained(model_args.model_name_or_path)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
# There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings
if not model_args.continue_training:
Expand All @@ -486,7 +537,7 @@ def main():
config.num_attention_heads = (
model_args.num_attention_heads if model_args.num_attention_heads is not None else config.num_attention_heads
)
config.use_fused_dropout_add = model_args.use_fused_dropout_add

config.use_flash_attention = model_args.use_flash_attention
config.use_fused_rms_norm = model_args.use_fused_rms_norm
config.fuse_attention_qkv = model_args.fuse_attention_qkv
Expand All @@ -503,40 +554,43 @@ def main():
config.use_recompute = training_args.recompute
config.tensor_parallel_degree = training_args.tensor_parallel_degree
config.tensor_parallel_rank = training_args.tensor_parallel_rank
config.sharding_parallel_degree = training_args.sharding_parallel_degree

if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1:
pipeline = training_args.strategy.pipeline
pipeline.vpp_degree = config.virtual_pp_degree
pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method

config.hidden_dropout_prob = model_args.hidden_dropout_prob
config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
print("Final pre-training config:", config)

# Set the dtype for loading model
# # Set the dtype for loading model
dtype = "float32"
if training_args.fp16_opt_level == "O2":
if training_args.fp16:
dtype = "float16"
if training_args.bf16:
dtype = "bfloat16"

with paddle.LazyGuard():
model = model_class.from_config(config, dtype=dtype)
criterion = criterion_class(config)

if training_args.recompute:

def fn(layer):
if hasattr(layer, "enable_recompute") and (layer.enable_recompute is False or layer.enable_recompute == 0):
layer.enable_recompute = True
if hasattr(layer, "layerwise_recompute"):
layer.layerwise_recompute = True

model.apply(fn)

# Create the learning_rate sheduler and optimizer
if training_args.decay_steps is None:
training_args.decay_steps = training_args.max_steps
warmup_steps = training_args.warmup_ratio * training_args.max_steps

if training_args.warmup_steps > 0:
warmup_steps = training_args.warmup_steps
else:
warmup_steps = training_args.warmup_ratio * training_args.max_steps

lr_scheduler = None
if training_args.lr_scheduler_type.value == "cosine":
Expand Down Expand Up @@ -564,7 +618,6 @@ def fn(layer):
tokenizer,
need_data=training_args.should_load_dataset,
)

trainer = PretrainingTrainer(
model=model,
criterion=criterion,
Expand All @@ -585,6 +638,7 @@ def fn(layer):
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)

# NOTE(gongenlei): new add
if not training_args.autotuner_benchmark:
metrics = train_result.metrics
Expand All @@ -594,6 +648,15 @@ def fn(layer):
trainer.save_metrics("train", metrics)
trainer.save_state()

if training_args.do_predict:
test_ret = trainer.predict(test_dataset)
trainer.log_metrics("test", test_ret.metrics)

# if training_args.should_load_dataset:
# effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"]
# print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}")
# print(f"ips: {effective_tokens_per_second:.2f} tokens/s")


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@
]


def get_use_casual_mask():
"""Get the value of the 'USE_CASUAL_MASK' environment variable."""
return os.getenv("USE_CASUAL_MASK", "False") == "True"

Check warning on line 88 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L88

Added line #L88 was not covered by tests


def enable_fuse_ffn_qkv_pass():
if os.getenv("FLAGS_enable_fused_ffn_qkv_pass") in [
"True",
Expand Down
Loading