diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index c1a008ba..8bd430f3 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -89,19 +89,20 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ "model context length should be greater than total sequence length" - self.inference_engine = LLM(actor_module, - tokenizer=tokenizer, - model_hf_config=model_hf_config, - tensor_parallel_size=tensor_parallel_size, - dtype=config.dtype, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - skip_tokenizer_init=False, - max_model_len=config.prompt_length + config.response_length, - load_format=config.load_format, - disable_log_stats=False, - max_num_batched_tokens=max_num_batched_tokens, - ) + self.inference_engine = LLM( + actor_module, + tokenizer=tokenizer, + model_hf_config=model_hf_config, + tensor_parallel_size=tensor_parallel_size, + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + skip_tokenizer_init=False, + max_model_len=config.prompt_length + config.response_length, + load_format=config.load_format, + disable_log_stats=False, + max_num_batched_tokens=max_num_batched_tokens, + ) # Offload vllm model to reduce peak memory usage self.inference_engine.offload_model_weights()