diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index ce39d0ff6..085133203 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -251,6 +251,8 @@ class Args: """number of vLLM Engines, set to 0 to disable vLLM""" vllm_tensor_parallel_size: int = 1 """tensor parallel size of vLLM Engine for multi-GPU inference""" + vllm_enforce_eager: bool = False + """whether to enforce eager mode for vLLM -- slow inference but needed for multi-node""" vllm_sync_backend: str = "nccl" """DeepSpeed -> vLLM weight sync backend""" enable_prefix_caching: bool = False @@ -1683,6 +1685,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): vllm_engines = create_vllm_engines( args.vllm_num_engines, args.vllm_tensor_parallel_size, + args.vllm_enforce_eager, model_config.model_name_or_path, model_config.model_revision, args.seed, diff --git a/open_instruct/vllm_utils2.py b/open_instruct/vllm_utils2.py index 7dbdf6a54..0d824f8c0 100644 --- a/open_instruct/vllm_utils2.py +++ b/open_instruct/vllm_utils2.py @@ -192,6 +192,7 @@ def stop_remote_worker_execution_loop(self): def create_vllm_engines( num_engines: int, tensor_parallel_size: int, + enforce_eager: bool, pretrain: str, revision: str, seed: int, @@ -224,6 +225,7 @@ def create_vllm_engines( tokenizer_revision=revision, trust_remote_code=True, tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, dtype="bfloat16", seed=seed + i, enable_prefix_caching=enable_prefix_caching,