Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RLVR multinode (2 nodes) issue for mistral-nemo-12B. #526

Open
palash04 opened this issue Jan 26, 2025 · 5 comments
Open

RLVR multinode (2 nodes) issue for mistral-nemo-12B. #526

palash04 opened this issue Jan 26, 2025 · 5 comments

Comments

@palash04
Copy link

I am trying to reproduce this - tulu-3-70B on multinode setting for a mistral nemo 12B: (mistralai/Mistral-Nemo-Instruct-2407) model. Single node doesn't work (due to OOM issue). For 2 nodes I am trying following config -

source configs/beaker_configs/ray_node_setup.sh && python open_instruct/ppo_vllm_thread_ray_gtrl.py \
    --dataset_mixer '{"ai2-adapt-dev/gsm8k_math_ifeval_ground_truth_mixed": 1.0}' \
    --dataset_train_splits train \
    --dataset_eval_mixer '{"ai2-adapt-dev/gsm8k_math_ground_truth": 1.0}' \
    --dataset_eval_splits test \
    --max_token_length 2048 \
    --max_prompt_token_length 2048 \
    --response_length 2048 \
    --model_name_or_path mistralai/Mistral-Nemo-Instruct-2407 \
    --reward_model_path allenai/Llama-3.1-Tulu-3-8B-RM \
    --non_stop_penalty \
    --stop_token eos \
    --temperature 1.0 \
    --ground_truths_key ground_truth \
    --chat_template tulu \
    --sft_messages_key messages \
    --learning_rate 3e-7 \
    --total_episodes 10000000 \
    --penalty_reward_value -10.0 \
    --deepspeed_stage 3 \
    --per_device_train_batch_size 2 \
    --local_rollout_forward_batch_size 2 \
    --local_mini_batch_size 32 \
    --local_rollout_batch_size 32 \
    --actor_num_gpus_per_node 7 8 \
    --vllm_tensor_parallel_size 1 \
    --beta 0.05 \
    --apply_verifiable_reward true \
    --output_dir output/rlvr_8b \
    --seed 3 \
    --num_evals 3 \
    --save_freq 100 \
    --reward_model_multiplier 0.0 \
    --gradient_checkpointing \
    --with_tracking

But this also gives OOM -

(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffec342b89f90b75085d518efb01000000 Worker ID: adcce158980aef9684362cb81c4696f40f6c990ad20e795c0cb2bbe1 Node ID: d41950ecc80570ac5de0e8fc785ef3ef48bc502b397350f255e1eca4 Worker IP address: 10.232.80.199 Worker port: 10015 Worker PID: 3825210 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.

If I set vllm_tensor_parallel_size = 2, then I get following errors depending on different actor_num_gpus:

  • actor_num_gpus_per_node 8 7:
    waits at ray.get(pg.ready()) forever because it can't get one more gpu resource.
  • actor_num_gpus_per_node 7 6: left one for inference, two for vllm_tensor_parallel_size
    But runs into -
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464] Error executing method init_process_group. This might cause deadlock in distributed execution.
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464] Traceback (most recent call last):
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]   File "/home/user/miniconda3/envs/SFT/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 456, in execute_method
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]     return executor(*args, **kwargs)
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]   File "/home/user/SFT_data/open-instruct/open_instruct/vllm_utils2.py", line 100, in init_process_group
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]     self._model_update_group = init_process_group(
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]   File "/home/user/SFT_data/open-instruct/open_instruct/vllm_utils2.py", line 77, in init_process_group
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]     pg, _ = _new_process_group_helper(
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]   File "/home/user/miniconda3/envs/SFT/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1569, in _new_process_group_helper
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464]     backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout)
(LLMRayActor pid=3431067, ip=10.232.80.204) ERROR 01-26 13:44:56 worker_base.py:464] RuntimeError: Gloo connectFullMesh failed with [../third_party/gloo/gloo/transport/tcp/pair.cc:144] no error
(LLMRayActor pid=3431067, ip=10.232.80.204) [rank0]:[E126 13:44:56.572231576 ProcessGroupGloo.cpp:143] Gloo connectFullMesh failed with [../third_party/gloo/gloo/transport/tcp/pair.cc:144] no error
(LLMRayActor pid=3431067, ip=10.232.80.204) Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::RayWorkerWrapper.execute_method() (pid=3432378, ip=10.232.80.204, actor_id=3c8418d0c30d35b0e755512001000000, repr=<open_instruct.vllm_utils2.RayWorkerWrapper object at 0x7efd8a2173a0>)
(LLMRayActor pid=3431067, ip=10.232.80.204)   File "/home/user/miniconda3/envs/SFT/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 465, in execute_method
(LLMRayActor pid=3431067, ip=10.232.80.204)     raise e
(LLMRayActor pid=3431067, ip=10.232.80.204)   File "/home/user/miniconda3/envs/SFT/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 456, in execute_method
(LLMRayActor pid=3431067, ip=10.232.80.204)     return executor(*args, **kwargs)
(LLMRayActor pid=3431067, ip=10.232.80.204)   File "/home/user/SFT_data/open-instruct/open_instruct/vllm_utils2.py", line 100, in init_process_group
(LLMRayActor pid=3431067, ip=10.232.80.204)     self._model_update_group = init_process_group(
(LLMRayActor pid=3431067, ip=10.232.80.204)   File "/home/user/SFT_data/open-instruct/open_instruct/vllm_utils2.py", line 77, in init_process_group
(LLMRayActor pid=3431067, ip=10.232.80.204)     pg, _ = _new_process_group_helper(
(LLMRayActor pid=3431067, ip=10.232.80.204)   File "/home/user/miniconda3/envs/SFT/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1569, in _new_process_group_helper
(LLMRayActor pid=3431067, ip=10.232.80.204)     backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout)
(LLMRayActor pid=3431067, ip=10.232.80.204) RuntimeError: Gloo connectFullMesh failed with [../third_party/gloo/gloo/transport/tcp/pair.cc:144] no error

Can anybody point out the correct config for 2 nodes to run mistral-nemo-12B model, with RM kept as is (Tulu-RM)?

@vwxyzjn
Copy link
Collaborator

vwxyzjn commented Jan 26, 2025

Thanks for reporting the issue. However, I am unable to reproduce the error.

    --actor_num_gpus_per_node 7 8 \
    --vllm_tensor_parallel_size 1 \

They should sum to equal to less than the number of GPUs you have (7 + 8 + 1 = 16 <= 16). It looks like you have some gloo multinode connection issues, so I'd suggest investigating there.

This is the exact command I ran

python mason.py \
    --cluster ai2/jupiter-cirrascale-2 \
    --workspace ai2/tulu-3-dev \
    --priority high \
    --preemptible \
    --num_nodes 2 \
    --budget ai2/oe-adapt \
    --gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/ppo_vllm_thread_ray_gtrl.py \
    --dataset_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 1.0}' \
    --dataset_train_splits train \
    --dataset_eval_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 128}' \
    --dataset_eval_splits train \
    --max_token_length 2048 \
    --max_prompt_token_length 2048 \
    --response_length 2048 \
    --model_name_or_path mistralai/Mistral-Nemo-Instruct-2407 \
    --reward_model_path mistralai/Mistral-Nemo-Instruct-2407 \
    --non_stop_penalty \
    --stop_token eos \
    --temperature 1.0 \
    --ground_truths_key ground_truth \
    --chat_template tulu \
    --sft_messages_key messages \
    --learning_rate 3e-7 \
    --total_episodes 10000000 \
    --penalty_reward_value -10.0 \
    --deepspeed_stage 3 \
    --per_device_train_batch_size 2 \
    --local_rollout_forward_batch_size 2 \
    --local_mini_batch_size 32 \
    --local_rollout_batch_size 32 \
    --actor_num_gpus_per_node 7 8 \
    --vllm_tensor_parallel_size 1 \
    --beta 0.05 \
    --apply_verifiable_reward true \
    --output_dir output/rlvr_8b \
    --seed 3 \
    --num_evals 3 \
    --save_freq 100 \
    --reward_model_multiplier 0.0 \
    --gradient_checkpointing \
    --with_tracking

Which in our infra translates to the following, but you should configure it appropriately.

echo 'Running on host: $BEAKER_REPLICA_RANK' && echo 'Running on host: $BEAKER_LEADER_REPLICA_HOSTNAME' && git config --global safe.directory '*' && umask 000 && cd /weka/oe-adapt-default/costah/open-instruct-uv && source configs/beaker_configs/ray_node_setup.sh && uv run python open_instruct/ppo_vllm_thread_ray_gtrl.py --dataset_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 1.0}' --dataset_train_splits train --dataset_eval_mixer '{"allenai/RLVR-GSM-MATH-IF-Mixed-Constraints": 128}' --dataset_eval_splits train --max_token_length 2048 --max_prompt_token_length 2048 --response_length 2048 --model_name_or_path mistralai/Mistral-Nemo-Instruct-2407 --reward_model_path mistralai/Mistral-Nemo-Instruct-2407 --non_stop_penalty --stop_token eos --temperature 1.0 --ground_truths_key ground_truth --chat_template tulu --sft_messages_key messages --learning_rate 3e-7 --total_episodes 10000000 --penalty_reward_value -10.0 --deepspeed_stage 3 --per_device_train_batch_size 2 --local_rollout_forward_batch_size 2 --local_mini_batch_size 32 --local_rollout_batch_size 32 --actor_num_gpus_per_node 7 8 --vllm_tensor_parallel_size 1 --beta 0.05 --apply_verifiable_reward true --output_dir output/rlvr_8b --seed 3 --num_evals 3 --save_freq 100 --reward_model_multiplier 0.0 --gradient_checkpointing --with_tracking

which runs fine

Image

@palash04
Copy link
Author

I see. I will first investigate the gloo multinode issue then. Thanks!

@palash04
Copy link
Author

Hi @vwxyzjn , it is working now. The issue was not gloo related but I updated the reward model to be same as the base model instead of tutu-3-8b-RM. I think different reward model can't be used? It shouldn't be the case though, right?

Image

@palash04
Copy link
Author

palash04 commented Jan 27, 2025

Also, is it required to train RM model beforehand, or can pass same DPOed model in the reward model args which will be trained on the fly?

@vwxyzjn
Copy link
Collaborator

vwxyzjn commented Jan 29, 2025

Yep, you should use a trained RM. No, you can't pass in a DPOed model as RM.

The primary purpose of the RM is to initialize the value model, so it needs to be the same arch as the base model. If you don't want to initialize from an RM you could also initial from the DPO / SFT model, but in our paper we show this could lead to worse average performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants