-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]: Llama3.1-70B-instruct save model #6108
Comments
Title: [BUG]: Llama3.1-70B-instruct save model |
Hi, is your final model save correct? As you have used tp = 8, the weight of embed_tokens is after sharded. |
Final model is not correct! |
Hi, Could you please tell me how to solve this issue? |
I couldn't reproduce your bug locally. I suggest upgrading ColossalAI to the latest main branch, then perform a source build with |
you can reproduce it by run |
tp=8 pp=1 sp=1 Llama3.1-8b-instruct Traceback (most recent call last):
File "/data1/Projects/mcts-llm/ColossalAI/applications/ColossalChat/examples/training_scripts/train_llama_rm.py", line 403, in <module>
train(args)
File "/data1/Projects/mcts-llm/ColossalAI/applications/ColossalChat/examples/training_scripts/train_llama_rm.py", line 326, in train
trainer.fit(
File "/data1/Projects/mcts-llm/ColossalAI/applications/ColossalChat/coati/trainer/base.py", line 67, in fit
self._train(epoch)
File "/data1/Projects/mcts-llm/ColossalAI/applications/ColossalChat/coati/trainer/rm.py", line 381, in _train
save_checkpoint(
File "/data1/Projects/mcts-llm/ColossalAI/applications/ColossalChat/coati/utils/ckpt_io.py", line 61, in save_checkpoint
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
File "/usr/local/lib/python3.10/site-packages/colossalai/booster/booster.py", line 374, in save_optimizer
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
File "/usr/local/lib/python3.10/site-packages/colossalai/checkpoint_io/checkpoint_io_base.py", line 197, in save_optimizer
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
File "/usr/local/lib/python3.10/site-packages/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py", line 447, in save_sharded_optimizer
total_size = save_state_dict_shards(
File "/usr/local/lib/python3.10/site-packages/colossalai/checkpoint_io/utils.py", line 243, in save_state_dict_shards
for idx, shard_pair in enumerate(sharded_state_dict):
File "/usr/local/lib/python3.10/site-packages/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py", line 155, in _optimizer_sharder
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
File "/usr/local/lib/python3.10/site-packages/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py", line 873, in gather_from_sharded_optimizer_state
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
File "/usr/local/lib/python3.10/site-packages/colossalai/checkpoint_io/utils.py", line 116, in search_tp_partition_dim
original_shape[partition_dim] == tp_size * current_shape[partition_dim]
AssertionError: The parameter isn't evenly distributed among tensor parallel group: shape before sharding torch.Size([128256, 4096]), shape after sharding torch.Size([16064, 4096]) |
Thank you for your reply, we are refactoring this part, you can try early versions like 0.3.1, I have run Llama3 models before. After refactoring finished, we will tested and remind you if needed. |
Has there been a break in the issue? |
Hi, you can try modifying it according to this PR and retry. |
|
The key is that the embedding dimension should not be sharded into 16064 (128256/8=16032). |
What’s the progress like? |
Set plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
# set make_vocal_size_divisible_by=32
make_vocal_size_divisible_by=32,
custom_policy=get_autopolicy(model.model),
) |
Is there an existing issue for this bug?
🐛 Describe the bug
I trained reward model based on Llama3.1-70B-instruct in 48 H100 (3d tp=8, pp=1, ).
When execute
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
, the size ofmodel.embed_tokens.weight
saved is [16064, 8192] rather than [128256, 8192]. However, the size of other weight are correct.Please HELP ME!
Thank you!
Environment
transformes 4.44.1
colosssalai 0.4.5
flash-attn 2.6.3
The text was updated successfully, but these errors were encountered: