diff --git a/.gitignore b/.gitignore index 7b8b8949..1ceebbe7 100644 --- a/.gitignore +++ b/.gitignore @@ -202,3 +202,10 @@ CLAUDE.md examples/strands_outputs/* strands_outputs/* examples/strands/strands_outputs/* + +# Deepresearch outputs ignore +examples/deepresearch/deepresearch_outputs/* +deepresearch_outputs/* +examples/deepresearch/hle_outputs/* +*/hle_outputs/* +examples/deepresearch/HLE_OUTPUT_EVOLUTION.md diff --git a/examples/deepresearch/.env.example b/examples/deepresearch/.env.example new file mode 100644 index 00000000..406695c4 --- /dev/null +++ b/examples/deepresearch/.env.example @@ -0,0 +1,28 @@ +# DeepResearch API Configuration +# Copy this file to .env and fill in your API keys + +# OpenAI API (recommended for best performance) +OPENAI_API_KEY=your_openai_api_key_here +OPENAI_BASE_URL=https://api.openai.com/v1 +MODEL_NAME=gpt-4 + +# Alternative: Together AI (cost-effective option) +# TOGETHER_AI_API_KEY=your_together_ai_key_here +# TOGETHER_AI_MODEL_NAME=Qwen/Qwen2.5-7B-Instruct-Turbo + +# Alternative: Custom OpenAI-compatible endpoint (for vLLM hosting) +# OPENAI_API_KEY=your_custom_api_key +# OPENAI_BASE_URL=http://your-vllm-server:8000/v1 +# MODEL_NAME=your-hosted-model-name + +# Search API keys for research tools +# Serper API (required for web search functionality) +SERPER_KEY_ID=your_serper_api_key_from_serper.dev + +# Alternative: Google Custom Search API (if you prefer Google over Serper) +# GOOGLE_SEARCH_SECRET_KEY=your_google_api_key +# GOOGLE_SEARCH_ENGINE_ID=your_custom_search_engine_id + +# Evaluation settings +# DEEPRESEARCH_TASK=Custom research question to test +# GAIA_DATASET_PATH=path/to/gaia.json \ No newline at end of file diff --git a/examples/deepresearch/README.md b/examples/deepresearch/README.md new file mode 100644 index 00000000..1db5865d --- /dev/null +++ b/examples/deepresearch/README.md @@ -0,0 +1,260 @@ +# DeepResearch Integration for rLLM + +## Overview + +This module integrates Tongyi's DeepResearch ReAct agent into the rLLM framework, enabling evaluation on academic benchmarks like HLE (Humanity's Last Exam). The integration demonstrates how to port external agent architectures into rLLM's workflow system while maintaining compatibility with the training and evaluation infrastructure. + +## Architecture + +``` +DeepResearch Agent (ReAct with XML-based tool calling) + ↓ +DeepResearchWorkflow (rLLM Workflow wrapper) + ↓ +AgentWorkflowEngine (Parallel execution) + ↓ +Episode/Trajectory (rLLM data format) +``` + +### Key Components + +- **`deepresearch_agent.py`**: MultiTurnReactAgent implementing Tongyi's ReAct loop with tool calling +- **`deepresearch_workflow.py`**: Wrapper that converts agent outputs to rLLM Episodes for trajectory tracking +- **`deepresearch_tools.py`**: Tool implementations (Search, Scholar, Visit, FileParser, PythonInterpreter) +- **`evaluate_hle.py`**: Evaluation script for HLE (Humanity's Last Exam) benchmark + +## Installation + +### Prerequisites + +```bash +# Activate rLLM environment +conda activate rllm + +# Install required dependencies +pip install datasets # For HLE dataset access +pip install tiktoken # Optional: for better token counting with OpenAI models +``` + +### Environment Setup + +Create a `.env` file with your API keys: + +```bash +# For model inference (choose one) +OPENAI_API_KEY=your_openai_key +TOGETHER_AI_API_KEY=your_together_key + +# Optional: For web search tool +SERPER_API_KEY=your_serper_key # Get free key from serper.dev +``` + +## Usage + +### Running HLE Evaluation + +```bash +# Evaluate on HLE dataset with default settings +python evaluate_hle.py --hf-dataset cais/hle --max-samples 10 --parallel-tasks 4 + +# Use specific model +python evaluate_hle.py --model gpt-4o --max-samples 5 + +# Use Together AI for evaluation +python evaluate_hle.py --model Qwen/Qwen2.5-7B-Instruct-Turbo \ + --base-url https://api.together.xyz/v1 \ + --max-samples 20 + +# Custom output directory +python evaluate_hle.py --output-dir ./my_results --max-samples 20 +``` + +### Using DeepResearch Agent Directly + +```python +from rllm.engine.rollout import OpenAIEngine +from deepresearch_agent import MultiTurnReactAgent +from deepresearch_tools import get_all_tools + +# Setup rollout engine +engine = OpenAIEngine( + model="gpt-4o", + api_key="your_key", + base_url="https://api.openai.com/v1" +) + +# Create agent with tools +agent = MultiTurnReactAgent( + rollout_engine=engine, + tools=get_all_tools() +) + +# Run a research task +result = await agent.run( + question="What is the reduced 12th dimensional Spin bordism of BG2?", + answer="Z/2" # Optional ground truth for evaluation +) + +print(f"Prediction: {result['prediction']}") +print(f"Rounds: {result['rounds']}") +print(f"Time taken: {result['time_taken']}s") +``` + +### Integrating with rLLM Workflows + +```python +from rllm.engine.agent_workflow_engine import AgentWorkflowEngine +from deepresearch_workflow import DeepResearchWorkflow + +# Create workflow engine for parallel execution +workflow_engine = AgentWorkflowEngine( + workflow_cls=DeepResearchWorkflow, + workflow_args={ + "tools": get_all_tools(), + "max_prompt_length": 4096, + "max_response_length": 2048 + }, + rollout_engine=engine, + n_parallel_tasks=4 # Run 4 tasks in parallel +) + +# Run evaluation on multiple tasks +tasks = [ + {"question": "Question 1", "answer": "Answer 1"}, + {"question": "Question 2", "answer": "Answer 2"} +] + +episodes = await workflow_engine.execute_tasks(tasks) + +# Episodes contain full trajectories for training +for episode in episodes: + print(f"Task: {episode.task}") + print(f"Prediction: {episode.metrics.get('prediction')}") + print(f"Is correct: {episode.is_correct}") +``` + +## Tools + +The agent has access to the following research tools: + +| Tool | Description | Implementation Status | +| --------------------- | --------------------------- | ------------------------------------ | +| **Search** | Web search via Serper API | ✅ Fully implemented (needs API key) | +| **PythonInterpreter** | Execute Python code safely | ✅ Fully implemented with security | +| **Scholar** | Academic paper search | ❌ Placeholder only | +| **Visit** | Visit and analyze web pages | ❌ Placeholder only | +| **FileParser** | Parse various file formats | ⚠️ Basic text only (no PDF/DOCX) | + +### Tool Implementation Notes + +- **Search**: Real web search with Serper API integration. Configure API key in `.env` file +- **PythonInterpreter**: Enhanced security, 50s timeout, supports numpy/pandas when available +- **Scholar**: Returns placeholder results. Needs integration with arXiv/Google Scholar APIs +- **Visit**: Returns placeholder content. Needs requests/BeautifulSoup implementation +- **FileParser**: Only reads text files up to 5000 chars. Original supports PDF/DOCX/media files + +## Key Improvements from Original + +### 1. Token Counting Fix + +- **Problem**: Original used mismatched tokenizers (GPT-2 for GPT-4o) causing incorrect context limits +- **Solution**: Now uses OpenAI API's actual token statistics from response.prompt_tokens and response.completion_tokens +- **Impact**: No more false "context exceeded" errors at 13k tokens when limit is 128k + +### 2. Context Management + +- **Problem**: System would incorrectly truncate messages based on wrong token counts +- **Solution**: Track actual cumulative API token consumption for accurate context management +- **Impact**: Model can use full context window effectively + +### 3. System Prompt Optimization + +- **Problem**: Over-constrained prompt requiring specific tags caused unnatural responses +- **Solution**: Simplified prompt matching original Tongyi design, letting model reason naturally +- **Impact**: Better convergence, fewer infinite loops + +### 4. Parallel Execution + +- \*\*Leverages AgentWorkflowEngine for concurrent task processing +- \*\*Configurable parallelism (n_parallel_tasks parameter) +- \*\*Automatic retry on failures + +## Evaluation Results + +Evaluation results will be added after running benchmarks. The system is designed to evaluate on HLE and other academic benchmarks. + +## Known Issues and Limitations + +1. **Tool Placeholders**: Scholar and Visit tools need real implementations for research tasks +2. **Model-Specific Behavior**: + - Some models may not consistently use `` tags + - Tool calling format adherence varies by model +3. **Long Context Tasks**: Very complex research may still hit token limits +4. **Judge Accuracy**: LLM judge may not perfectly evaluate complex answers + +## Future Improvements + +- [ ] Implement real Scholar tool using arXiv/Semantic Scholar APIs +- [ ] Implement real Visit tool using requests/BeautifulSoup +- [ ] Add PDF/DOCX parsing to FileParser +- [ ] Create unified evaluation framework for multiple benchmarks +- [ ] Add more Tongyi agents (QwenCoder, etc.) +- [ ] Improve judge accuracy with better prompts + +## Project Structure + +``` +examples/deepresearch/ +├── deepresearch_agent.py # Core ReAct agent implementation +├── deepresearch_workflow.py # rLLM workflow wrapper +├── deepresearch_tools.py # Tool implementations +├── evaluate_hle.py # HLE evaluation script +├── react_agent_original.py # Original Tongyi reference +├── tool_*_original.py # Original tool references +├── hle_outputs/ # Evaluation results (git ignored) +└── README.md # This file +``` + +## Contributing + +To add new tools or improve existing ones: + +1. Implement tool in `deepresearch_tools.py` following the pattern: + + ```python + class YourTool(DeepResearchTool): + async def call(self, **kwargs) -> str: + # Your implementation + return result_string + ``` + +2. Add to `DEEPRESEARCH_TOOLS` registry + +3. Test with evaluation script + +4. Submit PR with test results + +## Related Work + +This integration is part of the rLLM evaluation framework initiative. See also: + +- `examples/strands/` - Strands agent integration +- `rllm/agents/` - Native rLLM agents +- `rllm/workflows/` - Workflow base classes + +## Citation + +If you use this integration, please cite: + +```bibtex +@misc{deepresearch2024, + title={DeepResearch: Multi-turn Research Agent}, + author={Alibaba NLP Team}, + year={2024}, + url={https://github.com/Alibaba-NLP/DeepResearch} +} +``` + +## License + +This integration follows rLLM's license. The original DeepResearch implementation is from Alibaba's Tongyi team. diff --git a/examples/deepresearch/deepresearch_agent.py b/examples/deepresearch/deepresearch_agent.py new file mode 100644 index 00000000..daade1ad --- /dev/null +++ b/examples/deepresearch/deepresearch_agent.py @@ -0,0 +1,729 @@ +""" +DeepResearch Agent - Adapted from Tongyi DeepResearch for rLLM + +This is the core ReAct agent that implements DeepResearch's reasoning and tool-calling logic, +adapted to work with rLLM's OpenAI engine instead of the original server-based approach. + +Original: https://github.com/Alibaba-NLP/DeepResearch/blob/main/inference/react_agent.py +""" + +import asyncio +import json +import time +from datetime import datetime + +# rLLM imports +from rllm.engine.rollout import RolloutEngine + +# Constants from original DeepResearch +OBS_START = "" +OBS_END = "\n" +MAX_LLM_CALL_PER_RUN = 100 + +# System prompt adapted from DeepResearch +DEEPRESEARCH_SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You MUST use the provided tools to research and verify information before answering. Do NOT answer directly from memory - always use tools to gather current, accurate information. + +IMPORTANT: You are REQUIRED to use at least one tool before providing any answer. Even if you think you know the answer, you must verify it using the appropriate tools. Direct answers without tool use are not acceptable. + +When you have gathered sufficient information through tool use and are ready to provide the definitive response, you must enclose the entire final answer within tags. + +# Tools + +You MUST use one or more of the following tools to research the query: + +You are provided with the following tools: +- Search: for web searches to find current information +- Scholar: for academic research and paper searches +- Visit: for visiting and analyzing web pages +- PythonInterpreter: for running Python code and calculations +- FileParser: for reading and analyzing files + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + + +For Python code execution, use: + +python + +# Your Python code here +print("Hello World") + + + +Current date: """ + + +def today_date(): + """Get today's date in YYYY-MM-DD format.""" + return datetime.now().date().strftime("%Y-%m-%d") + + +def build_text_completion_prompt(messages: list[dict], allow_special: bool = True) -> str: + """ + Build text completion prompt from messages list. + Adapted from qwen_agent.utils.utils.build_text_completion_prompt + + Args: + messages: List of message dictionaries with 'role' and 'content' keys + allow_special: Whether to allow special tokens (for compatibility) + + Returns: + Formatted prompt string + """ + im_start = "<|im_start|>" + im_end = "<|im_end|>" + + prompt_parts = [] + + # Handle system message + if messages and messages[0]["role"] == "system": + sys_content = messages[0]["content"] + prompt_parts.append(f"{im_start}system\n{sys_content}{im_end}") + messages = messages[1:] + + # Ensure chat completes with assistant + if messages and messages[-1]["role"] != "assistant": + messages = messages + [{"role": "assistant", "content": ""}] + + # Format each message + for msg in messages: + role = msg["role"] + content = msg["content"] + prompt_parts.append(f"{im_start}{role}\n{content}{im_end}") + + return "\n".join(prompt_parts) + + +class MultiTurnReactAgent: + """ + Multi-turn ReAct Agent adapted from Tongyi DeepResearch. + + This agent implements the core reasoning loop with tool calling capabilities, + using rLLM's OpenAI engine for model inference. + """ + + def __init__( + self, + rollout_engine: RolloutEngine, + tools: dict = None, + system_prompt: str | None = None, + use_native_function_calling: bool = False, + **kwargs, + ): + """ + Initialize the ReAct agent. + + Args: + rollout_engine: rLLM OpenAI engine for model inference + tools: Dictionary of available tools {tool_name: tool_instance} + system_prompt: Optional custom system prompt + use_native_function_calling: Whether to use OpenAI native function calling (supports o3) + """ + self.rollout_engine = rollout_engine + self.tools = tools or {} + self.system_prompt = system_prompt + self.use_native_function_calling = use_native_function_calling + + # Convert tools to OpenAI format if using native function calling + if use_native_function_calling and self.tools: + self.openai_tools = [tool.json for tool in self.tools.values()] + else: + self.openai_tools = None + + # Configuration from original DeepResearch + self.max_llm_calls = MAX_LLM_CALL_PER_RUN + self.max_time = 150 * 60 # 150 minutes timeout + + # Smart context management using actual API consumption + self.total_prompt_tokens = 0 + self.total_completion_tokens = 0 + + # Auto-detect context limit based on model capabilities + # This ensures we don't hit limits too early for capable models + self.max_context_tokens = self._get_model_context_limit(rollout_engine) + + def _get_model_context_limit(self, rollout_engine) -> int: + """ + Auto-detect context limit based on model capabilities. + Uses LiteLLM's model info when available, falls back to conservative estimates. + Returns 90% of max to leave safety headroom. + """ + model_name = rollout_engine.model + + # Method 1: Try LiteLLM's get_model_info (most accurate) + try: + import litellm + + model_info = litellm.get_model_info(model_name) + if model_info and "max_input_tokens" in model_info: + max_tokens = model_info["max_input_tokens"] + conservative_limit = int(max_tokens * 0.90) # Use 90% for safety + if not hasattr(MultiTurnReactAgent, "_context_limit_reported"): + print(f" 📏 Detected context window: {max_tokens:,} tokens (using 90% = {conservative_limit:,})") + MultiTurnReactAgent._context_limit_reported = True + return conservative_limit + except Exception: + # LiteLLM might not have info for all models, that's ok + pass + + # Method 2: Try tiktoken to get model family info + try: + import tiktoken + + # tiktoken.encoding_for_model will throw if model unknown + encoding = tiktoken.encoding_for_model(model_name) + # Map known encodings to context limits + encoding_limits = { + "cl100k_base": 128 * 1024, # GPT-4, GPT-3.5-turbo-16k + "p50k_base": 4 * 1024, # text-davinci-002/003 + "r50k_base": 4 * 1024, # GPT-3 base models + } + if encoding.name in encoding_limits: + max_tokens = encoding_limits[encoding.name] + conservative_limit = int(max_tokens * 0.90) + if not hasattr(MultiTurnReactAgent, "_context_limit_reported"): + print(f" 📏 Inferred context from encoding '{encoding.name}': {conservative_limit:,} tokens") + MultiTurnReactAgent._context_limit_reported = True + return conservative_limit + except Exception: + pass + + # Method 3: Pattern matching fallback (least accurate but works) + model_lower = model_name.lower() + fallback_limits = { + # OpenAI reasoning models + ("o3", "o1"): 128 * 1024, + # GPT-4 family + ("gpt-4o", "gpt-4-turbo"): 128 * 1024, + ("gpt-4-32k",): 32 * 1024, + ("gpt-4",): 8 * 1024, + # Claude family + ("claude-3-5", "claude-3.5"): 200 * 1024, + ("claude-3",): 200 * 1024, + ("claude-2",): 100 * 1024, + # Gemini family + ("gemini-1.5", "gemini-2"): 1000 * 1024, + ("gemini",): 32 * 1024, + # Qwen + ("qwen2", "qwen-2"): 128 * 1024, + ("qwen",): 32 * 1024, + } + + for patterns, max_tokens in fallback_limits.items(): + if any(pattern in model_lower for pattern in patterns): + conservative_limit = int(max_tokens * 0.90) + if not hasattr(MultiTurnReactAgent, "_context_limit_reported"): + print(f" 📏 Pattern-matched context limit: {conservative_limit:,} tokens (90% of {max_tokens:,})") + MultiTurnReactAgent._context_limit_reported = True + return conservative_limit + + # Method 4: Ultimate fallback + default_limit = 100 * 1024 + if not hasattr(MultiTurnReactAgent, "_context_limit_reported"): + print(f" ⚠️ Unknown model '{model_name}', using conservative default: {default_limit:,} tokens") + MultiTurnReactAgent._context_limit_reported = True + return default_limit + + def sanity_check_output(self, content: str) -> bool: + """Check if the model output contains the expected thinking structure.""" + return "" in content and "" in content + + async def call_server(self, messages: list[dict], max_tries: int = 10): + """ + Call rLLM OpenAI engine with hybrid mode support. + + Supports both: + - Native function calling (for o3, gpt-4-turbo) + - ReAct text format (for gpt-4o, Claude) + + Args: + messages: List of chat completion messages + max_tries: Maximum number of retry attempts + + Returns: + ModelOutput with text and tool_calls + """ + for attempt in range(max_tries): + try: + # Base parameters + api_params = {"messages": messages} + + # Model-specific parameter configuration + model_name = self.rollout_engine.model.lower() + + if "o3" in model_name or "o1" in model_name: + # O3/O1: Very limited parameter support + api_params["max_completion_tokens"] = 4096 + elif "gpt-4" in model_name: + # GPT-4: Full parameter support + api_params.update( + { + "stop": ["\n", ""], + "temperature": 0.6, + "top_p": 0.95, + "max_tokens": 4096, + "presence_penalty": 1.1, + } + ) + elif "qwen" in model_name: + # Qwen models + api_params.update( + { + "temperature": 0.6, + "top_p": 0.95, + "max_tokens": 4096, + } + ) + else: + # Fallback: Conservative params + api_params.update( + { + "temperature": 0.6, + "max_tokens": 4096, + } + ) + + # Add tools parameter for native function calling + if self.use_native_function_calling and self.openai_tools: + api_params["tools"] = self.openai_tools + api_params["tool_choice"] = "auto" + + # Call rLLM OpenAI Engine + response = await self.rollout_engine.get_model_response(**api_params) + + # Track actual token consumption from API + if hasattr(response, "prompt_tokens") and hasattr(response, "completion_tokens"): + self.total_prompt_tokens += response.prompt_tokens + self.total_completion_tokens += response.completion_tokens + + # Return full ModelOutput (contains both text and tool_calls) + return response + + except Exception as e: + print(f"Error: Attempt {attempt + 1} failed: {e}") + if attempt < max_tries - 1: + # Exponential backoff + sleep_time = 2**attempt + print(f"Waiting {sleep_time} seconds before retry...") + await asyncio.sleep(sleep_time) + + raise Exception(f"Failed to get response after {max_tries} attempts") + + def get_total_tokens_used(self) -> int: + """ + Get total tokens consumed so far from actual API usage. + This is much more accurate than any tokenizer estimation. + + Returns: + Total tokens used (prompt + completion) + """ + return self.total_prompt_tokens + self.total_completion_tokens + + async def _run(self, question: str, answer: str = None, images: list = None, **kwargs) -> dict: + """ + Main reasoning loop adapted from original DeepResearch. + + This is the core ReAct implementation that handles: + - Multi-turn conversation + - Tool calling and execution + - Context length management + - Termination conditions + + Args: + question: The research question to answer + answer: Ground truth answer (for evaluation) + images: List of image data URLs (base64 encoded) + + Returns: + Dictionary with results including messages, prediction, and termination reason + """ + start_time = time.time() + + # Setup system prompt with current date + system_prompt = (self.system_prompt or DEEPRESEARCH_SYSTEM_PROMPT) + today_date() + + # Construct initial user message (multimodal if images present) + if images: + # Build multimodal message with images + user_content = [{"type": "text", "text": question}] + for image_data in images: + user_content.append({"type": "image_url", "image_url": {"url": image_data}}) + user_message = {"role": "user", "content": user_content} + else: + # Plain text message + user_message = {"role": "user", "content": question} + + messages = [ + {"role": "system", "content": system_prompt}, + user_message, + ] + + num_llm_calls_available = self.max_llm_calls + round = 0 + termination = None + prediction = "" + + # Truncate question for display + q_display = str(question).replace("\n", " ").strip() + if len(q_display) > 200: + q_display = q_display[:200] + "..." + print(f"🔍 Starting DeepResearch for question: {q_display}") + + while num_llm_calls_available > 0: + # Check time limit (150 minutes) + if time.time() - start_time > self.max_time: + prediction = "No answer found after 2h30mins" + termination = "No answer found after 2h30mins" + result = { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination, + } + return result + + round += 1 + num_llm_calls_available -= 1 + + # Get model response (ModelOutput with text and tool_calls) + response = await self.call_server(messages) + + # Extract text content (may be None for pure function calling) + content = response.text if hasattr(response, "text") and response.text else "" + + # Debug: Print raw model response to see format + if round == 1: + print(f"[DEBUG] Raw model response (first 500 chars): {content[:500]}") + if hasattr(response, "tool_calls") and response.tool_calls: + print(f"[DEBUG] Native tool_calls detected: {len(response.tool_calls)} call(s)") + + # Print concise round info with truncation + MAX_PRINT_LENGTH = 200 + + # Simple truncation for all prints + def truncate(text, max_len=MAX_PRINT_LENGTH): + text = str(text).replace("\n", " ").strip() + # Special handling for base64 images + if "data:image" in text or ";base64," in text: + # Find the base64 part and truncate it + if "base64," in text: + parts = text.split("base64,", 1) + return parts[0] + "base64,[truncated]" + return "[base64 image data]" + if len(text) > max_len: + return text[:max_len] + "..." + return text + + # Print round info based on content type + if "" in content: + # Extract tool name for display + if "python" in content.lower() and "" in content: + print(f"Round {round}: 🐍 Executing Python code") + elif '"name":' in content: + try: + import json5 + + tool_text = content.split("")[1].split("")[0] + tool_text = tool_text[:1000] # Limit for parsing + tool_data = json5.loads(tool_text) + tool_name = tool_data.get("name", "Unknown") + if "arguments" in tool_data: + args_str = truncate(str(tool_data["arguments"]), 100) + print(f"Round {round}: 🔧 Calling {tool_name} with args: {args_str}") + else: + print(f"Round {round}: 🔧 Calling {tool_name}") + except Exception: + print(f"Round {round}: 🔧 Tool call") + else: + print(f"Round {round}: 🔧 Tool call") + elif "" in content: + # Final answer + answer_preview = content.split("")[1].split("")[0] + print(f"Round {round}: ✅ Final answer: {truncate(answer_preview, 100)}") + else: + # Show internal reasoning if available, otherwise show content + if hasattr(response, "reasoning") and response.reasoning: + reasoning_preview = truncate(response.reasoning, 300) + print(f"Round {round}: 💭 [Internal] {reasoning_preview}") + elif content: + print(f"Round {round}: 💭 Reasoning: {truncate(content)}") + + # Clean up content if it contains tool_response + if "" in content: + pos = content.find("") + content = content[:pos] + + # HYBRID MODE: Handle both native tool_calls and ReAct text format + + # Priority 1: Check for native function calling (o3, gpt-4-turbo) + if hasattr(response, "tool_calls") and response.tool_calls: + # Native function calling path - build ALL messages first, then append atomically + tool_calls_formatted = [] + tool_responses = [] + + for tool_call in response.tool_calls: + try: + # Extract tool info from OpenAI format + tool_id = tool_call.id if hasattr(tool_call, "id") else "unknown" + function = tool_call.function if hasattr(tool_call, "function") else tool_call.get("function", {}) + tool_name = function.name if hasattr(function, "name") else function.get("name", "") + arguments_str = function.arguments if hasattr(function, "arguments") else function.get("arguments", "{}") + + # Parse arguments + tool_args = json.loads(arguments_str) if isinstance(arguments_str, str) else arguments_str + + # Print tool call with arguments (for consistency with ReAct format) + def truncate(text, max_len=100): + text = str(text).replace("\n", " ").strip() + if len(text) > max_len: + return text[:max_len] + "..." + return text + + args_str = truncate(str(tool_args), 100) + print(f"Round {round}: 🔧 [Native] Calling {tool_name} with args: {args_str}") + + # Execute tool + result = await self.custom_call_tool(tool_name, tool_args) + + # Collect tool call and response (don't append yet) + tool_calls_formatted.append( + { + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": arguments_str, + }, + } + ) + tool_responses.append({"role": "tool", "tool_call_id": tool_id, "content": result}) + + except Exception as e: + print(f"Error processing native tool call: {e}") + # On error, append error message and skip this tool call + messages.append({"role": "assistant", "content": content.strip()}) + messages.append({"role": "user", "content": f"Tool call error: {e}"}) + continue + + # Only append to messages if we have successful tool calls + if tool_calls_formatted: + # Add assistant message with ALL tool calls at once + messages.append( + { + "role": "assistant", + "content": content or "", # May be empty for pure function calling + "tool_calls": tool_calls_formatted, + } + ) + # Add all tool responses + messages.extend(tool_responses) + + # Priority 2: Check for ReAct text format (gpt-4o, Claude) + elif "" in content and "" in content: + # ReAct text format path + messages.append({"role": "assistant", "content": content.strip()}) + + tool_call_text = content.split("")[1].split("")[0] + try: + # Special handling for Python code (match original logic) + if "python" in tool_call_text.lower(): + try: + # Extract code from the original content (not just tool_call_text) + code_raw = content.split("")[1].split("")[0].split("")[1].split("")[0].strip() + result = await self.execute_python(code_raw) + except Exception: + result = "[Python Interpreter Error]: Formatting error." + else: + # Parse JSON tool call + tool_call = json5.loads(tool_call_text) + tool_name = tool_call.get("name", "") + tool_args = tool_call.get("arguments", {}) + result = await self.custom_call_tool(tool_name, tool_args) + + except Exception: + result = 'Error: Tool call is not a valid JSON. Tool call must contain a valid "name" and "arguments" field.' + + # Add tool response in ReAct format + tool_response = f"\n{result}\n" + messages.append({"role": "user", "content": tool_response}) + + # Priority 3: No tool call, just reasoning or answer + else: + messages.append({"role": "assistant", "content": content.strip()}) + + # Check for final answer AFTER processing tools + # This allows o3 to execute tools even when it includes answer in same message + if "" in content and "" in content: + prediction = content.split("")[1].split("")[0].strip() + termination = "answer" + break + + # Check if we've exceeded call limit + if num_llm_calls_available <= 0 and "" not in content: + # Handle both message formats + if isinstance(messages[-1], dict) and "content" in messages[-1]: + messages[-1]["content"] = "Sorry, the number of llm calls exceeds the limit." + + # Handle context length limit using actual API consumption + total_tokens_used = self.get_total_tokens_used() + + if total_tokens_used > self.max_context_tokens: + # Instead of replacing the last message, add a clear instruction + final_instruction = { + "role": "user", + "content": "You have reached the maximum context length. Based on all the information above, please provide your best answer now in the format: your final thinking\nyour answer", + } + + # Truncate conversation history to make room for final answer + # Keep system prompt, original question, and recent context + if len(messages) > 4: # system + user + at least 2 exchanges + # Keep first 2 messages (system + original question) and last 2 meaningful exchanges + truncated_messages = messages[:2] # system + original question + recent_messages = messages[-4:] # last 4 messages for context + truncated_messages.extend(recent_messages) + messages = truncated_messages + + messages.append(final_instruction) + + # Note: After truncation, we'll let the next API call handle any remaining limits + print(f"Round {round + 1}: ⚠️ Context limit reached, requesting final answer") + + response = await self.call_server(messages) + content = response.text if hasattr(response, "text") and response.text else "" + messages.append({"role": "assistant", "content": content.strip()}) + + if "" in content and "" in content: + prediction = content.split("")[1].split("")[0].strip() + termination = "answer generated due to token limit" + else: + prediction = content.strip() + termination = "response generated due to token limit (no answer format)" + + result = { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination, + } + return result + + # Final validation logic from original Tongyi implementation + # Handle both native function calling and ReAct text format + last_message_content = messages[-1].get("content", "") if isinstance(messages[-1], dict) else "" + if last_message_content and "" in last_message_content: + prediction = last_message_content.split("")[1].split("")[0] + termination = "answer" + else: + prediction = "No answer found." + termination = "answer not found" + if num_llm_calls_available == 0: + termination = "exceed available llm calls" + + # Final result + result = { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination, + "rounds": round, + "time_taken": time.time() - start_time, + } + + print("\n🏁 DeepResearch completed:") + print(f" Rounds: {round}") + print(f" Time: {result['time_taken']:.1f}s") + print(f" Termination: {termination}") + # Truncate prediction for display + pred_display = str(prediction).replace("\n", " ").strip() + if len(pred_display) > 200: + pred_display = pred_display[:200] + "..." + print(f" Prediction: {pred_display}") + + return result + + async def custom_call_tool(self, tool_name: str, tool_args: dict, **kwargs) -> str: + """ + Execute tool calls with the available tools. + + Args: + tool_name: Name of the tool to call + tool_args: Arguments to pass to the tool + + Returns: + Tool execution result as string + """ + if tool_name in self.tools: + try: + # Call the tool + if hasattr(self.tools[tool_name], "call"): + # Async tool + if asyncio.iscoroutinefunction(self.tools[tool_name].call): + result = await self.tools[tool_name].call(**tool_args) + else: + result = self.tools[tool_name].call(**tool_args) + elif callable(self.tools[tool_name]): + # Direct callable + result = self.tools[tool_name](**tool_args) + else: + result = f"Tool {tool_name} is not callable" + + return str(result) + + except Exception as e: + return f"Error calling tool {tool_name}: {e}" + else: + available_tools = list(self.tools.keys()) + return f"Tool {tool_name} not found. Available tools: {available_tools}" + + async def execute_python(self, code: str) -> str: + """ + Execute Python code using the PythonInterpreter tool. + + Args: + code: Python code to execute + + Returns: + Execution result as string + """ + if "PythonInterpreter" in self.tools: + try: + # Use the PythonInterpreter tool + tool = self.tools["PythonInterpreter"] + if hasattr(tool, "call"): + if asyncio.iscoroutinefunction(tool.call): + result = await tool.call(code=code) + else: + result = tool.call(code=code) + return str(result) + else: + return "PythonInterpreter tool is not callable" + except Exception as e: + return f"Python execution error: {e}" + else: + return "PythonInterpreter tool not available" + + def reset(self): + """Reset the agent state (for compatibility with rLLM workflow).""" + # Reset token counters for each new task + self.total_prompt_tokens = 0 + self.total_completion_tokens = 0 + + async def run(self, question: str, answer: str = None, **kwargs) -> dict: + """ + Public interface for running the agent. + + Args: + question: Research question to answer + answer: Ground truth answer (optional, for evaluation) + + Returns: + Result dictionary + """ + # Reset token counters for each new run + self.reset() + return await self._run(question, answer, **kwargs) diff --git a/examples/deepresearch/deepresearch_tools.py b/examples/deepresearch/deepresearch_tools.py new file mode 100644 index 00000000..fd10cbaf --- /dev/null +++ b/examples/deepresearch/deepresearch_tools.py @@ -0,0 +1,750 @@ +""" +DeepResearch Tools - Production-ready implementations + +This module provides tool implementations for the DeepResearch agent, with real +functionality ported from Tongyi's original implementations where possible. + +Now supports both: +- ReAct text format (for gpt-4o, Claude, etc.) +- OpenAI native function calling (for o3, o3-mini, etc.) +""" + +import http.client +import json +import os +from abc import ABC, abstractmethod + +from rllm.tools.tool_base import Tool as RLLMTool + + +class DeepResearchTool(RLLMTool, ABC): + """ + Base class for all DeepResearch tools. + + Inherits from rLLM's Tool to support OpenAI native function calling, + while maintaining compatibility with ReAct text format. + """ + + def __init__(self, name: str, description: str, parameters: dict | None = None): + """ + Initialize DeepResearch tool with OpenAI function calling support. + + Args: + name: Tool name + description: Tool description + parameters: OpenAI-style parameter schema (optional) + """ + # Set _json BEFORE calling super().__init__ + # because the parent's __init__ may access self.json + self._json = { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": parameters or {"type": "object", "properties": {}, "required": []}, + }, + } + + super().__init__(name=name, description=description) + + @abstractmethod + async def call(self, **kwargs) -> str: + """Execute the tool with given arguments.""" + pass + + async def async_forward(self, **kwargs): + """rLLM Tool interface - delegates to call()""" + from rllm.tools.tool_base import ToolOutput + + try: + result = await self.call(**kwargs) + return ToolOutput(name=self.name, output=result) + except Exception as e: + return ToolOutput(name=self.name, error=f"{type(e).__name__} - {str(e)}") + + +class SearchTool(DeepResearchTool): + """Web search tool using Serper API (ported from Tongyi).""" + + def __init__(self): + super().__init__( + name="Search", + description="Performs web searches using Google via Serper API", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string", + } + }, + "required": ["query"], + }, + ) + + def contains_chinese(self, text: str) -> bool: + """Check if text contains Chinese characters.""" + return any("\u4e00" <= char <= "\u9fff" for char in text) + + def _google_search_fallback(self, query: str | list) -> str: + """Use Google Custom Search API as fallback.""" + try: + import requests + + google_key = os.getenv("GOOGLE_SEARCH_SECRET_KEY") + engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID") + + queries = [query] if isinstance(query, str) else query + all_results = [] + + for q in queries: + params = {"key": google_key, "cx": engine_id, "q": q, "num": 10} + + response = requests.get( + "https://customsearch.googleapis.com/customsearch/v1", + params=params, + timeout=5, + ) + + if response.status_code == 200: + data = response.json() + items = data.get("items", []) + + web_snippets = [] + for idx, item in enumerate(items[:10], 1): + title = item.get("title", "") + link = item.get("link", "") + snippet = item.get("snippet", "") + entry = f"{idx}. [{title}]({link})\n {snippet}" + web_snippets.append(entry) + + result = f"Google search for '{q}' found {len(web_snippets)} results:\n\n" + "\n\n".join(web_snippets) + all_results.append(result) + else: + all_results.append(f"Google search error for '{q}': {response.status_code}") + + return "\n=======\n".join(all_results) if len(all_results) > 1 else all_results[0] + + except Exception as e: + return f"Google search fallback error: {e}" + + async def call(self, query: str | list, **kwargs) -> str: + """ + Search the web using Serper API or Google Custom Search. + + Args: + query: Search query string or list of queries + + Returns: + Formatted search results + """ + api_key = os.getenv("SERPER_API_KEY") + + # Try Google Custom Search as fallback if no Serper key + if not api_key: + google_key = os.getenv("GOOGLE_SEARCH_SECRET_KEY") + google_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID") + + if google_key and google_engine_id: + return self._google_search_fallback(query) + + return f"""[Search - API Key Required] + +To enable real web search, use one of these options: + +Option 1 - Serper (Recommended, simpler): +1. Get a free API key from https://serper.dev (2500 searches/month free) +2. Add to .env: SERPER_API_KEY=your_key_here + +Option 2 - Google Custom Search: +1. Set up at https://developers.google.com/custom-search +2. Add to .env: + GOOGLE_SEARCH_SECRET_KEY=your_key + GOOGLE_SEARCH_ENGINE_ID=your_engine_id + +Placeholder results for '{query}'...""" + + # Handle single query or list + queries = [query] if isinstance(query, str) else query + all_results = [] + + for q in queries: + try: + conn = http.client.HTTPSConnection("google.serper.dev") + + # Localize for Chinese queries + if self.contains_chinese(q): + payload = json.dumps({"q": q, "location": "China", "gl": "cn", "hl": "zh-cn"}) + else: + payload = json.dumps({"q": q, "location": "United States", "gl": "us", "hl": "en"}) + + headers = {"X-API-KEY": api_key, "Content-Type": "application/json"} + + # Retry logic + for i in range(5): + try: + conn.request("POST", "/search", payload, headers) + res = conn.getresponse() + break + except Exception: + if i == 4: + all_results.append(f"Google search timeout for '{q}'") + continue + + data = res.read() + results = json.loads(data.decode("utf-8")) + + if "organic" not in results: + all_results.append(f"No results found for '{q}'") + continue + + web_snippets = [] + for idx, page in enumerate(results.get("organic", [])[:10], 1): + date_published = f"\nDate: {page['date']}" if "date" in page else "" + source = f"\nSource: {page['source']}" if "source" in page else "" + snippet = f"\n{page['snippet']}" if "snippet" in page else "" + + entry = f"{idx}. [{page.get('title', 'Untitled')}]({page.get('link', '')}){date_published}{source}{snippet}" + web_snippets.append(entry) + + content = f"Google search for '{q}' found {len(web_snippets)} results:\n\n" + "\n\n".join(web_snippets) + all_results.append(content) + + except Exception as e: + all_results.append(f"Search error for '{q}': {e}") + + return "\n=======\n".join(all_results) if len(all_results) > 1 else all_results[0] + + +class ScholarTool(DeepResearchTool): + """Google Scholar search using Serper API (ported from Tongyi).""" + + def __init__(self): + super().__init__( + name="Scholar", + description="Search Google Scholar for academic papers", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The academic search query", + } + }, + "required": ["query"], + }, + ) + + async def call(self, query: str | list, **kwargs) -> str: + """ + Search Google Scholar using Serper API. + + Args: + query: Search query string or list of queries + + Returns: + Academic search results + """ + api_key = os.getenv("SERPER_API_KEY") + if not api_key: + return """[Scholar - API Key Required] + +To enable Google Scholar search, configure SERPER_API_KEY in your .env file.""" + + queries = [query] if isinstance(query, str) else query + all_results = [] + + for q in queries: + try: + conn = http.client.HTTPSConnection("google.serper.dev") + payload = json.dumps({"q": q, "type": "scholar", "num": 10}) + headers = {"X-API-KEY": api_key, "Content-Type": "application/json"} + + conn.request("POST", "/scholar", payload, headers) + res = conn.getresponse() + data = res.read() + results = json.loads(data.decode("utf-8")) + + if "organic" not in results: + all_results.append(f"No scholar results for '{q}'") + continue + + papers = [] + for idx, paper in enumerate(results.get("organic", [])[:10], 1): + title = paper.get("title", "Untitled") + link = paper.get("link", "") + snippet = paper.get("snippet", "") + publication = paper.get("publication", "") + year = paper.get("year", "") + cited_by = paper.get("citedBy", {}).get("value", 0) + + entry = f"{idx}. [{title}]({link})" + if publication: + entry += f"\n Publication: {publication}" + if year: + entry += f" ({year})" + if cited_by: + entry += f"\n Cited by: {cited_by}" + if snippet: + entry += f"\n {snippet}" + + papers.append(entry) + + result_text = f"Google Scholar search for '{q}':\n\n" + "\n\n".join(papers) + all_results.append(result_text) + + except Exception as e: + all_results.append(f"Scholar search error for '{q}': {e}") + + return "\n=======\n".join(all_results) if len(all_results) > 1 else all_results[0] + + +class VisitTool(DeepResearchTool): + """Web page visiting with content extraction.""" + + def __init__(self): + super().__init__( + name="Visit", + description="Visit and extract content from web pages", + parameters={ + "type": "object", + "properties": { + "url": {"type": "string", "description": "The URL to visit"}, + "goal": { + "type": "string", + "description": "Optional goal for the visit", + }, + }, + "required": ["url"], + }, + ) + + async def call(self, url: str | list, goal: str = "", **kwargs) -> str: + """ + Visit web pages and extract content. + + Args: + url: URL string or list of URLs + goal: Optional goal for the visit + + Returns: + Extracted webpage content + """ + try: + import requests + from bs4 import BeautifulSoup + except ImportError: + return """[Visit Tool - Dependencies Required] + +To enable webpage visiting: +pip install requests beautifulsoup4 + +Then the tool will fetch and parse webpage content.""" + + import re + from urllib.parse import urlparse + + urls = [url] if isinstance(url, str) else url + all_results = [] + + for target_url in urls[:5]: # Limit to 5 URLs + try: + # Validate and normalize URL + parsed = urlparse(target_url) + if not parsed.scheme: + target_url = f"https://{target_url}" + + # Fetch webpage + headers = {"User-Agent": "Mozilla/5.0 (compatible; DeepResearch/1.0)"} + response = requests.get(target_url, headers=headers, timeout=10) + response.raise_for_status() + + # Parse HTML + soup = BeautifulSoup(response.text, "html.parser") + + # Remove unwanted elements + for element in soup(["script", "style", "nav", "footer", "header", "aside"]): + element.decompose() + + # Extract title + title = soup.title.string if soup.title else "No Title" + + # Extract main content + content = "" + for selector in ["main", "article", ".content", "#content", ".post"]: + element = soup.select_one(selector) + if element: + content = element.get_text(separator="\n", strip=True) + break + + if not content: + body = soup.find("body") + if body: + content = body.get_text(separator="\n", strip=True) + + # Clean up text + content = re.sub(r"\n{3,}", "\n\n", content) + content = re.sub(r" {2,}", " ", content) + + # Limit length + if len(content) > 5000: + content = content[:5000] + "\n[Content truncated...]" + + # Format result + result = f"[Webpage: {target_url}]\nTitle: {title}" + if goal: + result += f"\nGoal: {goal}" + result += f"\n\nContent:\n{content}" + + all_results.append(result) + + except Exception as e: + all_results.append(f"[Error visiting {target_url}]: {e}") + + return "\n\n=======\n\n".join(all_results) + + +class FileParserTool(DeepResearchTool): + """Enhanced file parsing for multiple formats.""" + + def __init__(self): + super().__init__( + name="FileParser", + description="Parse files: TXT, JSON, CSV, PDF, DOCX, etc.", + parameters={ + "type": "object", + "properties": { + "files": { + "type": "string", + "description": "File path or list of file paths to parse", + } + }, + "required": ["files"], + }, + ) + + async def call(self, files: str | list, **kwargs) -> str: + """ + Parse files and extract content. + + Args: + files: File path string or list of paths + + Returns: + Extracted file content + """ + import csv + from pathlib import Path + + file_paths = [files] if isinstance(files, str) else files + all_results = [] + + for file_path in file_paths[:10]: # Limit to 10 files + if not os.path.exists(file_path): + all_results.append(f"Error: File not found at {file_path}") + continue + + try: + file_ext = Path(file_path).suffix.lower() + file_name = os.path.basename(file_path) + file_size = os.path.getsize(file_path) + + content = "" + + # Text files + if file_ext in [ + ".txt", + ".md", + ".log", + ".py", + ".js", + ".java", + ".cpp", + ".c", + ".h", + ]: + with open(file_path, encoding="utf-8", errors="ignore") as f: + content = f.read() + + # JSON files + elif file_ext == ".json": + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + content = json.dumps(data, indent=2, ensure_ascii=False) + + # CSV files + elif file_ext == ".csv": + rows = [] + with open(file_path, encoding="utf-8", errors="ignore") as f: + reader = csv.reader(f) + for i, row in enumerate(reader): + if i >= 100: + rows.append("[... truncated ...]") + break + rows.append(", ".join(row)) + content = "\n".join(rows) + + # PDF files + elif file_ext == ".pdf": + try: + import PyPDF2 + + with open(file_path, "rb") as f: + pdf_reader = PyPDF2.PdfReader(f) + pages = [] + for i in range(min(len(pdf_reader.pages), 10)): + page = pdf_reader.pages[i] + pages.append(f"Page {i + 1}:\n{page.extract_text()}") + content = "\n\n".join(pages) + except ImportError: + content = "[PDF parsing requires: pip install PyPDF2]" + + # Word documents + elif file_ext in [".docx", ".doc"]: + try: + from docx import Document + + doc = Document(file_path) + paragraphs = [] + for i, para in enumerate(doc.paragraphs): + if i >= 100: + paragraphs.append("[... truncated ...]") + break + if para.text.strip(): + paragraphs.append(para.text) + content = "\n\n".join(paragraphs) + except ImportError: + content = "[DOCX parsing requires: pip install python-docx]" + + # Default: try as text + else: + try: + with open(file_path, encoding="utf-8", errors="ignore") as f: + content = f.read() + except Exception: + content = f"[Cannot parse file type: {file_ext}]" + + # Limit content + if len(content) > 10000: + content = content[:10000] + "\n[Content truncated...]" + + result = f"[File: {file_name}]\nType: {file_ext}\nSize: {file_size:,} bytes\n\nContent:\n{content}" + all_results.append(result) + + except Exception as e: + all_results.append(f"Error parsing {file_path}: {e}") + + return "\n\n=======\n\n".join(all_results) + + +class PythonInterpreterTool(DeepResearchTool): + """Safe Python code execution (from existing implementation).""" + + def __init__(self): + super().__init__( + name="PythonInterpreter", + description="Execute Python code for calculations and analysis", + parameters={ + "type": "object", + "properties": {"code": {"type": "string", "description": "Python code to execute"}}, + "required": ["code"], + }, + ) + self.timeout = 50 + + async def call(self, code: str, timeout: int = None, **kwargs) -> str: + """Execute Python code safely with timeout.""" + timeout = timeout or self.timeout + + # Security checks - check for dangerous imports/operations + dangerous_patterns = [ + "import os", + "import subprocess", + "import sys", + "from os import", + "from subprocess import", + "from sys import", + "exec(", + "eval(", + "compile(", + "open(", + "file(", + ] + + code_lower = code.lower() + for pattern in dangerous_patterns: + if pattern in code_lower: + return f"[Security Error] '{pattern}' not allowed for safety reasons" + + import io + import sys + from concurrent.futures import ThreadPoolExecutor, TimeoutError + + # Setup safe environment + allowed_modules = { + "math": __import__("math"), + "datetime": __import__("datetime"), + "json": __import__("json"), + "random": __import__("random"), + "re": __import__("re"), + "collections": __import__("collections"), + "itertools": __import__("itertools"), + "statistics": __import__("statistics"), + } + + # Add numpy/pandas if available + try: + import numpy as np + + allowed_modules["numpy"] = np + allowed_modules["np"] = np + except ImportError: + pass + + try: + import pandas as pd + + allowed_modules["pandas"] = pd + allowed_modules["pd"] = pd + except ImportError: + pass + + # Restricted builtins with safe import capability + def safe_import(name, *args, **kwargs): + """Allow importing only safe modules.""" + safe_modules = [ + "math", + "datetime", + "json", + "random", + "re", + "collections", + "itertools", + "statistics", + "numpy", + "pandas", + "scipy", + "scipy.linalg", # Add scipy submodules + "scipy.optimize", + "scipy.signal", + "scipy.special", + "matplotlib", + "matplotlib.pyplot", + ] + # Check if the module or its parent is allowed + if name in safe_modules or any(name.startswith(m + ".") for m in safe_modules): + return __import__(name, *args, **kwargs) + else: + raise ImportError(f"Module '{name}' is not allowed for safety reasons") + + restricted_builtins = { + "abs": abs, + "all": all, + "any": any, + "bin": bin, + "bool": bool, + "chr": chr, + "dict": dict, + "enumerate": enumerate, + "filter": filter, + "float": float, + "hex": hex, + "int": int, + "len": len, + "list": list, + "map": map, + "max": max, + "min": min, + "oct": oct, + "ord": ord, + "pow": pow, + "print": print, + "range": range, + "reversed": reversed, + "round": round, + "set": set, + "slice": slice, + "sorted": sorted, + "str": str, + "sum": sum, + "tuple": tuple, + "type": type, + "zip": zip, + "__import__": safe_import, # Allow safe imports + # Add exception classes for proper error handling + "Exception": Exception, + "ImportError": ImportError, + "ValueError": ValueError, + "TypeError": TypeError, + "KeyError": KeyError, + "IndexError": IndexError, + "AttributeError": AttributeError, + } + + global_vars = {"__builtins__": restricted_builtins} + global_vars.update(allowed_modules) + local_vars = {} + + # Capture output + old_stdout = sys.stdout + old_stderr = sys.stderr + stdout_buffer = io.StringIO() + stderr_buffer = io.StringIO() + + def execute_with_timeout(): + try: + sys.stdout = stdout_buffer + sys.stderr = stderr_buffer + exec(code, global_vars, local_vars) + return True + except Exception as e: + stderr_buffer.write(f"Execution error: {e}") + return False + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + # Execute with timeout + with ThreadPoolExecutor() as executor: + try: + future = executor.submit(execute_with_timeout) + future.result(timeout=timeout) + + stdout_content = stdout_buffer.getvalue() + stderr_content = stderr_buffer.getvalue() + + if stderr_content: + return f"[Error]\n{stderr_content}" + elif stdout_content: + return f"[Output]\n{stdout_content.rstrip()}" + else: + meaningful_vars = {k: v for k, v in local_vars.items() if not k.startswith("_") and k not in allowed_modules} + if meaningful_vars: + return f"[Variables]\n{meaningful_vars}" + else: + return "[Success] Code executed (no output)" + + except TimeoutError: + return f"[Timeout] Execution exceeded {timeout}s" + + return "[Error] Unexpected execution error" + + +# Tool registry +DEEPRESEARCH_TOOLS = { + "Search": SearchTool(), + "Scholar": ScholarTool(), + "Visit": VisitTool(), + "FileParser": FileParserTool(), + "PythonInterpreter": PythonInterpreterTool(), +} + + +def get_tool(name: str) -> DeepResearchTool: + """Get a tool by name.""" + return DEEPRESEARCH_TOOLS.get(name) + + +def get_all_tools() -> dict[str, DeepResearchTool]: + """Get all available tools.""" + return DEEPRESEARCH_TOOLS.copy() diff --git a/examples/deepresearch/deepresearch_workflow.py b/examples/deepresearch/deepresearch_workflow.py new file mode 100644 index 00000000..b461d785 --- /dev/null +++ b/examples/deepresearch/deepresearch_workflow.py @@ -0,0 +1,271 @@ +""" +DeepResearch Workflow for rLLM + +This workflow integrates the DeepResearch agent with rLLM's AgentWorkflowEngine, +enabling parallel execution and trajectory tracking while maintaining DeepResearch's +core reasoning capabilities. +""" + +from deepresearch_agent import MultiTurnReactAgent + +from rllm.agents.agent import Action, Episode, Step, Trajectory +from rllm.engine.rollout import RolloutEngine +from rllm.workflows.workflow import TerminationReason, Workflow + + +class DeepResearchWorkflow(Workflow): + """ + Workflow that wraps the DeepResearch MultiTurnReactAgent for use with AgentWorkflowEngine. + + This workflow: + 1. Creates a DeepResearch agent instance + 2. Executes the research task using the agent's ReAct loop + 3. Converts the results to rLLM Episode format for trajectory tracking + """ + + def __init__( + self, + rollout_engine: RolloutEngine, + executor, + tools: dict = None, + system_prompt: str = None, + **kwargs, + ): + """ + Initialize the DeepResearch workflow. + + Args: + rollout_engine: rLLM rollout engine for model inference + executor: Thread pool executor for async operations + tools: Dictionary of available tools for research tasks + system_prompt: Custom system prompt (optional, uses default if None) + **kwargs: Additional arguments passed to parent Workflow + """ + super().__init__(rollout_engine, executor, **kwargs) + + self.tools = tools or {} + self.system_prompt = system_prompt + + # Auto-detect if we should use native function calling + # O3 models require native function calling, other models use XML format + model_name = rollout_engine.model.lower() + use_native_fc = "o3" in model_name or "o1" in model_name + + # Create the DeepResearch agent + self.agent = MultiTurnReactAgent( + rollout_engine=rollout_engine, + tools=self.tools, + system_prompt=self.system_prompt, + use_native_function_calling=use_native_fc, + ) + + # Note: We don't register the agent since DeepResearch handles its own trajectory + + async def run(self, task: dict, uid: str, **kwargs) -> Episode: + """ + Execute the DeepResearch workflow on a single task. + + Args: + task: Task dictionary containing: + - question: The research question to answer + - answer: Ground truth answer (optional, for evaluation) + - Any other task metadata + uid: Unique identifier for this episode + + Returns: + Episode object with trajectory and results + """ + # Reset workflow state for this task + self.reset(task=task, uid=uid) + + # Extract question and answer from task + question = task.get("question", task.get("query", "No question provided")) + answer = task.get("answer", "") + + print(f"🚀 Starting DeepResearch workflow for task {uid}") + print(f" Question: {question}") + + try: + # Run the DeepResearch agent + result = await self.agent.run(question=question, answer=answer, **kwargs) + + # Convert the result to rLLM Episode format + episode = self._convert_to_episode(result, task, uid) + + print(f"✅ DeepResearch workflow completed for task {uid}") + print(f" Prediction: {result.get('prediction', 'No prediction')}") + + return episode + + except Exception as e: + print(f"❌ DeepResearch workflow failed for task {uid}: {e}") + + # Create a failed episode + episode = Episode() + episode.id = uid + episode.task = task + episode.termination_reason = TerminationReason.UNKNOWN + episode.is_correct = False + episode.trajectories = [] + episode.metrics = {"error": str(e)} + return episode + + def _convert_to_episode(self, result: dict, task: dict, uid: str) -> Episode: + """ + Convert DeepResearch result to rLLM Episode format. + + Args: + result: Result dictionary from DeepResearch agent + task: Original task dictionary + uid: Episode unique identifier + + Returns: + Episode object with trajectory + """ + # Create trajectory from the conversation messages + trajectory = Trajectory(task=task.get("question", "")) + + # Convert conversation to steps + messages = result.get("messages", []) + + i = 0 + while i < len(messages): + # Look for assistant messages (model responses) + if messages[i]["role"] == "assistant": + # Build chat completion context up to this point + current_context = messages[: i + 1] + + # Create step + step = Step( + chat_completions=current_context.copy(), + model_response=messages[i]["content"], + action=self._extract_action_from_response(messages[i]["content"]), + observation=self._get_next_observation(messages, i), + reward=0.0, # Will be computed later if needed + ) + + trajectory.steps.append(step) + + i += 1 + + # Determine if the answer is correct (if ground truth available) + prediction = result.get("prediction", "") + ground_truth = task.get("answer", "") + is_correct = self._evaluate_answer(prediction, ground_truth) if ground_truth else False + + # Map termination reason + termination_reason = self._map_termination_reason(result.get("termination", "unknown")) + + # Create episode + episode = Episode() + episode.id = uid + episode.task = task + episode.termination_reason = termination_reason + episode.is_correct = is_correct + episode.trajectories = [("deepresearch_agent", trajectory)] + episode.metrics = { + "rounds": result.get("rounds", 0), + "time_taken": result.get("time_taken", 0), + "prediction": prediction, + "ground_truth": ground_truth, + } + + return episode + + def _extract_action_from_response(self, response: str) -> Action: + """ + Extract action information from model response. + + Args: + response: Model response text + + Returns: + Action object + """ + # Check for tool calls + if "" in response and "" in response: + tool_call_text = response.split("")[1].split("")[0] + return Action(action={"type": "tool_call", "tool_call": tool_call_text.strip()}) + # Check for final answer + elif "" in response and "" in response: + answer = response.split("")[1].split("")[0].strip() + return Action(action={"type": "final_answer", "answer": answer}) + else: + # Just thinking/reasoning + return Action(action={"type": "reasoning", "content": response}) + + def _get_next_observation(self, messages: list, current_index: int) -> str: + """ + Get the observation that follows the current assistant message. + + Args: + messages: List of all messages + current_index: Index of current assistant message + + Returns: + Next observation string (tool response or empty) + """ + if current_index + 1 < len(messages): + next_msg = messages[current_index + 1] + if next_msg["role"] == "user" and "" in next_msg["content"]: + return next_msg["content"] + + return "" + + def _evaluate_answer(self, prediction: str, ground_truth: str) -> bool: + """ + Simple answer evaluation (can be enhanced with specific metrics). + + Args: + prediction: Model's predicted answer + ground_truth: Correct answer + + Returns: + True if correct, False otherwise + """ + if not prediction or not ground_truth: + return False + + # Simple string matching (can be enhanced with fuzzy matching, etc.) + return prediction.strip().lower() == ground_truth.strip().lower() + + def _map_termination_reason(self, termination: str) -> TerminationReason: + """ + Map DeepResearch termination reasons to rLLM TerminationReason enum. + + Args: + termination: DeepResearch termination string + + Returns: + Mapped TerminationReason + """ + mapping = { + "answer": TerminationReason.ENV_DONE, + "timeout": TerminationReason.TIMEOUT, + "max_rounds_reached": TerminationReason.MAX_TURNS_EXCEEDED, + "token_limit_no_answer": TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED, + "answer_token_limit": TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED, + } + + return mapping.get(termination, TerminationReason.UNKNOWN) + + def reset(self, task: dict = None, uid: str = None): + """ + Reset the workflow for a new task. + + Args: + task: New task dictionary + uid: New unique identifier + """ + # Skip parent reset since we don't use registered agents + # The DeepResearch agent manages its own state per run() + pass + + def is_multithread_safe(self) -> bool: + """ + Indicate whether this workflow is safe for multithreaded execution. + + Returns: + True, as each workflow instance manages its own state + """ + return True diff --git a/examples/deepresearch/evaluate_hle.py b/examples/deepresearch/evaluate_hle.py new file mode 100644 index 00000000..e61345b6 --- /dev/null +++ b/examples/deepresearch/evaluate_hle.py @@ -0,0 +1,517 @@ +""" +Humanity's Last Exam (HLE) Evaluation for DeepResearch + rLLM + +Adapted from original DeepResearch HLE evaluation to work with rLLM's +DeepResearch integration and AgentWorkflowEngine. + +Original: https://github.com/Alibaba-NLP/DeepResearch/blob/main/evaluation/evaluate_hle_official.py +""" + +import argparse +import asyncio +import json +import os +import statistics +from datetime import datetime +from typing import Any + +from datasets import load_dataset +from deepresearch_tools import get_all_tools +from deepresearch_workflow import DeepResearchWorkflow +from dotenv import find_dotenv, load_dotenv + +from rllm.engine.agent_workflow_engine import AgentWorkflowEngine +from rllm.engine.rollout import OpenAIEngine + + +class HLEJudge: + """Judge for evaluating HLE responses using OpenAI API.""" + + def __init__(self, judge_engine: OpenAIEngine): + self.judge_engine = judge_engine + # Binary yes/no judge prompt aligned with Tongyi DeepResearch + self.judge_prompt = """You are an impartial judge evaluating the correctness of an AI assistant's answer. + +[Question] +{question} + +[Correct Answer] +{reference_answer} + +[Assistant's Answer] +{assistant_answer} + +Task: Determine if the assistant's answer is correct by comparing it to the correct answer. + +Instructions: +1. Extract the final answer from the assistant's response +2. Compare it with the correct answer +3. Provide your reasoning +4. Answer with "yes" if correct, "no" if incorrect + +Output format: +correct: [yes/no] +reasoning: [your explanation]""" + + async def judge_response(self, question: str, reference_answer: str, assistant_answer: str) -> dict[str, Any]: + """ + Judge a single response. + + Args: + question: Original question + reference_answer: Ground truth answer + assistant_answer: Model's prediction + + Returns: + Dictionary with judgment results + """ + try: + prompt = self.judge_prompt.format( + question=question, + reference_answer=reference_answer, + assistant_answer=assistant_answer, + ) + + messages = [{"role": "user", "content": prompt}] + + # Use appropriate token parameter based on model + if "o3" in self.judge_engine.model.lower() or "o1" in self.judge_engine.model.lower(): + response = await self.judge_engine.get_model_response(messages=messages, max_completion_tokens=1000) + else: + response = await self.judge_engine.get_model_response(messages=messages, temperature=0.1, max_tokens=1000) + + judgment_text = response.text if hasattr(response, "text") else str(response) + + # Parse binary yes/no from judge output + is_correct = False + if "correct:" in judgment_text.lower(): + # Extract the yes/no after "correct:" + try: + correct_line = [line for line in judgment_text.lower().split("\n") if "correct:" in line][0] + is_correct = "yes" in correct_line + except (IndexError, ValueError): + is_correct = False + + return { + "judgment": judgment_text, + "is_correct": is_correct, + } + + except Exception as e: + print(f"Judge error: {e}") + return {"judgment": f"Judge error: {e}", "is_correct": False} + + +async def evaluate_hle_dataset(dataset_path: str, args) -> dict[str, Any]: + """ + Evaluate DeepResearch on HLE dataset. + + Args: + dataset_path: Path to HLE JSONL dataset + args: Command line arguments + + Returns: + Evaluation results dictionary + """ + print("📊 Starting HLE Evaluation") + print(f"Dataset: {dataset_path}") + print(f"Max samples: {args.max_samples}") + print("=" * 60) + + # Load dataset (HF only to align with examples pattern) + questions = [] + dataset_name = args.hf_dataset or "cais/hle" + split_name = args.hf_split or "test" + + print(f"🧰 Loading dataset from Hugging Face: {dataset_name} (split={split_name})") + try: + if args.hf_config: + ds = load_dataset(dataset_name, args.hf_config, split=split_name) + else: + ds = load_dataset(dataset_name, split=split_name) + + def extract_qa(example: dict[str, Any]) -> dict[str, str]: + q = "" + a = "" + if "question" in example: + q = example["question"] + elif "prompt" in example: + q = example["prompt"] + elif "input" in example: + q = example["input"] + + if "answer" in example: + a = example["answer"] + elif "target" in example: + a = example["target"] + elif "output" in example: + a = example["output"] + elif "correct_answer" in example: + a = example["correct_answer"] + + if "choices" in example and a: + try: + choices_text = "\n".join([f"{i + 1}. {choice}" for i, choice in enumerate(example["choices"])]) + q = f"{q}\n\nChoices:\n{choices_text}" + except Exception: + pass + + # Inject external contexts (urls/files/images/extra text) to help tools + try: + extras: list[str] = [] + # Text contexts + for key in [ + "context", + "contexts", + "extra", + "additional_context", + "background", + "passage", + "passages", + ]: + if key in example and example[key]: + val = example[key] + if isinstance(val, list | tuple): + val_str = "\n".join([str(v) for v in val][:5]) + else: + val_str = str(val) + if val_str.strip(): + extras.append(f"{key.title()}:\n{val_str}") + + # URLs + urls = [] + if "urls" in example and example["urls"]: + urls = example["urls"] if isinstance(example["urls"], list | tuple) else [example["urls"]] + elif "url" in example and example["url"]: + urls = [example["url"]] + if urls: + url_lines = "\n".join([f"- {u}" for u in urls[:10]]) + extras.append(f"URLs:\n{url_lines}") + + # File paths + file_paths = [] + for key in ["file_paths", "file_path", "files"]: + if key in example and example[key]: + vals = example[key] if isinstance(example[key], list | tuple) else [example[key]] + file_paths.extend([str(v) for v in vals]) + if file_paths: + file_lines = "\n".join([f"- {p}" for p in file_paths[:10]]) + extras.append(f"Files:\n{file_lines}") + + # Images + images = [] + for key in ["images", "image"]: + if key in example and example[key]: + vals = example[key] if isinstance(example[key], list | tuple) else [example[key]] + images.extend([str(v) for v in vals]) + if images: + img_lines = "\n".join([f"- {p}" for p in images[:10]]) + extras.append(f"Images:\n{img_lines}") + + if extras: + q = f"{q}\n\nAdditional context for tools:\n" + "\n\n".join(extras) + except Exception: + pass + + return { + "question": str(q) if q is not None else "", + "answer": str(a) if a is not None else "", + } + + total_len = len(ds) + limit = min(args.max_samples, total_len) if args.max_samples else total_len + for idx in range(limit): + ex = ds[idx] + qa = extract_qa(ex) + if qa["question"] and qa["answer"]: + questions.append( + { + "id": f"hle_{idx}", + "question": qa["question"], + "answer": qa["answer"], + } + ) + else: + print(f"Warning: Could not extract question/answer from example {idx}") + + except Exception as e: + print(f"❌ Failed to load dataset from Hugging Face: {e}") + raise + + print(f"📋 Loaded {len(questions)} questions from HLE dataset") + + # Setup rollout engine + load_dotenv(find_dotenv()) + + # Use GPT-4o for model evaluation + model_engine = setup_rollout_engine(args, model_role="evaluation") + + # Setup judge (can use same or different model) + judge_engine = setup_rollout_engine(args, model_role="judge") + judge = HLEJudge(judge_engine) + + # Setup tools + tools = get_all_tools() + + # Create AgentWorkflowEngine + workflow_engine = AgentWorkflowEngine( + workflow_cls=DeepResearchWorkflow, + workflow_args={ + "tools": tools, + "max_prompt_length": 4096, + "max_response_length": 2048, + }, + rollout_engine=model_engine, + n_parallel_tasks=args.parallel_tasks, + retry_limit=1, + ) + + print(f"⚙️ Created evaluation setup with {args.parallel_tasks} parallel tasks") + + # Run DeepResearch evaluation + print("\n🔬 Running DeepResearch evaluation...") + start_time = asyncio.get_event_loop().time() + + try: + episodes = await workflow_engine.execute_tasks(questions) + eval_time = asyncio.get_event_loop().time() - start_time + + print(f"\n✅ Evaluation completed in {eval_time:.1f}s") + + # Extract predictions + results = [] + for episode in episodes: + prediction = episode.metrics.get("prediction", "No prediction available") + results.append( + { + "question": episode.task.get("question", ""), + "reference_answer": episode.task.get("answer", ""), + "prediction": prediction, + "episode_id": episode.id, + "is_correct": episode.is_correct, + "rounds": episode.metrics.get("rounds", 0), + "termination_reason": episode.termination_reason.value if episode.termination_reason else "unknown", + } + ) + + # Judge responses + print(f"\n⚖️ Judging {len(results)} responses...") + + judge_results = [] + for result in results: + judgment = await judge.judge_response( + question=result["question"], + reference_answer=result["reference_answer"], + assistant_answer=result["prediction"], + ) + result.update(judgment) + judge_results.append(result) + + # Calculate metrics + metrics = calculate_hle_metrics(judge_results) + metrics["evaluation_time"] = eval_time + metrics["total_questions"] = len(questions) + + # Save results + save_hle_results(judge_results, metrics, args) + + return metrics + + except Exception as e: + print(f"❌ Evaluation failed: {e}") + raise + + +def setup_rollout_engine(args, model_role="evaluation") -> OpenAIEngine: + """Setup rollout engine for evaluation or judging.""" + + # Load environment variables + load_dotenv(find_dotenv()) + + # Provider selection + together_api_key = os.getenv("TOGETHER_AI_API_KEY") + openai_api_key = os.getenv("OPENAI_API_KEY") + + if args.api_key: + api_key = args.api_key + base_url = args.base_url or "https://api.openai.com/v1" + model_name = args.model or "gpt-4" + elif together_api_key and model_role == "evaluation": + api_key = together_api_key + base_url = args.base_url or "https://api.together.xyz/v1" + model_name = args.model or os.getenv("TOGETHER_AI_MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct-Turbo") + print(f"🔧 Using Together AI for {model_role}") + elif openai_api_key: + api_key = openai_api_key + base_url = args.base_url or "https://api.openai.com/v1" + model_name = args.model or "gpt-4o" + print(f"🔧 Using OpenAI for {model_role}") + else: + raise ValueError("❌ API key required. Please set OPENAI_API_KEY or TOGETHER_AI_API_KEY in .env file") + + # For evaluation, DeepResearch handles all sampling params internally + # For judge, we need basic params + if model_role == "judge": + # Check if model is O3/O1 (use model_name which is already determined above) + if "o3" in model_name.lower() or "o1" in model_name.lower(): + sampling_params = { + "max_completion_tokens": 1000, + } + else: + sampling_params = { + "temperature": 0.1, + "top_p": 0.95, + "max_tokens": 1000, + } + else: + # Don't set default sampling_params for evaluation + # DeepResearch will handle model-specific params + sampling_params = {} + + return OpenAIEngine( + model=model_name, + tokenizer=None, + base_url=base_url, + api_key=api_key, + sampling_params=sampling_params, + ) + + +def calculate_hle_metrics(results: list[dict[str, Any]]) -> dict[str, Any]: + """Calculate HLE evaluation metrics.""" + + total = len(results) + if total == 0: + return {"error": "No results to evaluate"} + + # Basic accuracy (judge-based binary yes/no) + judge_correct = sum(1 for r in results if r.get("is_correct", False)) + judge_accuracy = judge_correct / total + + # Termination analysis + termination_counts = {} + for result in results: + reason = result.get("termination_reason", "unknown") + termination_counts[reason] = termination_counts.get(reason, 0) + 1 + + # Round analysis + rounds = [r.get("rounds", 0) for r in results] + avg_rounds = statistics.mean(rounds) if rounds else 0 + + return { + "total_questions": total, + "judge_accuracy": judge_accuracy, + "judge_correct": judge_correct, + "average_rounds": avg_rounds, + "termination_distribution": termination_counts, + } + + +def save_hle_results(results: list[dict], metrics: dict, args): + """Save HLE evaluation results.""" + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Save detailed results + results_file = os.path.join(args.output_dir, f"hle_results_{timestamp}.json") + os.makedirs(args.output_dir, exist_ok=True) + + with open(results_file, "w", encoding="utf-8") as f: + json.dump( + { + "metadata": { + "timestamp": timestamp, + "dataset": "HLE", + "model": args.model, + "total_questions": len(results), + }, + "results": results, + "metrics": metrics, + }, + f, + indent=2, + ensure_ascii=False, + ) + + # Save metrics summary + metrics_file = os.path.join(args.output_dir, f"hle_metrics_{timestamp}.json") + with open(metrics_file, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=2, ensure_ascii=False) + + print(f"💾 Results saved to: {results_file}") + print(f"📊 Metrics saved to: {metrics_file}") + + +def print_hle_summary(metrics: dict[str, Any]): + """Print HLE evaluation summary.""" + + print("\n" + "=" * 60) + print("📊 HLE EVALUATION SUMMARY") + print("=" * 60) + print(f"Total Questions: {metrics.get('total_questions', 0)}") + print(f"Judge Accuracy: {metrics.get('judge_accuracy', 0):.2%}") + print(f"Correct Answers: {metrics.get('judge_correct', 0)}/{metrics.get('total_questions', 0)}") + print(f"Average Rounds: {metrics.get('average_rounds', 0):.1f}") + print(f"Evaluation Time: {metrics.get('evaluation_time', 0):.1f}s") + + print("\nTermination Reasons:") + term_dist = metrics.get("termination_distribution", {}) + for reason, count in term_dist.items(): + print(f" {reason}: {count}") + + print("=" * 60) + + +async def main(): + parser = argparse.ArgumentParser(description="Run HLE evaluation with DeepResearch + rLLM") + + # Dataset options (HF only) + parser.add_argument( + "--hf-dataset", + default="cais/hle", + help="Hugging Face dataset path (default: cais/hle)", + ) + parser.add_argument( + "--hf-config", + default=None, + help="Optional dataset configuration name for HF datasets that require it.", + ) + parser.add_argument( + "--hf-split", + default="test", + help="Dataset split to load from HF (default: test)", + ) + parser.add_argument( + "--max-samples", + type=int, + default=None, + help="Maximum number of samples to evaluate", + ) + + # Model options + parser.add_argument("--model", default=None, help="Model name to use") + parser.add_argument("--base-url", default=None, help="API base URL") + parser.add_argument("--api-key", default=None, help="API key (uses env vars if not provided)") + + # Execution options + parser.add_argument("--parallel-tasks", type=int, default=4, help="Number of parallel tasks") + parser.add_argument("--output-dir", default="./hle_outputs", help="Output directory for results") + + args = parser.parse_args() + + try: + metrics = await evaluate_hle_dataset(args.hf_dataset, args) + print_hle_summary(metrics) + + except Exception as e: + print(f"❌ HLE evaluation failed: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + # Set environment for tokenizers + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + asyncio.run(main()) diff --git a/rllm/engine/rollout/openai_engine.py b/rllm/engine/rollout/openai_engine.py index fcff3ab1..6fb809c9 100644 --- a/rllm/engine/rollout/openai_engine.py +++ b/rllm/engine/rollout/openai_engine.py @@ -34,6 +34,22 @@ def __init__(self, model: str = "", tokenizer=None, max_prompt_length: int = 409 self.client = openai.AsyncOpenAI(base_url=base_url, api_key=api_key) logging.getLogger("httpx").setLevel(logging.WARNING) + def _prepare_max_tokens_param(self, sampling_params: dict, prompt_length: int = None) -> dict: + """Prepare max tokens parameter for API call (supports O3's max_completion_tokens).""" + if "max_completion_tokens" in sampling_params: + return {"max_completion_tokens": sampling_params.pop("max_completion_tokens")} + + max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length)) + + # Adjust for prompt length if provided (completion method needs this) + if prompt_length and self.max_model_length: + remaining = self.max_model_length - prompt_length + if remaining <= max_tokens: + max_tokens = remaining + print(f"Warning: Decreasing max_tokens to {max_tokens} to stay within max_model_length") + + return {"max_tokens": max_tokens} + async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: kwargs.pop("application_id", None) kwargs.pop("validate", None) @@ -43,19 +59,22 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: sampling_params = self.sampling_params.copy() sampling_params.update(kwargs) - max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length)) + create_params = self._prepare_max_tokens_param(sampling_params) retries = self.api_retries while retries > 0: try: - response = await self.client.chat.completions.create(model=self.model, messages=messages, timeout=3600, max_tokens=max_tokens, **sampling_params) + response = await self.client.chat.completions.create(model=self.model, messages=messages, timeout=3600, **create_params, **sampling_params) content = response.choices[0].message.content reasoning = response.choices[0].message.reasoning if hasattr(response.choices[0].message, "reasoning") and isinstance(response.choices[0].message.reasoning, str) else "" tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") and isinstance(response.choices[0].message.tool_calls, list) else [] + # Build text with reasoning if available, otherwise use content if reasoning: - text = f"{THOUGHT_DELIMITER_START}\n{reasoning}\n{THOUGHT_DELIMITER_END}\n\n{content}" # best guess + text = f"{THOUGHT_DELIMITER_START}\n{reasoning}\n{THOUGHT_DELIMITER_END}\n\n{content}" + else: + text = content prompt_length = response.usage.prompt_tokens completion_length = response.usage.completion_tokens @@ -102,16 +121,12 @@ async def completion(self, prompt: str, **kwargs) -> ModelOutput: if enforce_max_prompt_length and (prompt_length > self.max_prompt_length or prompt_length > self.max_model_length): raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length)) - remaining_tokens = self.max_model_length - prompt_length - if remaining_tokens <= max_tokens: - max_tokens = remaining_tokens - print(f"Warning: Decreasing max_tokens to {max_tokens} to stay within max_model_length") + create_params = self._prepare_max_tokens_param(sampling_params, prompt_length) retries = self.api_retries while retries > 0: try: - response = await self.client.completions.create(model=self.model, prompt=prompt, timeout=3600, max_tokens=max_tokens, **sampling_params) + response = await self.client.completions.create(model=self.model, prompt=prompt, timeout=3600, **create_params, **sampling_params) text = response.choices[0].text completion_ids = self.tokenizer.encode(text, add_special_tokens=False)