Skip to content

Commit 9aa1e97

Browse files
yayashuxueclaude
andcommitted
refactor: simplify OpenAI engine parameter handling and fix HLE output directory
- Simplified unsupported parameter handling in OpenAIEngine from 210 to 132 lines - Removed complex parse_openai_error_for_unsupported_param function and duplicate code - Extracted common logic into single _fix_unsupported_param helper method - Fixed HLE evaluation script to always output to examples/deepresearch/hle_outputs/ - Ensures outputs go to gitignored location regardless of where script is run This addresses reviewer feedback about overly complex error handling with code duplication. Tested with GPT-4o and O3-mini models. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 0ec7b65 commit 9aa1e97

File tree

2 files changed

+26
-103
lines changed

2 files changed

+26
-103
lines changed

examples/deepresearch/evaluate_hle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,12 @@ async def main():
518518
parser.add_argument(
519519
"--parallel-tasks", type=int, default=4, help="Number of parallel tasks"
520520
)
521+
# Default output directory relative to script location
522+
script_dir = os.path.dirname(os.path.abspath(__file__))
523+
default_output_dir = os.path.join(script_dir, "hle_outputs")
524+
521525
parser.add_argument(
522-
"--output-dir", default="./hle_outputs", help="Output directory for results"
526+
"--output-dir", default=default_output_dir, help="Output directory for results"
523527
)
524528

525529
args = parser.parse_args()

rllm/engine/rollout/openai_engine.py

Lines changed: 21 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,11 @@
99
from rllm.parser import ChatTemplateParser, ToolParser
1010

1111

12-
def parse_openai_error_for_unsupported_param(error_message: str) -> tuple[str | None, str | None]:
13-
"""
14-
Parse OpenAI API error to extract unsupported parameter and suggested replacement.
15-
16-
Returns: (unsupported_param, suggested_param) or (None, None) if not parseable
17-
18-
Example errors:
19-
- "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead."
20-
- "Unsupported value: 'temperature' does not support 0.6 with this model. Only the default (1) value is supported."
21-
"""
22-
if "unsupported parameter" in error_message.lower():
23-
# Extract parameter name from quotes
24-
import re
25-
26-
match = re.search(r"'([^']+)'\s+is not supported", error_message, re.IGNORECASE)
27-
if match:
28-
unsupported = match.group(1)
29-
# Check for suggested replacement
30-
suggest_match = re.search(r"use\s+'([^']+)'\s+instead", error_message, re.IGNORECASE)
31-
suggested = suggest_match.group(1) if suggest_match else None
32-
return unsupported, suggested
33-
34-
if "unsupported value" in error_message.lower():
35-
# Parameter exists but value not allowed - remove the param entirely
36-
import re
37-
38-
match = re.search(r"'([^']+)'\s+does not support", error_message, re.IGNORECASE)
39-
if match:
40-
return match.group(1), None
41-
42-
return None, None
43-
44-
4512
class OpenAIEngine(RolloutEngine):
4613
def __init__(self, model: str, tokenizer=None, api_retries: int = 3, base_url: str = "https://api.openai.com/v1", api_key: str = os.getenv("OPENAI_API_KEY"), sampling_params: dict | None = None, **kwargs):
4714
self.model = model
4815
self.api_retries = api_retries
4916
self.sampling_params = sampling_params or {}
50-
self._param_fixes_logged = set() # Track which param fixes we've already logged
5117

5218
self.tokenizer = tokenizer
5319
if self.tokenizer is not None:
@@ -65,14 +31,29 @@ def __init__(self, model: str, tokenizer=None, api_retries: int = 3, base_url: s
6531
self.client = openai.AsyncOpenAI(base_url=base_url, api_key=api_key)
6632
logging.getLogger("httpx").setLevel(logging.WARNING)
6733

34+
def _fix_unsupported_param(self, error_msg: str, sampling_params: dict) -> bool:
35+
"""Fix unsupported parameters based on error message. Returns True if fixed."""
36+
37+
# Try to extract unsupported parameter from error
38+
if "max_tokens" in error_msg and "max_completion_tokens" in error_msg:
39+
if "max_tokens" in sampling_params:
40+
sampling_params["max_completion_tokens"] = sampling_params.pop("max_tokens")
41+
return True
42+
43+
# Remove any unsupported parameter mentioned in error
44+
for param in ["temperature", "top_p", "presence_penalty", "frequency_penalty"]:
45+
if param in error_msg.lower() and "not support" in error_msg.lower():
46+
sampling_params.pop(param, None)
47+
return True
48+
49+
return False
50+
6851
async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
6952
sampling_params = self.sampling_params.copy()
7053
sampling_params.update(kwargs)
7154
sampling_params.pop("model", None)
7255

7356
retries = self.api_retries
74-
param_retry_budget = 10 # Allow up to 10 parameter fixes (reasoning models can reject many params)
75-
7657
while retries > 0:
7758
try:
7859
response = await self.client.chat.completions.create(model=self.model, messages=messages, timeout=3600, **sampling_params)
@@ -87,39 +68,9 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
8768
print("Sleep for 5 seconds for API limit.")
8869
await asyncio.sleep(5)
8970
except openai.BadRequestError as e:
90-
# Try to auto-fix unsupported parameters
9171
error_msg = str(e)
92-
unsupported_param, suggested_param = parse_openai_error_for_unsupported_param(error_msg)
93-
94-
if unsupported_param and param_retry_budget > 0:
95-
param_retry_budget -= 1
96-
97-
# Only log this fix once per engine instance
98-
log_key = f"{unsupported_param}->{suggested_param}" if suggested_param else f"remove:{unsupported_param}"
99-
should_log = log_key not in self._param_fixes_logged
100-
if should_log:
101-
self._param_fixes_logged.add(log_key)
102-
print(f"⚠️ Model {self.model} doesn't support '{unsupported_param}', adjusting parameters...")
103-
104-
if suggested_param:
105-
# Remap parameter (e.g., max_tokens -> max_completion_tokens)
106-
if unsupported_param in sampling_params:
107-
value = sampling_params.pop(unsupported_param)
108-
if suggested_param not in sampling_params:
109-
sampling_params[suggested_param] = value
110-
if should_log:
111-
print(f" Remapped '{unsupported_param}' -> '{suggested_param}'")
112-
else:
113-
# Just remove the unsupported parameter
114-
if unsupported_param in sampling_params:
115-
sampling_params.pop(unsupported_param)
116-
if should_log:
117-
print(f" Removed '{unsupported_param}'")
118-
119-
# Retry immediately with fixed params (don't decrement retries)
120-
continue
121-
122-
# Can't auto-fix or out of param retry budget
72+
if self._fix_unsupported_param(error_msg, sampling_params):
73+
continue # Retry with fixed params
12374
retries -= 1
12475
if retries == 0:
12576
raise Exception(f"Error processing content after retries: {e}") from e
@@ -138,8 +89,6 @@ async def completion(self, prompt: str, **kwargs) -> ModelOutput:
13889
sampling_params.pop("model", None)
13990

14091
retries = self.api_retries
141-
param_retry_budget = 10 # Allow up to 10 parameter fixes (reasoning models can reject many params)
142-
14392
while retries > 0:
14493
try:
14594
response = await self.client.completions.create(model=self.model, prompt=prompt, timeout=3600, **sampling_params)
@@ -151,39 +100,9 @@ async def completion(self, prompt: str, **kwargs) -> ModelOutput:
151100
print("Sleep for 5 seconds for API limit.")
152101
await asyncio.sleep(5)
153102
except openai.BadRequestError as e:
154-
# Try to auto-fix unsupported parameters
155103
error_msg = str(e)
156-
unsupported_param, suggested_param = parse_openai_error_for_unsupported_param(error_msg)
157-
158-
if unsupported_param and param_retry_budget > 0:
159-
param_retry_budget -= 1
160-
161-
# Only log this fix once per engine instance
162-
log_key = f"{unsupported_param}->{suggested_param}" if suggested_param else f"remove:{unsupported_param}"
163-
should_log = log_key not in self._param_fixes_logged
164-
if should_log:
165-
self._param_fixes_logged.add(log_key)
166-
print(f"⚠️ Model {self.model} doesn't support '{unsupported_param}', adjusting parameters...")
167-
168-
if suggested_param:
169-
# Remap parameter (e.g., max_tokens -> max_completion_tokens)
170-
if unsupported_param in sampling_params:
171-
value = sampling_params.pop(unsupported_param)
172-
if suggested_param not in sampling_params:
173-
sampling_params[suggested_param] = value
174-
if should_log:
175-
print(f" Remapped '{unsupported_param}' -> '{suggested_param}'")
176-
else:
177-
# Just remove the unsupported parameter
178-
if unsupported_param in sampling_params:
179-
sampling_params.pop(unsupported_param)
180-
if should_log:
181-
print(f" Removed '{unsupported_param}'")
182-
183-
# Retry immediately with fixed params (don't decrement retries)
184-
continue
185-
186-
# Can't auto-fix or out of param retry budget
104+
if self._fix_unsupported_param(error_msg, sampling_params):
105+
continue # Retry with fixed params
187106
retries -= 1
188107
if retries == 0:
189108
raise Exception(f"Error processing content after retries: {e}") from e

0 commit comments

Comments
 (0)