Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ppo_freeze mat1 mat2 should have the same dtype #5480

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ex-yanminmin001
Copy link

What does this PR do?

Fixes # (issue)

Before submitting

@hiyouga
Copy link
Owner

hiyouga commented Sep 19, 2024

请提供详细一点的信息,目前这个 PR 会造成 VRAM 显著上升

@ex-yanminmin001
Copy link
Author

ex-yanminmin001 commented Sep 23, 2024

### model
model_name_or_path: /LLMs/Qwen1.5-0.5B-Chat
reward_model: /output/Qwen1.5-0.5B-Chat/full/reward
reward_model_type: full

### method
stage: ppo
do_train: true
finetuning_type: freeze
lora_target: all

### dataset
dataset: identity
template: qwen
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
bf16: true

这种配置文件下,运行的时候会报错
run_exp()
File "/workspace/src/llamafactory/train/tuner.py", line 55, in run_exp
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
File "/workspace/src/llamafactory/train/ppo/workflow.py", line 77, in run_ppo
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
File "/workspace/src/llamafactory/train/ppo/trainer.py", line 249, in ppo_train
mini_batch_queries, mini_batch_responses = self.get_inputs(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/workspace/src/llamafactory/train/ppo/trainer.py", line 354, in get_inputs
generate_output: "torch.Tensor" = unwrapped_model.generate(
File "/usr/local/lib/python3.10/dist-packages/trl/models/modeling_value_head.py", line 209, in generate
return self.pretrained_model.generate(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1989, in generate
result = self._sample(
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2932, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 1068, in forward
logits = self.lm_head(hidden_states)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

改完之后可以继续训练

@ex-yanminmin001
Copy link
Author

debug的时候,
F.linear(input, self.weight, self.bias)
input是模型输出的结果 类型为float32
weight是bf16,造成错误,添加
if finetuning_args.stage == "ppo": for name, param in model.named_parameters(): if not any(forbidden_module in name for forbidden_module in forbidden_modules): if cast_trainable_params_to_fp32: param.data = param.data.to(torch.float32)
强制把模型数据转换成float32,错误消失

@ex-yanminmin001
Copy link
Author

请提供详细一点的信息,目前这个 PR 会造成 VRAM 显著上升

已提供,大神看一下呢

@@ -133,6 +133,25 @@ def _setup_freeze_tuning(
else:
param.requires_grad_(False)

'''在使用ppo_freeze的时候,model为qwen1.5-0.5b的时候,第一次load actor model的时候添加数据类型的转换
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change to english? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants