66
77from rllm .engine .rollout .rollout_engine import ModelOutput , RolloutEngine
88from rllm .globals import THOUGHT_DELIMITER_END , THOUGHT_DELIMITER_START
9- from rllm .parser import ChatTemplateParser , ToolParser
9+ from rllm .parser import ChatTemplateParser
10+ from rllm .tools .tool_base import Tool
11+ from rllm .workflows import TerminationEvent , TerminationReason
1012
1113
1214class OpenAIEngine (RolloutEngine ):
13- def __init__ (self , model : str , tokenizer = None , api_retries : int = 3 , base_url : str = "https://api.openai.com/v1" , api_key : str = os .getenv ("OPENAI_API_KEY" ), sampling_params : dict | None = None , ** kwargs ):
15+ def __init__ (self , model : str = "" , tokenizer = None , max_prompt_length : int = 4096 , max_response_length : int = 4096 , max_model_length : int | None = None , api_retries : int = 3 , base_url : str = "https://api.openai.com/v1" , api_key : str = os .getenv ("OPENAI_API_KEY" ), sampling_params : dict | None = None , tools : list [ Tool | dict ] = None , accumulate_reasoning : bool = False , ** kwargs ):
1416 self .model = model
17+ self .max_prompt_length = max_prompt_length
18+ self .max_response_length = max_response_length
19+ self .max_model_length = max_model_length - 1 if max_model_length is not None else max_prompt_length + max_response_length - 1
1520 self .api_retries = api_retries
1621 self .sampling_params = sampling_params or {}
22+ self .tools = tools or []
23+ self .accumulate_reasoning = accumulate_reasoning
1724
1825 self .tokenizer = tokenizer
1926 if self .tokenizer is not None :
2027 self .chat_parser = ChatTemplateParser .get_parser (self .tokenizer , disable_thinking = kwargs .get ("disable_thinking" , False ))
21- try :
22- self .tool_parser = ToolParser .get_parser (self .tokenizer )
23- except Exception :
24- print (f"Warning: No tool parser found for { self .tokenizer .name_or_path } . Tool calls not be parsed." )
25- self .tool_parser = None
2628 self ._use_chat_completions = False
2729 else :
28- print ("No tokenizer provided, will use the chat completions endpoint. This is not recommended." )
30+ # In this case, we cannot enforce max prompt length or dynamically adjust max_tokens <= max_response_length if needed
31+ print ("No tokenizer provided to OpenAIEngine, will use the chat completions endpoint." )
2932 self ._use_chat_completions = True
3033
3134 self .client = openai .AsyncOpenAI (base_url = base_url , api_key = api_key )
3235 logging .getLogger ("httpx" ).setLevel (logging .WARNING )
3336
3437 async def chat_completion (self , messages : list [dict ], ** kwargs ) -> ModelOutput :
38+ kwargs .pop ("application_id" , None )
39+ kwargs .pop ("validate" , None )
40+ kwargs .pop ("model" , None )
41+ kwargs .pop ("enforce_max_prompt_length" , None )
42+
3543 sampling_params = self .sampling_params .copy ()
3644 sampling_params .update (kwargs )
37- sampling_params .pop ("model" , None )
45+
46+ max_tokens = sampling_params .pop ("max_tokens" , sampling_params .pop ("max_new_tokens" , self .max_response_length ))
47+
3848 retries = self .api_retries
3949 while retries > 0 :
4050 try :
41- response = await self .client .chat .completions .create (model = self .model , messages = messages , timeout = 3600 , ** sampling_params )
42- text = response .choices [0 ].message .content
43- if hasattr (response .choices [0 ].message , "reasoning" ) and isinstance (response .choices [0 ].message .reasoning , str ):
44- text = f"{ THOUGHT_DELIMITER_START } \n { response .choices [0 ].message .reasoning } \n { THOUGHT_DELIMITER_END } \n \n { text } "
45- return ModelOutput (text = text , tool_calls = response .choices [0 ].message .tool_calls , finish_reason = response .choices [0 ].finish_reason , completion_tokens = response .usage .completion_tokens , prompt_tokens = response .usage .prompt_tokens )
51+ response = await self .client .chat .completions .create (model = self .model , messages = messages , timeout = 3600 , max_tokens = max_tokens , ** sampling_params )
52+
53+ content = response .choices [0 ].message .content
54+ reasoning = response .choices [0 ].message .reasoning if hasattr (response .choices [0 ].message , "reasoning" ) and isinstance (response .choices [0 ].message .reasoning , str ) else ""
55+ tool_calls = response .choices [0 ].message .tool_calls if hasattr (response .choices [0 ].message , "tool_calls" ) and isinstance (response .choices [0 ].message .tool_calls , list ) else []
56+
57+ if reasoning :
58+ text = f"{ THOUGHT_DELIMITER_START } \n { reasoning } \n { THOUGHT_DELIMITER_END } \n \n { content } " # best guess
59+
60+ prompt_length = response .usage .prompt_tokens
61+ completion_length = response .usage .completion_tokens
62+ finish_reason = response .choices [0 ].finish_reason
63+
64+ return ModelOutput (
65+ text = text ,
66+ content = content ,
67+ reasoning = reasoning ,
68+ tool_calls = tool_calls ,
69+ prompt_ids = [],
70+ completion_ids = [],
71+ prompt_length = prompt_length ,
72+ completion_length = completion_length ,
73+ finish_reason = finish_reason ,
74+ )
75+
4676 except openai .RateLimitError :
4777 retries -= 1
4878 if retries == 0 :
4979 raise Exception ("Rate limit reached and retries exhausted." ) from None
5080 print ("Sleep for 5 seconds for API limit." )
5181 await asyncio .sleep (5 )
82+
5283 except Exception as e :
5384 retries -= 1
5485 if retries == 0 :
@@ -57,20 +88,58 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput:
5788 await asyncio .sleep (1 )
5889
5990 async def completion (self , prompt : str , ** kwargs ) -> ModelOutput :
91+ kwargs .pop ("application_id" , None )
92+ kwargs .pop ("validate" , None )
93+ kwargs .pop ("model" , None )
94+ enforce_max_prompt_length = kwargs .pop ("enforce_max_prompt_length" , True )
95+
6096 sampling_params = self .sampling_params .copy ()
6197 sampling_params .update (kwargs )
62- sampling_params .pop ("model" , None )
98+
99+ prompt_ids = self .tokenizer .encode (prompt , add_special_tokens = False )
100+ prompt_length = len (prompt_ids )
101+
102+ if enforce_max_prompt_length and (prompt_length > self .max_prompt_length or prompt_length > self .max_model_length ):
103+ raise TerminationEvent (TerminationReason .MAX_PROMPT_LENGTH_EXCEEDED )
104+
105+ max_tokens = sampling_params .pop ("max_tokens" , sampling_params .pop ("max_new_tokens" , self .max_response_length ))
106+ remaining_tokens = self .max_model_length - prompt_length
107+ if remaining_tokens <= max_tokens :
108+ max_tokens = remaining_tokens
109+ print (f"Warning: Decreasing max_tokens to { max_tokens } to stay within max_model_length" )
110+
63111 retries = self .api_retries
64112 while retries > 0 :
65113 try :
66- response = await self .client .completions .create (model = self .model , prompt = prompt , timeout = 3600 , ** sampling_params )
67- return ModelOutput (text = response .choices [0 ].text , tool_calls = [], finish_reason = response .choices [0 ].finish_reason , completion_tokens = response .usage .completion_tokens , prompt_tokens = response .usage .prompt_tokens )
114+ response = await self .client .completions .create (model = self .model , prompt = prompt , timeout = 3600 , max_tokens = max_tokens , ** sampling_params )
115+
116+ text = response .choices [0 ].text
117+ completion_ids = self .tokenizer .encode (text , add_special_tokens = False )
118+ parsed_output = self .chat_parser .parse_completion (completion_ids )
119+
120+ prompt_length = response .usage .prompt_tokens
121+ completion_length = response .usage .completion_tokens
122+ finish_reason = response .choices [0 ].finish_reason
123+
124+ return ModelOutput (
125+ text = text ,
126+ content = parsed_output ["content" ],
127+ reasoning = parsed_output ["reasoning" ],
128+ tool_calls = parsed_output ["tool_calls" ],
129+ prompt_ids = prompt_ids ,
130+ completion_ids = completion_ids ,
131+ prompt_length = prompt_length ,
132+ completion_length = completion_length ,
133+ finish_reason = finish_reason ,
134+ )
135+
68136 except openai .RateLimitError :
69137 retries -= 1
70138 if retries == 0 :
71139 raise Exception ("Rate limit reached and retries exhausted." ) from None
72140 print ("Sleep for 5 seconds for API limit." )
73141 await asyncio .sleep (5 )
142+
74143 except Exception as e :
75144 retries -= 1
76145 if retries == 0 :
@@ -79,13 +148,10 @@ async def completion(self, prompt: str, **kwargs) -> ModelOutput:
79148 await asyncio .sleep (1 )
80149
81150 async def get_model_response (self , messages : list [dict ], ** kwargs ) -> ModelOutput :
82- kwargs .pop ("application_id" , None ) # only needed for verl engine
83- kwargs .pop ("validate" , None ) # only needed for verl engine
84151 if self ._use_chat_completions :
85152 return await self .chat_completion (messages , ** kwargs )
86153 else :
87- prompt = self .chat_parser .parse (messages , add_generation_prompt = True , is_first_msg = True )
88- output = await self .completion (prompt , ** kwargs )
89- if self .tool_parser is not None :
90- output .tool_calls = self .tool_parser .parse (output .text )
91- return output
154+ tools = kwargs .pop ("tools" , self .tools )
155+ accumulate_reasoning = kwargs .pop ("accumulate_reasoning" , self .accumulate_reasoning )
156+ prompt = self .chat_parser .parse (messages , add_generation_prompt = True , is_first_msg = True , tools = tools , accumulate_reasoning = accumulate_reasoning )
157+ return await self .completion (prompt , ** kwargs )
0 commit comments