Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d00dcdd
uses add_request
finbarrtimbers Sep 2, 2025
7b74746
ran linter
finbarrtimbers Sep 2, 2025
1408569
Clean up
finbarrtimbers Sep 2, 2025
dba447a
Fixed bug
finbarrtimbers Sep 2, 2025
3932bf1
Added duplication
finbarrtimbers Sep 2, 2025
7becde9
Added prompt_tokens to metadata.
finbarrtimbers Sep 2, 2025
d2cb9f7
Added missing key to metadata
finbarrtimbers Sep 2, 2025
01dd23b
Fixed bug where we weren't returning properly.
finbarrtimbers Sep 2, 2025
f47879f
Fix script
finbarrtimbers Sep 2, 2025
ca5f07d
Added logging
finbarrtimbers Sep 2, 2025
3f626fb
fix bug
finbarrtimbers Sep 2, 2025
ca79b8b
use clone for SamplingParams
finbarrtimbers Sep 2, 2025
084ad77
Fixes to duplication
finbarrtimbers Sep 2, 2025
d2e6041
Removed logging.
finbarrtimbers Sep 2, 2025
12a4ce7
Cleaned up PR.
finbarrtimbers Sep 2, 2025
0813b85
Clean PR
finbarrtimbers Sep 2, 2025
e9d6cfb
Removed whitespace
finbarrtimbers Sep 2, 2025
417748b
Cleaned up PR
finbarrtimbers Sep 2, 2025
1ff890c
Merge branch 'main' into combined-llm-loop
finbarrtimbers Sep 2, 2025
d96e7b2
Added comment for cleaner PR.
finbarrtimbers Sep 2, 2025
fe8e1bf
Merge branch 'main' into combined-llm-loop
finbarrtimbers Sep 2, 2025
f133a8e
Cleaning up PR
finbarrtimbers Sep 2, 2025
341a77b
Revert "load pretokenized user query (v0) (#965)"
finbarrtimbers Sep 3, 2025
8ebfdf9
Bug fix.
finbarrtimbers Sep 3, 2025
e88c2c2
Fixed issue where we weren't setting params right in tools.
finbarrtimbers Sep 3, 2025
a929c63
Updated descriptions.
finbarrtimbers Sep 3, 2025
ba441f0
Fix ordering.
finbarrtimbers Sep 3, 2025
d4e6fd9
Updated tool script with description.
finbarrtimbers Sep 3, 2025
4e7cbe3
Fixed use of wrong vllm.SamplingParams.
finbarrtimbers Sep 3, 2025
d425a42
Now, tool use should run.
finbarrtimbers Sep 3, 2025
826f199
Reapply "load pretokenized user query (v0) (#965)"
finbarrtimbers Sep 3, 2025
e214dba
Merge branch 'main' into combined-llm-loop
finbarrtimbers Sep 3, 2025
74fb0a6
minor clean up.
finbarrtimbers Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2919,6 +2919,8 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
actor_manager,
checkpoint_state,
)
except Exception as e:
logger.error(f"Error in run_training: {e}", exc_info=True)
finally:
cleanup_training_resources(
stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager
Expand Down
175 changes: 101 additions & 74 deletions open_instruct/vllm_utils3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""

import copy
import os
import queue
import time
Expand Down Expand Up @@ -43,7 +42,7 @@
from vllm.v1.core import kv_cache_utils

from open_instruct import logger_utils
from open_instruct.queue_types import GenerationResult, RequestInfo, TokenStatistics
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics
from open_instruct.tool_utils.tool_vllm import MaxCallsExceededTool, Tool
from open_instruct.utils import ray_get_with_progress

Expand Down Expand Up @@ -93,7 +92,7 @@ def _handle_output(output, tools, tracking, sampling_params, max_tool_calls, exe
if not tools:
return output

assert len(output.outputs) <= 1 # In tool mode, sampling_params.n == 1
assert len(output.outputs) <= 1, f"{len(output.outputs)=}" # In tool mode, sampling_params.n == 1
o = output.outputs[0]

# Update concatenated outputs
Expand Down Expand Up @@ -203,7 +202,6 @@ def _process_outputs_with_tools(
def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=None, start_time=None):
"""Prepare final outputs based on whether tools were used."""
if not tools:
outputs.sort(key=lambda x: int(x.request_id.split("_")[-1]))
return _process_outputs(
outputs, dataset_index=dataset_index, token_statistics=token_statistics, start_time=start_time
)
Expand All @@ -223,14 +221,14 @@ def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=
# Merge n completions into the same outputs
merged_outputs = {}
for req_id in tracking["concat_outputs"]:
real_req_id, _ = req_id.split("-")
real_req_id = "_".join(req_id.split("_")[:-1])
if real_req_id not in merged_outputs:
merged_outputs[real_req_id] = tracking["concat_outputs"][req_id]
else:
merged_outputs[real_req_id].outputs.append(tracking["concat_outputs"][req_id].outputs[0])

final_outputs = sorted(
merged_outputs.values(), key=lambda x: (int(x.request_id.split("-")[0]), int(x.request_id.split("-")[1]))
merged_outputs.values(), key=lambda x: (int(x.request_id.split("_")[1]), int(x.request_id.split("_")[2]))
)

return _process_outputs_with_tools(
Expand Down Expand Up @@ -317,6 +315,32 @@ def init_process_group(
return pg


def add_request(request: PromptRequest, llm_engine: vllm.LLMEngine, tools, request_metadata: dict):
"""Add a request to the LLM engine."""
prefix = "eval" if request.is_eval else "train"

for batch_idx, prompt in enumerate(request.prompts):
request_id = f"{prefix}_{request.training_step}_{batch_idx}"
sampling_params = request.generation_config.clone()
sampling_params.n = 1 # Use n=1 for tool processing
request_metadata[request_id] = {
"is_eval": request.is_eval,
"dataset_index": request.dataset_index[batch_idx],
"training_step": request.training_step,
"sampling_params": sampling_params,
"prompt_tokens": len(prompt),
"start_time": time.perf_counter(),
}

tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id)

for j in range(request.generation_config.n):
sub_sampling_params = sampling_params.clone() # Already has n=1
if request.generation_config.seed is not None:
sub_sampling_params.seed = request.generation_config.seed + j
llm_engine.add_request(f"{request_id}_{j}", tokens_prompt, sub_sampling_params)


class LLMRayActor:
"""Ray actor for LLM generation with optional tool support."""

Expand Down Expand Up @@ -384,6 +408,15 @@ def _should_stop(self) -> bool:
ray.cancel(should_stop_ref)
return self._should_stop_value

def _insert_result_to_queue(self, result, is_eval: bool):
"""Insert result into the appropriate queue with error handling."""
try:
results_queue = self.eval_results_queue if is_eval else self.results_queue
results_queue.put(result, timeout=10)
except queue.Full:
queue_name = "eval" if is_eval else "train"
self.logger.warning(f"{queue_name} results queue is full, discarding result.")

def process_from_queue(self, timeout: float = 60.0):
"""Run generation loop using LLMEngine directly, with optional tool support.

Expand All @@ -401,37 +434,20 @@ def process_from_queue(self, timeout: float = 60.0):

result = self._process_request(request)

try:
if request.is_eval:
self.eval_results_queue.put(result, timeout=10)
else:
self.results_queue.put(result, timeout=10)
return 1 # Successfully processed one request
except queue.Full:
self.logger.warning("Results queue is full, discarding result.")
return 0
self._insert_result_to_queue(result, is_eval=request.is_eval)
return 1

def _process_request(self, request):
"""Unified processing for both tool and non-tool generation."""
prompts = request.prompts
sampling_params = request.generation_config
start_time = request.start_time

self.logger.info(f"[LLMRayActor] Processing request with {len(prompts)} prompts, tools={bool(self.tools)}")
self.logger.info(
f"[LLMRayActor] Processing request with {len(request.prompts)} prompts, tools={bool(self.tools)}"
)

if self.tools:
# Need n=1 for individual tool tracking
sampling_params = copy.deepcopy(sampling_params)
original_n = request.generation_config.n
sampling_params.n = 1
tracking = _init_tool_tracking()
tokenizer = self.llm_engine.tokenizer
else:
original_n = 1
tracking = None
tokenizer = None
tracking = _init_tool_tracking() if self.tools else None
tokenizer = self.llm_engine.tokenizer

self._add_initial_requests(prompts, sampling_params, original_n, request.training_step)
add_request(request, self.llm_engine, self.tools, request_metadata=self.request_metadata)

outputs = []
iteration = 0
Expand All @@ -441,18 +457,19 @@ def _process_request(self, request):

# Poll tool futures first (matching ToolUseLLM order)
if tracking and tracking.get("pending_tool_futures"):
self._poll_tool_futures(tracking, sampling_params, tokenizer)
outputs.extend(self._poll_tool_futures(tracking, tokenizer))

# Process engine steps - ONLY if there are unfinished requests (matching ToolUseLLM)
if self.llm_engine.has_unfinished_requests():
step_outputs = list(self.llm_engine.step())
step_outputs = [o for o in self.llm_engine.step() if o.finished]
for output in step_outputs:
if output.finished:
result = _handle_output(
output, self.tools, tracking, sampling_params, self.max_tool_calls, self.executor
)
if result is not None:
outputs.append(result)
self.logger.info(f"{len(output.outputs)=}")
result = _handle_output(
output, self.tools, tracking, request.generation_config, self.max_tool_calls, self.executor
)
# Result is None when we do more tool processing.
if result is not None:
outputs.append(result)

# Check termination condition (matching ToolUseLLM exactly)
pending_count = len(tracking["pending_tool_futures"]) if tracking else 0
Expand All @@ -465,23 +482,40 @@ def _process_request(self, request):
total_generation_tokens = 0
earliest_start_time = float("inf")

# Now, we combine outputs:
combined_outputs = defaultdict(list)
for output in outputs:
request_id = output.request_id
if request_id in self.request_metadata:
metadata = self.request_metadata[request_id]
total_prompt_tokens += metadata["prompt_tokens"]
earliest_start_time = min(earliest_start_time, metadata["start_time"])

# Remove the sub_idx.
request_id = "_".join(output.request_id.split("_")[:-1])
combined_outputs[request_id].append(output)
# Preserve original order from request.dataset_index
prefix = "eval" if request.is_eval else "train"
# request_id is batch_num _ training_step _ within_batch_idx _ repetition_idx.
# we order by within_batch_idx.
ordered_ids = [f"{prefix}_{request.training_step}_{batch_idx}" for batch_idx in range(len(request.prompts))]
final_outputs = []
for request_id in ordered_ids:
outs = combined_outputs[request_id]
assert len(outs) == request.generation_config.n, f"{len(outs)=} != {request.generation_config.n=}"
final_outputs.append(
vllm.RequestOutput(
request_id=request_id,
prompt=outs[0].prompt,
prompt_token_ids=outs[0].prompt_token_ids,
prompt_logprobs=outs[0].prompt_logprobs,
outputs=[completion for out in outs for completion in out.outputs],
finished=outs[0].finished,
)
)
metadata = self.request_metadata.pop(request_id)
total_prompt_tokens += metadata["prompt_tokens"]
earliest_start_time = min(earliest_start_time, metadata["start_time"])
for output in outs:
for completion in output.outputs:
total_generation_tokens += len(completion.token_ids)

generation_time = end_time - earliest_start_time

for output in outputs:
self.request_metadata.pop(output.request_id, None)

result = _finalize_outputs(
outputs,
final_outputs,
tracking,
request.dataset_index,
self.tools,
Expand All @@ -490,33 +524,17 @@ def _process_request(self, request):
num_response_tokens=total_generation_tokens,
generation_time=generation_time,
),
start_time=start_time,
start_time=request.start_time,
)
return result

def _add_initial_requests(self, prompts, sampling_params, n_samples, training_step):
"""Add initial requests to the engine."""
for i, prompt in enumerate(prompts):
if self.tools:
# Create individual requests for each sample when using tools
for j in range(n_samples):
request_id = f"{training_step}_{i}-{j}"
self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)}
tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=f"{training_step}_{i}")
self.llm_engine.add_request(request_id, tokens_prompt, sampling_params)
else:
# Standard request format for non-tool mode
request_id = f"batch_{training_step}_{i}"
self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)}
tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id)
self.llm_engine.add_request(request_id, tokens_prompt, sampling_params)

def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
def _poll_tool_futures(self, tracking, tokenizer):
"""Poll and handle completed tool executions."""
if not self.tools or not tracking["pending_tool_futures"]:
return
return []

dict_keys_to_delete = []
completed_outputs = []

for req_id, (future, last_o, last_output) in tracking["pending_tool_futures"].items():
if not future.done():
Expand All @@ -525,6 +543,11 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
# Tool future is done, process it
tool_result = future.result() # Get the tool result

# Get sampling params from request metadata for this request
# Extract the base request ID by removing the sub-request suffix
base_req_id = "_".join(req_id.split("_")[:-1])
sampling_params = self.request_metadata[base_req_id]["sampling_params"]

last_prompt_token_ids = last_output.prompt_token_ids
last_token_ids = last_o.token_ids
tool_output_token_ids = tokenizer.encode(
Expand Down Expand Up @@ -559,7 +582,7 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
can_make_new_request = can_make_new_request and new_sample_tokens > 0

if can_make_new_request:
new_sampling_params = copy.deepcopy(sampling_params)
new_sampling_params = sampling_params.clone()
new_sampling_params.max_tokens = new_sample_tokens

try:
Expand All @@ -569,12 +592,16 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
except Exception as e:
# Match original ToolUseLLM behavior - just log and continue
self.logger.error(f"[_poll_tool_futures] Error adding request {req_id}: {e}")
else:
# If we can't make a new request, this tool execution is complete
completed_outputs.append(tracking["concat_outputs"][req_id])

dict_keys_to_delete.append(req_id)

for req_id in dict_keys_to_delete:
if req_id in tracking["pending_tool_futures"]:
del tracking["pending_tool_futures"][req_id]
tracking["pending_tool_futures"].pop(req_id, None)

return completed_outputs

def init_process_group(
self,
Expand Down
4 changes: 2 additions & 2 deletions scripts/train/debug/large_test_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ uv run python mason.py \
--priority urgent \
--preemptible \
--num_nodes 2 \
--description "rlvr ace fn and og ocr stdio from base with perf penalty" \
--description "Large (multi-node) test script." \
--max_retries 0 \
--env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
--budget ai2/oe-adapt \
Expand All @@ -39,7 +39,7 @@ uv run python mason.py \
--stop_strings "</answer>" \
--non_stop_penalty False \
--temperature 1.0 \
--verbose False \
--verbose False \
--ground_truths_key ground_truth \
--sft_messages_key messages \
--total_episodes 10_000 \
Expand Down
1 change: 1 addition & 0 deletions scripts/train/debug/single_gpu_integration_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ uv run python mason.py \
--cluster ai2/augusta-google-1 \
--cluster ai2/saturn-cirrascale \
--image "$BEAKER_IMAGE" \
--description "Single GPU on Beaker integration test." \
--pure_docker_mode \
--workspace ai2/open-instruct-dev \
--priority high \
Expand Down
1 change: 1 addition & 0 deletions scripts/train/debug/single_gpu_on_beaker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ uv run python mason.py \
--cluster ai2/saturn-cirrascale \
--cluster ai2/ceres-cirrascale \
--image "$BEAKER_IMAGE" \
--description "Single GPU on Beaker test script." \
--pure_docker_mode \
--workspace ai2/open-instruct-dev \
--priority urgent \
Expand Down
1 change: 1 addition & 0 deletions scripts/train/debug/tool_grpo_fast.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ uv run python mason.py \
--cluster ai2/augusta-google-1 \
--cluster ai2/saturn-cirrascale \
--image "$BEAKER_IMAGE" \
--description "Single GPU on Beaker with tool use test script." \
--pure_docker_mode \
--workspace ai2/tulu-thinker \
--priority high \
Expand Down