99from rllm .parser import ChatTemplateParser , ToolParser
1010
1111
12- def parse_openai_error_for_unsupported_param (error_message : str ) -> tuple [str | None , str | None ]:
13- """
14- Parse OpenAI API error to extract unsupported parameter and suggested replacement.
15-
16- Returns: (unsupported_param, suggested_param) or (None, None) if not parseable
17-
18- Example errors:
19- - "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead."
20- - "Unsupported value: 'temperature' does not support 0.6 with this model. Only the default (1) value is supported."
21- """
22- if "unsupported parameter" in error_message .lower ():
23- # Extract parameter name from quotes
24- import re
25-
26- match = re .search (r"'([^']+)'\s+is not supported" , error_message , re .IGNORECASE )
27- if match :
28- unsupported = match .group (1 )
29- # Check for suggested replacement
30- suggest_match = re .search (r"use\s+'([^']+)'\s+instead" , error_message , re .IGNORECASE )
31- suggested = suggest_match .group (1 ) if suggest_match else None
32- return unsupported , suggested
33-
34- if "unsupported value" in error_message .lower ():
35- # Parameter exists but value not allowed - remove the param entirely
36- import re
37-
38- match = re .search (r"'([^']+)'\s+does not support" , error_message , re .IGNORECASE )
39- if match :
40- return match .group (1 ), None
41-
42- return None , None
43-
44-
4512class OpenAIEngine (RolloutEngine ):
4613 def __init__ (self , model : str , tokenizer = None , api_retries : int = 3 , base_url : str = "https://api.openai.com/v1" , api_key : str = os .getenv ("OPENAI_API_KEY" ), sampling_params : dict | None = None , ** kwargs ):
4714 self .model = model
4815 self .api_retries = api_retries
4916 self .sampling_params = sampling_params or {}
50- self ._param_fixes_logged = set () # Track which param fixes we've already logged
5117
5218 self .tokenizer = tokenizer
5319 if self .tokenizer is not None :
@@ -65,14 +31,29 @@ def __init__(self, model: str, tokenizer=None, api_retries: int = 3, base_url: s
6531 self .client = openai .AsyncOpenAI (base_url = base_url , api_key = api_key )
6632 logging .getLogger ("httpx" ).setLevel (logging .WARNING )
6733
34+ def _fix_unsupported_param (self , error_msg : str , sampling_params : dict ) -> bool :
35+ """Fix unsupported parameters based on error message. Returns True if fixed."""
36+
37+ # Try to extract unsupported parameter from error
38+ if "max_tokens" in error_msg and "max_completion_tokens" in error_msg :
39+ if "max_tokens" in sampling_params :
40+ sampling_params ["max_completion_tokens" ] = sampling_params .pop ("max_tokens" )
41+ return True
42+
43+ # Remove any unsupported parameter mentioned in error
44+ for param in ["temperature" , "top_p" , "presence_penalty" , "frequency_penalty" ]:
45+ if param in error_msg .lower () and "not support" in error_msg .lower ():
46+ sampling_params .pop (param , None )
47+ return True
48+
49+ return False
50+
6851 async def chat_completion (self , messages : list [dict ], ** kwargs ) -> ModelOutput :
6952 sampling_params = self .sampling_params .copy ()
7053 sampling_params .update (kwargs )
7154 sampling_params .pop ("model" , None )
7255
7356 retries = self .api_retries
74- param_retry_budget = 10 # Allow up to 10 parameter fixes (reasoning models can reject many params)
75-
7657 while retries > 0 :
7758 try :
7859 response = await self .client .chat .completions .create (model = self .model , messages = messages , timeout = 3600 , ** sampling_params )
@@ -87,39 +68,9 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
8768 print ("Sleep for 5 seconds for API limit." )
8869 await asyncio .sleep (5 )
8970 except openai .BadRequestError as e :
90- # Try to auto-fix unsupported parameters
9171 error_msg = str (e )
92- unsupported_param , suggested_param = parse_openai_error_for_unsupported_param (error_msg )
93-
94- if unsupported_param and param_retry_budget > 0 :
95- param_retry_budget -= 1
96-
97- # Only log this fix once per engine instance
98- log_key = f"{ unsupported_param } ->{ suggested_param } " if suggested_param else f"remove:{ unsupported_param } "
99- should_log = log_key not in self ._param_fixes_logged
100- if should_log :
101- self ._param_fixes_logged .add (log_key )
102- print (f"⚠️ Model { self .model } doesn't support '{ unsupported_param } ', adjusting parameters..." )
103-
104- if suggested_param :
105- # Remap parameter (e.g., max_tokens -> max_completion_tokens)
106- if unsupported_param in sampling_params :
107- value = sampling_params .pop (unsupported_param )
108- if suggested_param not in sampling_params :
109- sampling_params [suggested_param ] = value
110- if should_log :
111- print (f" Remapped '{ unsupported_param } ' -> '{ suggested_param } '" )
112- else :
113- # Just remove the unsupported parameter
114- if unsupported_param in sampling_params :
115- sampling_params .pop (unsupported_param )
116- if should_log :
117- print (f" Removed '{ unsupported_param } '" )
118-
119- # Retry immediately with fixed params (don't decrement retries)
120- continue
121-
122- # Can't auto-fix or out of param retry budget
72+ if self ._fix_unsupported_param (error_msg , sampling_params ):
73+ continue # Retry with fixed params
12374 retries -= 1
12475 if retries == 0 :
12576 raise Exception (f"Error processing content after retries: { e } " ) from e
@@ -138,8 +89,6 @@ async def completion(self, prompt: str, **kwargs) -> ModelOutput:
13889 sampling_params .pop ("model" , None )
13990
14091 retries = self .api_retries
141- param_retry_budget = 10 # Allow up to 10 parameter fixes (reasoning models can reject many params)
142-
14392 while retries > 0 :
14493 try :
14594 response = await self .client .completions .create (model = self .model , prompt = prompt , timeout = 3600 , ** sampling_params )
@@ -151,39 +100,9 @@ async def completion(self, prompt: str, **kwargs) -> ModelOutput:
151100 print ("Sleep for 5 seconds for API limit." )
152101 await asyncio .sleep (5 )
153102 except openai .BadRequestError as e :
154- # Try to auto-fix unsupported parameters
155103 error_msg = str (e )
156- unsupported_param , suggested_param = parse_openai_error_for_unsupported_param (error_msg )
157-
158- if unsupported_param and param_retry_budget > 0 :
159- param_retry_budget -= 1
160-
161- # Only log this fix once per engine instance
162- log_key = f"{ unsupported_param } ->{ suggested_param } " if suggested_param else f"remove:{ unsupported_param } "
163- should_log = log_key not in self ._param_fixes_logged
164- if should_log :
165- self ._param_fixes_logged .add (log_key )
166- print (f"⚠️ Model { self .model } doesn't support '{ unsupported_param } ', adjusting parameters..." )
167-
168- if suggested_param :
169- # Remap parameter (e.g., max_tokens -> max_completion_tokens)
170- if unsupported_param in sampling_params :
171- value = sampling_params .pop (unsupported_param )
172- if suggested_param not in sampling_params :
173- sampling_params [suggested_param ] = value
174- if should_log :
175- print (f" Remapped '{ unsupported_param } ' -> '{ suggested_param } '" )
176- else :
177- # Just remove the unsupported parameter
178- if unsupported_param in sampling_params :
179- sampling_params .pop (unsupported_param )
180- if should_log :
181- print (f" Removed '{ unsupported_param } '" )
182-
183- # Retry immediately with fixed params (don't decrement retries)
184- continue
185-
186- # Can't auto-fix or out of param retry budget
104+ if self ._fix_unsupported_param (error_msg , sampling_params ):
105+ continue # Retry with fixed params
187106 retries -= 1
188107 if retries == 0 :
189108 raise Exception (f"Error processing content after retries: { e } " ) from e
0 commit comments