Skip to content

Commit 339582f

Browse files
rollout upgrades
1 parent 92c0c89 commit 339582f

File tree

13 files changed

+541
-532
lines changed

13 files changed

+541
-532
lines changed

examples/frozenlake/train_frozenlake_agent.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ python3 -m examples.frozenlake.train_frozenlake_agent \
5252
trainer.logger=['console','wandb'] \
5353
trainer.project_name='rllm-agent' \
5454
trainer.experiment_name='frozenlake-agent-0.6B' \
55-
trainer.val_before_train=False \
55+
trainer.val_before_train=True \
5656
trainer.n_gpus_per_node=8 \
5757
trainer.nnodes=1 \
5858
trainer.save_freq=40 \

examples/simple_math/train_hendrycks_math.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ python3 -m examples.simple_math.train_hendrycks_math \
5454
trainer.critic_warmup=0 \
5555
trainer.logger=['console','wandb'] \
5656
trainer.project_name='rllm-agent' \
57-
trainer.experiment_name='deepscaler-debug-math-fsdp1' \
57+
trainer.experiment_name='simple-math' \
5858
trainer.val_before_train=True \
5959
trainer.n_gpus_per_node=8 \
6060
trainer.nnodes=1 \

rllm/engine/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""
55

66
from .agent_execution_engine import AgentExecutionEngine, AsyncAgentExecutionEngine
7-
from .rollout.openai_engine import OpenAIEngine
7+
8+
# Avoid importing rollout submodules eagerly to prevent circular imports with workflows
9+
# Import base class only (no side effects) and lazy-load specific engines via __getattr__
810
from .rollout.rollout_engine import RolloutEngine
911

1012
__all__ = [
@@ -13,20 +15,24 @@
1315
"AgentWorkflowEngine",
1416
"RolloutEngine",
1517
"OpenAIEngine",
18+
"VerlEngine",
1619
]
1720

18-
# VerlEngine is optional; only export if verl is installed
19-
try:
20-
from .rollout.verl_engine import VerlEngine
21-
22-
__all__.append("VerlEngine")
23-
except Exception:
24-
VerlEngine = None
25-
2621

2722
def __getattr__(name):
2823
if name == "AgentWorkflowEngine":
2924
from .agent_workflow_engine import AgentWorkflowEngine as _AgentWorkflowEngine
3025

3126
return _AgentWorkflowEngine
27+
if name == "OpenAIEngine":
28+
from .rollout.openai_engine import OpenAIEngine as _OpenAIEngine
29+
30+
return _OpenAIEngine
31+
if name == "VerlEngine":
32+
try:
33+
from .rollout.verl_engine import VerlEngine as _VerlEngine
34+
35+
return _VerlEngine
36+
except Exception:
37+
raise AttributeError(name) from None
3238
raise AttributeError(name)

rllm/engine/agent_execution_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def __init__(
102102
**rollout_engine_args,
103103
api_retries=api_retries,
104104
tokenizer=self.tokenizer,
105+
max_prompt_length=self.max_prompt_length,
106+
max_response_length=self.max_response_length,
105107
disable_thinking=kwargs.get("disable_thinking", False),
106108
)
107109
elif self.engine_name == "verl":
@@ -140,12 +142,12 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str:
140142
sampling_params.update(kwargs)
141143

142144
if self.engine_name == "openai":
143-
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, **sampling_params)
145+
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, enforce_max_prompt_length=False, **sampling_params)
144146
return output.text
145147
elif self.engine_name == "verl":
146148
meta_data = sampling_params.pop("meta_info", {})
147149
validate = meta_data.get("validate", False)
148-
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, **sampling_params)
150+
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, enforce_max_prompt_length=False, **sampling_params)
149151
return output.text
150152
else:
151153
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")

rllm/engine/rollout/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
1-
from .openai_engine import OpenAIEngine
1+
# Avoid importing concrete engines at module import time to prevent circular imports
22
from .rollout_engine import ModelOutput, RolloutEngine
33

44
__all__ = [
55
"ModelOutput",
66
"RolloutEngine",
77
"OpenAIEngine",
8+
"VerlEngine",
89
]
10+
11+
12+
def __getattr__(name):
13+
if name == "OpenAIEngine":
14+
from .openai_engine import OpenAIEngine as _OpenAIEngine
15+
16+
return _OpenAIEngine
17+
if name == "VerlEngine":
18+
try:
19+
from .verl_engine import VerlEngine as _VerlEngine
20+
21+
return _VerlEngine
22+
except Exception:
23+
raise AttributeError(name) from None
24+
raise AttributeError(name)

rllm/engine/rollout/openai_engine.py

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,80 @@
66

77
from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine
88
from rllm.globals import THOUGHT_DELIMITER_END, THOUGHT_DELIMITER_START
9-
from rllm.parser import ChatTemplateParser, ToolParser
9+
from rllm.parser import ChatTemplateParser
10+
from rllm.tools.tool_base import Tool
11+
from rllm.workflows import TerminationEvent, TerminationReason
1012

1113

1214
class OpenAIEngine(RolloutEngine):
13-
def __init__(self, model: str, tokenizer=None, api_retries: int = 3, base_url: str = "https://api.openai.com/v1", api_key: str = os.getenv("OPENAI_API_KEY"), sampling_params: dict | None = None, **kwargs):
15+
def __init__(self, model: str = "", tokenizer=None, max_prompt_length: int = 4096, max_response_length: int = 4096, max_model_length: int | None = None, api_retries: int = 3, base_url: str = "https://api.openai.com/v1", api_key: str = os.getenv("OPENAI_API_KEY"), sampling_params: dict | None = None, tools: list[Tool | dict] = None, accumulate_reasoning: bool = False, **kwargs):
1416
self.model = model
17+
self.max_prompt_length = max_prompt_length
18+
self.max_response_length = max_response_length
19+
self.max_model_length = max_model_length - 1 if max_model_length is not None else max_prompt_length + max_response_length - 1
1520
self.api_retries = api_retries
1621
self.sampling_params = sampling_params or {}
22+
self.tools = tools or []
23+
self.accumulate_reasoning = accumulate_reasoning
1724

1825
self.tokenizer = tokenizer
1926
if self.tokenizer is not None:
2027
self.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=kwargs.get("disable_thinking", False))
21-
try:
22-
self.tool_parser = ToolParser.get_parser(self.tokenizer)
23-
except Exception:
24-
print(f"Warning: No tool parser found for {self.tokenizer.name_or_path}. Tool calls not be parsed.")
25-
self.tool_parser = None
2628
self._use_chat_completions = False
2729
else:
28-
print("No tokenizer provided, will use the chat completions endpoint. This is not recommended.")
30+
# In this case, we cannot enforce max prompt length or dynamically adjust max_tokens <= max_response_length if needed
31+
print("No tokenizer provided to OpenAIEngine, will use the chat completions endpoint.")
2932
self._use_chat_completions = True
3033

3134
self.client = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
3235
logging.getLogger("httpx").setLevel(logging.WARNING)
3336

3437
async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
38+
kwargs.pop("application_id", None)
39+
kwargs.pop("validate", None)
40+
kwargs.pop("model", None)
41+
kwargs.pop("enforce_max_prompt_length", None)
42+
3543
sampling_params = self.sampling_params.copy()
3644
sampling_params.update(kwargs)
37-
sampling_params.pop("model", None)
45+
46+
max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length))
47+
3848
retries = self.api_retries
3949
while retries > 0:
4050
try:
41-
response = await self.client.chat.completions.create(model=self.model, messages=messages, timeout=3600, **sampling_params)
42-
text = response.choices[0].message.content
43-
if hasattr(response.choices[0].message, "reasoning") and isinstance(response.choices[0].message.reasoning, str):
44-
text = f"{THOUGHT_DELIMITER_START}\n{response.choices[0].message.reasoning}\n{THOUGHT_DELIMITER_END}\n\n{text}"
45-
return ModelOutput(text=text, tool_calls=response.choices[0].message.tool_calls, finish_reason=response.choices[0].finish_reason, completion_tokens=response.usage.completion_tokens, prompt_tokens=response.usage.prompt_tokens)
51+
response = await self.client.chat.completions.create(model=self.model, messages=messages, timeout=3600, max_tokens=max_tokens, **sampling_params)
52+
53+
content = response.choices[0].message.content
54+
reasoning = response.choices[0].message.reasoning if hasattr(response.choices[0].message, "reasoning") and isinstance(response.choices[0].message.reasoning, str) else ""
55+
tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") and isinstance(response.choices[0].message.tool_calls, list) else []
56+
57+
if reasoning:
58+
text = f"{THOUGHT_DELIMITER_START}\n{reasoning}\n{THOUGHT_DELIMITER_END}\n\n{content}" # best guess
59+
60+
prompt_length = response.usage.prompt_tokens
61+
completion_length = response.usage.completion_tokens
62+
finish_reason = response.choices[0].finish_reason
63+
64+
return ModelOutput(
65+
text=text,
66+
content=content,
67+
reasoning=reasoning,
68+
tool_calls=tool_calls,
69+
prompt_ids=[],
70+
completion_ids=[],
71+
prompt_length=prompt_length,
72+
completion_length=completion_length,
73+
finish_reason=finish_reason,
74+
)
75+
4676
except openai.RateLimitError:
4777
retries -= 1
4878
if retries == 0:
4979
raise Exception("Rate limit reached and retries exhausted.") from None
5080
print("Sleep for 5 seconds for API limit.")
5181
await asyncio.sleep(5)
82+
5283
except Exception as e:
5384
retries -= 1
5485
if retries == 0:
@@ -57,20 +88,58 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
5788
await asyncio.sleep(1)
5889

5990
async def completion(self, prompt: str, **kwargs) -> ModelOutput:
91+
kwargs.pop("application_id", None)
92+
kwargs.pop("validate", None)
93+
kwargs.pop("model", None)
94+
enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True)
95+
6096
sampling_params = self.sampling_params.copy()
6197
sampling_params.update(kwargs)
62-
sampling_params.pop("model", None)
98+
99+
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
100+
prompt_length = len(prompt_ids)
101+
102+
if enforce_max_prompt_length and (prompt_length > self.max_prompt_length or prompt_length > self.max_model_length):
103+
raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED)
104+
105+
max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length))
106+
remaining_tokens = self.max_model_length - prompt_length
107+
if remaining_tokens <= max_tokens:
108+
max_tokens = remaining_tokens
109+
print(f"Warning: Decreasing max_tokens to {max_tokens} to stay within max_model_length")
110+
63111
retries = self.api_retries
64112
while retries > 0:
65113
try:
66-
response = await self.client.completions.create(model=self.model, prompt=prompt, timeout=3600, **sampling_params)
67-
return ModelOutput(text=response.choices[0].text, tool_calls=[], finish_reason=response.choices[0].finish_reason, completion_tokens=response.usage.completion_tokens, prompt_tokens=response.usage.prompt_tokens)
114+
response = await self.client.completions.create(model=self.model, prompt=prompt, timeout=3600, max_tokens=max_tokens, **sampling_params)
115+
116+
text = response.choices[0].text
117+
completion_ids = self.tokenizer.encode(text, add_special_tokens=False)
118+
parsed_output = self.chat_parser.parse_completion(completion_ids)
119+
120+
prompt_length = response.usage.prompt_tokens
121+
completion_length = response.usage.completion_tokens
122+
finish_reason = response.choices[0].finish_reason
123+
124+
return ModelOutput(
125+
text=text,
126+
content=parsed_output["content"],
127+
reasoning=parsed_output["reasoning"],
128+
tool_calls=parsed_output["tool_calls"],
129+
prompt_ids=prompt_ids,
130+
completion_ids=completion_ids,
131+
prompt_length=prompt_length,
132+
completion_length=completion_length,
133+
finish_reason=finish_reason,
134+
)
135+
68136
except openai.RateLimitError:
69137
retries -= 1
70138
if retries == 0:
71139
raise Exception("Rate limit reached and retries exhausted.") from None
72140
print("Sleep for 5 seconds for API limit.")
73141
await asyncio.sleep(5)
142+
74143
except Exception as e:
75144
retries -= 1
76145
if retries == 0:
@@ -79,13 +148,10 @@ async def completion(self, prompt: str, **kwargs) -> ModelOutput:
79148
await asyncio.sleep(1)
80149

81150
async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput:
82-
kwargs.pop("application_id", None) # only needed for verl engine
83-
kwargs.pop("validate", None) # only needed for verl engine
84151
if self._use_chat_completions:
85152
return await self.chat_completion(messages, **kwargs)
86153
else:
87-
prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True)
88-
output = await self.completion(prompt, **kwargs)
89-
if self.tool_parser is not None:
90-
output.tool_calls = self.tool_parser.parse(output.text)
91-
return output
154+
tools = kwargs.pop("tools", self.tools)
155+
accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning)
156+
prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning)
157+
return await self.completion(prompt, **kwargs)

rllm/engine/rollout/rollout_engine.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,51 @@
11
from dataclasses import dataclass
22

3+
from rllm.tools.tool_base import ToolCall
4+
35

46
@dataclass
57
class ModelOutput:
68
text: str
7-
tool_calls: list
9+
content: str
10+
reasoning: str
11+
tool_calls: list[ToolCall]
12+
prompt_ids: list[int]
13+
completion_ids: list[int]
14+
prompt_length: int
15+
completion_length: int
816
finish_reason: str
9-
completion_tokens: int
10-
prompt_tokens: int
17+
18+
def to_dict(self):
19+
return {
20+
"text": self.text,
21+
"content": self.content,
22+
"reasoning": self.reasoning,
23+
"tool_calls": [tool_call.to_dict() for tool_call in self.tool_calls],
24+
"prompt_ids": self.prompt_ids,
25+
"completion_ids": self.completion_ids,
26+
"prompt_length": self.prompt_length,
27+
"completion_length": self.completion_length,
28+
"finish_reason": self.finish_reason,
29+
}
30+
31+
@classmethod
32+
def from_dict(cls, data: dict):
33+
return cls(
34+
text=data["text"],
35+
content=data["content"],
36+
reasoning=data["reasoning"],
37+
tool_calls=[ToolCall(**tool_call) for tool_call in data["tool_calls"]],
38+
prompt_ids=data["prompt_ids"],
39+
completion_ids=data["completion_ids"],
40+
prompt_length=data["prompt_length"],
41+
completion_length=data["completion_length"],
42+
finish_reason=data["finish_reason"],
43+
)
1144

1245

1346
class RolloutEngine:
14-
def __init__(self, model: str, tokenizer=None, **kwargs):
15-
self.model = model
16-
self.tokenizer = tokenizer
47+
def __init__(self, *args, **kwargs):
48+
pass
1749

1850
async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput:
1951
raise NotImplementedError("get_model_response is not implemented")

0 commit comments

Comments
 (0)