From a3ba907a2257ae0c754c67e54e21befc020d36cd Mon Sep 17 00:00:00 2001 From: Tyson Thomas Date: Sun, 15 Jun 2025 21:53:43 -0700 Subject: [PATCH] Refactor LLM SDK --- .gitignore | 1 + config/gni/devtools_grd_files.gni | 16 +- front_end/panels/ai_chat/BUILD.gn | 28 +- front_end/panels/ai_chat/LLM/LLMClient.ts | 300 +++++++++ .../panels/ai_chat/LLM/LLMErrorHandler.ts | 345 ++++++++++ front_end/panels/ai_chat/LLM/LLMProvider.ts | 94 +++ .../panels/ai_chat/LLM/LLMProviderRegistry.ts | 106 ++++ .../panels/ai_chat/LLM/LLMResponseParser.ts | 306 +++++++++ front_end/panels/ai_chat/LLM/LLMTypes.ts | 228 +++++++ .../panels/ai_chat/LLM/LiteLLMProvider.ts | 377 +++++++++++ .../panels/ai_chat/LLM/OpenAIProvider.ts | 427 +++++++++++++ .../ai_chat/agent_framework/AgentRunner.ts | 94 ++- front_end/panels/ai_chat/core/AgentNodes.ts | 217 +++++-- front_end/panels/ai_chat/core/AgentService.ts | 46 ++ front_end/panels/ai_chat/core/ChatLiteLLM.ts | 145 ----- front_end/panels/ai_chat/core/ChatOpenAI.ts | 173 ------ .../panels/ai_chat/core/ConfigurableGraph.ts | 21 +- front_end/panels/ai_chat/core/Graph.ts | 61 +- front_end/panels/ai_chat/core/GraphHelpers.ts | 57 -- .../panels/ai_chat/core/LiteLLMClient.ts | 482 -------------- front_end/panels/ai_chat/core/OpenAIClient.ts | 371 ----------- .../panels/ai_chat/core/UnifiedLLMClient.ts | 588 ------------------ .../framework/GenericToolEvaluator.ts | 54 +- .../framework/judges/LLMEvaluator.ts | 33 +- .../panels/ai_chat/tools/CritiqueTool.ts | 82 ++- ...FullPageAccessibilityTreeToMarkdownTool.ts | 29 +- .../ai_chat/tools/HTMLToMarkdownTool.ts | 33 +- .../ai_chat/tools/SchemaBasedExtractorTool.ts | 94 ++- .../tools/StreamlinedSchemaExtractorTool.ts | 47 +- front_end/panels/ai_chat/tools/Tools.ts | 83 ++- front_end/panels/ai_chat/ui/AIChatPanel.ts | 4 +- front_end/panels/ai_chat/ui/SettingsDialog.ts | 6 +- 32 files changed, 2826 insertions(+), 2122 deletions(-) create mode 100644 front_end/panels/ai_chat/LLM/LLMClient.ts create mode 100644 front_end/panels/ai_chat/LLM/LLMErrorHandler.ts create mode 100644 front_end/panels/ai_chat/LLM/LLMProvider.ts create mode 100644 front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts create mode 100644 front_end/panels/ai_chat/LLM/LLMResponseParser.ts create mode 100644 front_end/panels/ai_chat/LLM/LLMTypes.ts create mode 100644 front_end/panels/ai_chat/LLM/LiteLLMProvider.ts create mode 100644 front_end/panels/ai_chat/LLM/OpenAIProvider.ts delete mode 100644 front_end/panels/ai_chat/core/ChatLiteLLM.ts delete mode 100644 front_end/panels/ai_chat/core/ChatOpenAI.ts delete mode 100644 front_end/panels/ai_chat/core/LiteLLMClient.ts delete mode 100644 front_end/panels/ai_chat/core/OpenAIClient.ts delete mode 100644 front_end/panels/ai_chat/core/UnifiedLLMClient.ts diff --git a/.gitignore b/.gitignore index dc3f801d0b7..3dcfc570877 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,4 @@ test/perf/.generated # Dependencies node_modules/ +*/.idea/ \ No newline at end of file diff --git a/config/gni/devtools_grd_files.gni b/config/gni/devtools_grd_files.gni index d2793a46ca6..df946b4aba5 100644 --- a/config/gni/devtools_grd_files.gni +++ b/config/gni/devtools_grd_files.gni @@ -604,6 +604,7 @@ grd_files_release_sources = [ "front_end/panels/ai_chat/ui/HelpDialog.js", "front_end/panels/ai_chat/ui/PromptEditDialog.js", "front_end/panels/ai_chat/ui/SettingsDialog.js", + "front_end/panels/ai_chat/ui/EvaluationDialog.js", "front_end/panels/ai_chat/core/AgentService.js", "front_end/panels/ai_chat/core/State.js", "front_end/panels/ai_chat/core/Graph.js", @@ -611,16 +612,20 @@ grd_files_release_sources = [ "front_end/panels/ai_chat/core/Constants.js", "front_end/panels/ai_chat/core/ConfigurableGraph.js", "front_end/panels/ai_chat/core/GraphConfigs.js", - "front_end/panels/ai_chat/core/OpenAIClient.js", - "front_end/panels/ai_chat/core/LiteLLMClient.js", - "front_end/panels/ai_chat/core/UnifiedLLMClient.js", "front_end/panels/ai_chat/core/BaseOrchestratorAgent.js", "front_end/panels/ai_chat/core/PageInfoManager.js", "front_end/panels/ai_chat/core/AgentNodes.js", - "front_end/panels/ai_chat/core/ChatOpenAI.js", - "front_end/panels/ai_chat/core/ChatLiteLLM.js", "front_end/panels/ai_chat/core/GraphHelpers.js", "front_end/panels/ai_chat/core/StateGraph.js", + "front_end/panels/ai_chat/core/Logger.js", + "front_end/panels/ai_chat/LLM/LLMTypes.js", + "front_end/panels/ai_chat/LLM/LLMProvider.js", + "front_end/panels/ai_chat/LLM/LLMProviderRegistry.js", + "front_end/panels/ai_chat/LLM/LLMErrorHandler.js", + "front_end/panels/ai_chat/LLM/LLMResponseParser.js", + "front_end/panels/ai_chat/LLM/OpenAIProvider.js", + "front_end/panels/ai_chat/LLM/LiteLLMProvider.js", + "front_end/panels/ai_chat/LLM/LLMClient.js", "front_end/panels/ai_chat/tools/Tools.js", "front_end/panels/ai_chat/tools/CombinedExtractionTool.js", "front_end/panels/ai_chat/tools/CritiqueTool.js", @@ -628,6 +633,7 @@ grd_files_release_sources = [ "front_end/panels/ai_chat/tools/FinalizeWithCritiqueTool.js", "front_end/panels/ai_chat/tools/HTMLToMarkdownTool.js", "front_end/panels/ai_chat/tools/SchemaBasedExtractorTool.js", + "front_end/panels/ai_chat/tools/StreamlinedSchemaExtractorTool.js", "front_end/panels/ai_chat/tools/VisitHistoryManager.js", "front_end/panels/ai_chat/tools/FullPageAccessibilityTreeToMarkdownTool.js", "front_end/panels/ai_chat/common/utils.js", diff --git a/front_end/panels/ai_chat/BUILD.gn b/front_end/panels/ai_chat/BUILD.gn index 43599fed10f..53e7b8af75b 100644 --- a/front_end/panels/ai_chat/BUILD.gn +++ b/front_end/panels/ai_chat/BUILD.gn @@ -29,18 +29,21 @@ devtools_module("ai_chat") { "core/AgentService.ts", "core/Constants.ts", "core/GraphConfigs.ts", - "core/OpenAIClient.ts", - "core/LiteLLMClient.ts", - "core/UnifiedLLMClient.ts", "core/ConfigurableGraph.ts", "core/BaseOrchestratorAgent.ts", "core/PageInfoManager.ts", "core/AgentNodes.ts", - "core/ChatOpenAI.ts", - "core/ChatLiteLLM.ts", "core/GraphHelpers.ts", "core/StateGraph.ts", "core/Logger.ts", + "LLM/LLMTypes.ts", + "LLM/LLMProvider.ts", + "LLM/LLMProviderRegistry.ts", + "LLM/LLMErrorHandler.ts", + "LLM/LLMResponseParser.ts", + "LLM/OpenAIProvider.ts", + "LLM/LiteLLMProvider.ts", + "LLM/LLMClient.ts", "tools/Tools.ts", "tools/CritiqueTool.ts", "tools/FetcherTool.ts", @@ -107,17 +110,21 @@ _ai_chat_sources = [ "core/AgentService.ts", "core/Constants.ts", "core/GraphConfigs.ts", - "core/OpenAIClient.ts", - "core/LiteLLMClient.ts", - "core/UnifiedLLMClient.ts", "core/ConfigurableGraph.ts", "core/BaseOrchestratorAgent.ts", "core/PageInfoManager.ts", "core/AgentNodes.ts", - "core/ChatOpenAI.ts", - "core/ChatLiteLLM.ts", "core/GraphHelpers.ts", "core/StateGraph.ts", + "core/Logger.ts", + "LLM/LLMTypes.ts", + "LLM/LLMProvider.ts", + "LLM/LLMProviderRegistry.ts", + "LLM/LLMErrorHandler.ts", + "LLM/LLMResponseParser.ts", + "LLM/OpenAIProvider.ts", + "LLM/LiteLLMProvider.ts", + "LLM/LLMClient.ts", "tools/Tools.ts", "tools/CritiqueTool.ts", "tools/FetcherTool.ts", @@ -125,6 +132,7 @@ _ai_chat_sources = [ "tools/VisitHistoryManager.ts", "tools/HTMLToMarkdownTool.ts", "tools/SchemaBasedExtractorTool.ts", + "tools/StreamlinedSchemaExtractorTool.ts", "tools/CombinedExtractionTool.ts", "tools/FullPageAccessibilityTreeToMarkdownTool.ts", "agent_framework/ConfigurableAgentTool.ts", diff --git a/front_end/panels/ai_chat/LLM/LLMClient.ts b/front_end/panels/ai_chat/LLM/LLMClient.ts new file mode 100644 index 00000000000..71468993615 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/LLMClient.ts @@ -0,0 +1,300 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { LLMMessage, LLMResponse, LLMCallOptions, LLMProvider, ModelInfo } from './LLMTypes.js'; +import { LLMProviderRegistry } from './LLMProviderRegistry.js'; +import { OpenAIProvider } from './OpenAIProvider.js'; +import { LiteLLMProvider } from './LiteLLMProvider.js'; +import { LLMResponseParser } from './LLMResponseParser.js'; +import { createLogger } from '../core/Logger.js'; + +const logger = createLogger('LLMClient'); + +/** + * Configuration for individual LLM providers + */ +export interface LLMProviderConfig { + provider: LLMProvider; + apiKey: string; + providerURL?: string; // Optional: for LiteLLM endpoint or custom OpenAI endpoint +} + +/** + * Configuration for the LLM client + */ +export interface LLMClientConfig { + providers: LLMProviderConfig[]; +} + +/** + * Request structure for LLM calls + */ +export interface LLMCallRequest { + provider: LLMProvider; + model: string; + messages: LLMMessage[]; + systemPrompt: string; + tools?: any[]; + temperature?: number; +} + +/** + * Main LLM client coordinator that provides a unified interface for agents + * Replaces UnifiedLLMClient with cleaner architecture + */ +export class LLMClient { + private static instance: LLMClient | null = null; + private initialized = false; + + private constructor() {} + + /** + * Get the singleton instance + */ + static getInstance(): LLMClient { + if (!LLMClient.instance) { + LLMClient.instance = new LLMClient(); + } + return LLMClient.instance; + } + + /** + * Initialize the LLM client with provider configurations + */ + async initialize(config: LLMClientConfig): Promise { + logger.info('Initializing LLM client with providers:', config.providers.map(p => p.provider)); + + // Clear existing providers + LLMProviderRegistry.clear(); + + // Register providers based on configuration + for (const providerConfig of config.providers) { + try { + let providerInstance; + + switch (providerConfig.provider) { + case 'openai': + providerInstance = new OpenAIProvider(providerConfig.apiKey); + break; + case 'litellm': + providerInstance = new LiteLLMProvider( + providerConfig.apiKey, + providerConfig.providerURL + ); + break; + default: + logger.warn(`Unknown provider type: ${providerConfig.provider}`); + continue; + } + + LLMProviderRegistry.registerProvider(providerConfig.provider, providerInstance); + logger.info(`Registered ${providerConfig.provider} provider`); + } catch (error) { + logger.error(`Failed to initialize ${providerConfig.provider} provider:`, error); + } + } + + this.initialized = true; + logger.info('LLM client initialization complete'); + } + + /** + * Check if the client is initialized + */ + private ensureInitialized(): void { + if (!this.initialized) { + throw new Error('LLMClient must be initialized before use. Call initialize() first.'); + } + } + + /** + * Main method for LLM calls with request object + */ + async call(request: LLMCallRequest): Promise { + this.ensureInitialized(); + + const provider = LLMProviderRegistry.getProvider(request.provider); + + if (!provider) { + throw new Error(`Provider ${request.provider} not available. Available providers: ${LLMProviderRegistry.getRegisteredProviders().join(', ')}`); + } + + logger.debug(`Using ${request.provider} provider for model ${request.model}`); + + // Build messages array with required system prompt + let messages = [...request.messages]; + + // Add system prompt - always required + const hasSystemMessage = messages.some(msg => msg.role === 'system'); + if (!hasSystemMessage) { + messages.unshift({ + role: 'system', + content: request.systemPrompt + }); + } + + // Build options + const options: LLMCallOptions = {}; + if (request.temperature !== undefined) { + options.temperature = request.temperature; + } + if (request.tools) { + options.tools = request.tools; + } + + return provider.callWithMessages(request.model, messages, options); + } + + + /** + * Parse response into standardized action structure + */ + parseResponse(response: LLMResponse): ReturnType { + return LLMResponseParser.parseResponse(response); + } + + /** + * Get all available models from all providers + */ + async getAvailableModels(): Promise { + this.ensureInitialized(); + return LLMProviderRegistry.getAllModels(); + } + + /** + * Get models for a specific provider + */ + async getModelsByProvider(provider: LLMProvider): Promise { + this.ensureInitialized(); + return LLMProviderRegistry.getModelsByProvider(provider); + } + + /** + * Test connection to a specific model + */ + async testConnection(provider: LLMProvider, modelId: string): Promise<{success: boolean, message: string}> { + this.ensureInitialized(); + + const providerInstance = LLMProviderRegistry.getProvider(provider); + + if (!providerInstance) { + return { + success: false, + message: `Provider ${provider} not available` + }; + } + + if (providerInstance.testConnection) { + return providerInstance.testConnection(modelId); + } + + // Fallback test: simple call + try { + const response = await this.call({ + provider, + model: modelId, + messages: [{ role: 'user', content: 'Please respond with "OK" to test the connection.' }], + systemPrompt: 'You are a helpful assistant for testing purposes.', + temperature: 0.1 + }); + + return { + success: true, + message: `Connected successfully. Response: ${response.text || 'No text response'}` + }; + } catch (error) { + return { + success: false, + message: error instanceof Error ? error.message : 'Unknown error occurred' + }; + } + } + + /** + * Refresh models for a specific provider or all providers + */ + async refreshProviderModels(provider?: LLMProvider): Promise { + this.ensureInitialized(); + + if (provider) { + const providerInstance = LLMProviderRegistry.getProvider(provider); + if (providerInstance) { + try { + await providerInstance.getModels(); + logger.info(`Refreshed models for ${provider} provider`); + } catch (error) { + logger.error(`Failed to refresh models for ${provider}:`, error); + } + } + } else { + // Refresh all providers + const providers = LLMProviderRegistry.getRegisteredProviders(); + for (const providerType of providers) { + await this.refreshProviderModels(providerType); + } + } + } + + /** + * Register a custom model with the LiteLLM provider + */ + registerCustomModel(modelId: string, name?: string): ModelInfo { + const modelInfo: ModelInfo = { + id: modelId, + name: name || modelId, + provider: 'litellm', + capabilities: { + functionCalling: true, + reasoning: false, + vision: false, + structured: true + } + }; + + // Save to localStorage for LiteLLM provider to pick up + try { + const existingModels = JSON.parse(localStorage.getItem('ai_chat_custom_models') || '[]'); + const updatedModels = [...existingModels, modelInfo]; + localStorage.setItem('ai_chat_custom_models', JSON.stringify(updatedModels)); + logger.info(`Registered custom model: ${modelId}`); + } catch (error) { + logger.error('Failed to save custom model to localStorage:', error); + } + + return modelInfo; + } + + /** + * Get registry statistics + */ + getStats(): { + initialized: boolean; + providersCount: number; + providers: LLMProvider[]; + } { + const registryStats = LLMProviderRegistry.getStats(); + return { + initialized: this.initialized, + ...registryStats + }; + } + + /** + * Static method to fetch models from LiteLLM endpoint (for UI use without initialization) + */ + static async fetchLiteLLMModels(apiKey: string | null, baseUrl?: string): Promise { + const provider = new LiteLLMProvider(apiKey, baseUrl); + const models = await provider.fetchModels(); + return models; + } + + /** + * Static method to test LiteLLM connection (for UI use without initialization) + */ + static async testLiteLLMConnection(apiKey: string | null, modelName: string, baseUrl?: string): Promise<{success: boolean, message: string}> { + const provider = new LiteLLMProvider(apiKey, baseUrl); + return provider.testConnection(modelName); + } + +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/LLM/LLMErrorHandler.ts b/front_end/panels/ai_chat/LLM/LLMErrorHandler.ts new file mode 100644 index 00000000000..f384d9b0205 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/LLMErrorHandler.ts @@ -0,0 +1,345 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { createLogger } from '../core/Logger.js'; +import type { + LLMErrorType, + RetryConfig, + ErrorRetryConfig, + ExtendedRetryConfig, + RetryCallback +} from './LLMTypes.js'; +import { LLMErrorType as ErrorType } from './LLMTypes.js'; + +const logger = createLogger('LLMErrorHandler'); + +/** + * Default retry configuration for all error types + */ +const DEFAULT_RETRY_CONFIG: RetryConfig = { + maxRetries: 2, + baseDelayMs: 1000, + maxDelayMs: 10000, + backoffMultiplier: 2, + jitterMs: 500, +}; + +/** + * Error-specific retry configurations (only for specific error types) + * All other error types will use DEFAULT_RETRY_CONFIG + */ +const ERROR_SPECIFIC_RETRY_CONFIGS: ErrorRetryConfig = { + [ErrorType.RATE_LIMIT_ERROR]: { + maxRetries: 3, + baseDelayMs: 60000, // 60 seconds for rate limits + maxDelayMs: 300000, // Max 5 minutes + backoffMultiplier: 1, // No exponential backoff for rate limits + jitterMs: 5000, // Small jitter to avoid thundering herd + }, + + [ErrorType.NETWORK_ERROR]: { + maxRetries: 3, + baseDelayMs: 2000, + maxDelayMs: 30000, + backoffMultiplier: 2, + jitterMs: 1000, + }, +}; + +/** + * Utility class for classifying errors that occur during LLM calls + */ +export class LLMErrorClassifier { + /** + * Classify an error based on its message and properties + */ + static classifyError(error: Error): LLMErrorType { + const message = error.message.toLowerCase(); + + // JSON parsing errors + if (message.includes('json parsing failed') || + message.includes('invalid json') || + message.includes('json parse') || + message.includes('unexpected token') || + message.includes('syntaxerror')) { + return ErrorType.JSON_PARSE_ERROR; + } + + // Rate limit detection + if (message.includes('rate limit') || + message.includes('too many requests') || + message.includes('quota exceeded') || + message.includes('429') || + message.includes('rate_limit_exceeded')) { + return ErrorType.RATE_LIMIT_ERROR; + } + + // Network errors + if (message.includes('fetch') || + message.includes('network') || + message.includes('connection') || + message.includes('timeout') || + message.includes('econnreset') || + message.includes('enotfound') || + message.includes('aborted') || + message.includes('socket')) { + return ErrorType.NETWORK_ERROR; + } + + // Server errors (5xx) + if (message.includes('internal server error') || + message.includes('502') || + message.includes('503') || + message.includes('504') || + message.includes('500') || + message.includes('server error') || + message.includes('service unavailable')) { + return ErrorType.SERVER_ERROR; + } + + // Authentication errors + if (message.includes('unauthorized') || + message.includes('invalid api key') || + message.includes('authentication') || + message.includes('401') || + message.includes('forbidden') || + message.includes('403')) { + return ErrorType.AUTH_ERROR; + } + + // Quota/billing errors + if (message.includes('insufficient quota') || + message.includes('billing') || + message.includes('usage limit') || + message.includes('quota_exceeded') || + message.includes('insufficient_quota')) { + return ErrorType.QUOTA_ERROR; + } + + return ErrorType.UNKNOWN_ERROR; + } + + /** + * Check if an error type should be retried + */ + static shouldRetry(errorType: LLMErrorType): boolean { + // Auth and quota errors should never be retried + return errorType !== ErrorType.AUTH_ERROR && errorType !== ErrorType.QUOTA_ERROR; + } + + /** + * Get the retry configuration for a specific error type + */ + static getRetryConfig(errorType: LLMErrorType, customConfig?: Partial): RetryConfig { + // Start with default config + let config = { ...DEFAULT_RETRY_CONFIG }; + + // Apply error-specific config if available + const errorSpecificConfig = ERROR_SPECIFIC_RETRY_CONFIGS[errorType]; + if (errorSpecificConfig) { + config = { ...config, ...errorSpecificConfig }; + } + + // Apply custom overrides + if (customConfig) { + config = { ...config, ...customConfig }; + } + + return config; + } +} + +/** + * Manages retry logic for LLM operations with exponential backoff and jitter + */ +export class LLMRetryManager { + private config: ExtendedRetryConfig; + private onRetry?: RetryCallback; + + constructor(config: ExtendedRetryConfig = {}) { + this.config = { + defaultConfig: DEFAULT_RETRY_CONFIG, + enableLogging: true, + ...config, + }; + this.onRetry = config.onRetry; + } + + /** + * Execute an operation with retry logic + */ + async executeWithRetry( + operation: () => Promise, + options: { + customRetryConfig?: Partial; + context?: string; + } = {} + ): Promise { + const startTime = Date.now(); + let lastError: Error; + let attempt = 1; + + while (true) { + try { + const result = await operation(); + + if (attempt > 1 && this.config.enableLogging) { + logger.info(`Operation succeeded on attempt ${attempt}`, { + context: options.context, + totalTime: Date.now() - startTime, + }); + } + + return result; + } catch (error) { + lastError = error instanceof Error ? error : new Error(String(error)); + const errorType = LLMErrorClassifier.classifyError(lastError); + + if (this.config.enableLogging) { + logger.error(`Operation failed on attempt ${attempt}:`, { + error: lastError.message, + errorType, + context: options.context, + }); + } + + // Check if we should retry this error type + if (!LLMErrorClassifier.shouldRetry(errorType)) { + if (this.config.enableLogging) { + logger.info(`Not retrying ${errorType} error`); + } + throw lastError; + } + + // Get retry configuration + const retryConfig = LLMErrorClassifier.getRetryConfig(errorType, options.customRetryConfig); + + // Check if we've exceeded max retries + if (attempt > retryConfig.maxRetries) { + if (this.config.enableLogging) { + logger.error(`Max retries (${retryConfig.maxRetries}) exceeded for ${errorType}`); + } + throw lastError; + } + + // Check total time limit + if (this.config.maxTotalTimeMs && (Date.now() - startTime) >= this.config.maxTotalTimeMs) { + if (this.config.enableLogging) { + logger.error(`Total retry time limit (${this.config.maxTotalTimeMs}ms) exceeded`); + } + throw lastError; + } + + // Calculate delay and wait + const delayMs = this.calculateDelay(retryConfig, attempt); + + if (this.config.enableLogging) { + logger.warn(`Retrying after ${delayMs}ms (attempt ${attempt + 1}/${retryConfig.maxRetries + 1}) for ${errorType}`); + } + + // Call retry callback if provided + if (this.onRetry) { + this.onRetry(attempt, lastError, errorType, delayMs); + } + + if (delayMs > 0) { + await this.sleep(delayMs); + } + + attempt++; + } + } + } + + /** + * Calculate retry delay with exponential backoff and jitter + */ + private calculateDelay(config: RetryConfig, attempt: number): number { + const baseDelay = config.baseDelayMs; + const multiplier = config.backoffMultiplier; + const maxDelay = config.maxDelayMs; + const jitter = config.jitterMs; + + // Calculate exponential backoff + const exponentialDelay = baseDelay * Math.pow(multiplier, attempt - 1); + + // Apply max delay cap + const cappedDelay = Math.min(exponentialDelay, maxDelay); + + // Add random jitter to avoid thundering herd problem + const randomJitter = jitter > 0 ? Math.random() * jitter : 0; + + return Math.max(0, cappedDelay + randomJitter); + } + + /** + * Sleep for specified milliseconds + */ + private async sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); + } + + /** + * Static convenience method for simple retry scenarios + */ + static async simpleRetry( + operation: () => Promise, + customConfig?: Partial + ): Promise { + const manager = new LLMRetryManager(); + return manager.executeWithRetry(operation, { customRetryConfig: customConfig }); + } +} + +/** + * Static utility functions for common error handling scenarios + */ +export class LLMErrorUtils { + /** + * Check if an error is retryable + */ + static isRetryable(error: Error): boolean { + const errorType = LLMErrorClassifier.classifyError(error); + return LLMErrorClassifier.shouldRetry(errorType); + } + + /** + * Get human-readable error message + */ + static getErrorMessage(error: Error): string { + const errorType = LLMErrorClassifier.classifyError(error); + + switch (errorType) { + case ErrorType.RATE_LIMIT_ERROR: + return 'Rate limit exceeded. Please wait before trying again.'; + case ErrorType.NETWORK_ERROR: + return 'Network connection error. Please check your internet connection.'; + case ErrorType.AUTH_ERROR: + return 'Authentication failed. Please check your API key.'; + case ErrorType.QUOTA_ERROR: + return 'API quota exceeded. Please check your usage limits.'; + case ErrorType.SERVER_ERROR: + return 'Server error. The service may be temporarily unavailable.'; + case ErrorType.JSON_PARSE_ERROR: + return 'Failed to parse response. The AI response was not in the expected format.'; + default: + return error.message || 'An unknown error occurred.'; + } + } + + /** + * Create enhanced error with additional context + */ + static enhanceError(error: Error, context: { operation?: string; attempt?: number }): Error { + const errorType = LLMErrorClassifier.classifyError(error); + const enhancedMessage = `${context.operation || 'LLM operation'} failed with ${errorType}${context.attempt ? ` (attempt ${context.attempt})` : ''}: ${error.message}`; + + const enhancedError = new Error(enhancedMessage); + (enhancedError as any).originalError = error; + (enhancedError as any).errorType = errorType; + (enhancedError as any).context = context; + + return enhancedError; + } +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/LLM/LLMProvider.ts b/front_end/panels/ai_chat/LLM/LLMProvider.ts new file mode 100644 index 00000000000..6235d8f5f58 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/LLMProvider.ts @@ -0,0 +1,94 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { LLMMessage, LLMResponse, LLMCallOptions, LLMProvider as LLMProviderType, ModelInfo } from './LLMTypes.js'; + +/** + * Base interface that all LLM providers must implement + */ +export interface LLMProviderInterface { + /** Provider name/type */ + readonly name: LLMProviderType; + + /** + * Execute a chat completion request with messages + */ + callWithMessages( + modelName: string, + messages: LLMMessage[], + options?: LLMCallOptions + ): Promise; + + /** + * Simple call method for backward compatibility + */ + call( + modelName: string, + prompt: string, + systemPrompt: string, + options?: LLMCallOptions + ): Promise; + + /** + * Get all models supported by this provider + */ + getModels(): Promise; + + /** + * Parse response into standardized action structure + */ + parseResponse(response: LLMResponse): any; + + /** + * Test connection to a specific model (optional) + */ + testConnection?(modelId: string): Promise<{success: boolean, message: string}>; +} + +/** + * Abstract base class providing common functionality for providers + */ +export abstract class LLMBaseProvider implements LLMProviderInterface { + abstract readonly name: LLMProviderType; + + constructor(protected config: any = {}) {} + + abstract callWithMessages( + modelName: string, + messages: LLMMessage[], + options?: LLMCallOptions + ): Promise; + + abstract call( + modelName: string, + prompt: string, + systemPrompt: string, + options?: LLMCallOptions + ): Promise; + + abstract getModels(): Promise; + + abstract parseResponse(response: LLMResponse): any; + + /** + * Helper method to handle provider-specific errors + */ + protected handleProviderError(error: any, context: string): Error { + if (error instanceof Error) { + return error; + } + + // Handle fetch errors + if (error.name === 'TypeError' && error.message.includes('fetch')) { + return new Error(`Network error in ${context}: ${error.message}`); + } + + // Handle HTTP errors + if (error.status) { + return new Error(`HTTP ${error.status} error in ${context}: ${error.message || 'Unknown error'}`); + } + + return new Error(`Unknown error in ${context}: ${String(error)}`); + } +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts b/front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts new file mode 100644 index 00000000000..89d0f441067 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/LLMProviderRegistry.ts @@ -0,0 +1,106 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { createLogger } from '../core/Logger.js'; +import type { LLMProviderInterface } from './LLMProvider.js'; +import type { LLMProvider, ModelInfo } from './LLMTypes.js'; + +const logger = createLogger('LLMProviderRegistry'); + +/** + * Registry for managing LLM providers with distributed model ownership + */ +export class LLMProviderRegistry { + private static providers = new Map(); + + /** + * Register a provider instance + */ + static registerProvider(providerType: LLMProvider, providerInstance: LLMProviderInterface): void { + logger.info(`Registering provider: ${providerType}`); + this.providers.set(providerType, providerInstance); + } + + /** + * Get a provider by type + */ + static getProvider(providerType: LLMProvider): LLMProviderInterface | undefined { + return this.providers.get(providerType); + } + + /** + * Check if a provider is registered + */ + static hasProvider(providerType: LLMProvider): boolean { + return this.providers.has(providerType); + } + + /** + * Get all models from all registered providers + */ + static async getAllModels(): Promise { + const allModels: ModelInfo[] = []; + + for (const [providerType, provider] of this.providers.entries()) { + try { + const providerModels = await provider.getModels(); + allModels.push(...providerModels); + logger.debug(`Got ${providerModels.length} models from ${providerType}`); + } catch (error) { + logger.warn(`Failed to get models from ${providerType}:`, error); + } + } + + logger.info(`Total models available: ${allModels.length}`); + return allModels; + } + + /** + * Get models for a specific provider + */ + static async getModelsByProvider(providerType: LLMProvider): Promise { + const provider = this.getProvider(providerType); + if (!provider) { + logger.warn(`Provider ${providerType} not registered`); + return []; + } + + try { + const models = await provider.getModels(); + logger.debug(`Got ${models.length} models from ${providerType}`); + return models; + } catch (error) { + logger.error(`Failed to get models from ${providerType}:`, error); + return []; + } + } + + /** + * Get all registered provider types + */ + static getRegisteredProviders(): LLMProvider[] { + return Array.from(this.providers.keys()); + } + + /** + * Clear all registrations (useful for testing) + */ + static clear(): void { + this.providers.clear(); + logger.info('LLM Provider Registry cleared'); + } + + /** + * Get registry statistics + */ + static getStats(): { + providersCount: number; + providers: LLMProvider[]; + } { + return { + providersCount: this.providers.size, + providers: Array.from(this.providers.keys()), + }; + } +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/LLM/LLMResponseParser.ts b/front_end/panels/ai_chat/LLM/LLMResponseParser.ts new file mode 100644 index 00000000000..08f49924511 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/LLMResponseParser.ts @@ -0,0 +1,306 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import { createLogger } from '../core/Logger.js'; +import type { UnifiedLLMResponse, ParsedLLMAction } from './LLMTypes.js'; + +const logger = createLogger('LLMResponseParser'); + +/** + * Utility class for parsing and processing LLM responses + */ +export class LLMResponseParser { + /** + * Parse strict JSON from LLM response, handling common formatting issues + */ + static parseStrictJSON(text: string): any { + // Trim whitespace + let jsonText = text.trim(); + + // Remove markdown code blocks if present + if (jsonText.startsWith('```json')) { + jsonText = jsonText.replace(/^```json\s*/, '').replace(/\s*```$/, ''); + } else if (jsonText.startsWith('```')) { + jsonText = jsonText.replace(/^```\s*/, '').replace(/\s*```$/, ''); + } + + // Remove any leading/trailing text that's not part of JSON + const jsonMatch = jsonText.match(/\{.*\}/s) || jsonText.match(/\[.*\]/s); + if (jsonMatch) { + jsonText = jsonMatch[0]; + } + + // Try to parse + try { + return JSON.parse(jsonText); + } catch (error) { + // Log the problematic text for debugging + logger.error('Failed to parse JSON after cleanup:', { + original: text, + cleaned: jsonText, + error: error instanceof Error ? error.message : String(error), + }); + throw new Error(`Unable to parse JSON: ${error instanceof Error ? error.message : String(error)}`); + } + } + + /** + * Parse unified response to determine action type + * Equivalent to OpenAIClient.parseOpenAIResponse + */ + static parseResponse(response: UnifiedLLMResponse): ParsedLLMAction { + // Check for function calls first + if (response.functionCall) { + return { + type: 'tool_call', + name: response.functionCall.name, + args: response.functionCall.arguments, + }; + } + + // Process text response + if (response.text) { + const rawContent = response.text; + + // Attempt to parse text as JSON tool call (fallback for some models) + if (rawContent.trim().startsWith('{') && rawContent.includes('"action":"tool"')) { + try { + const contentJson = JSON.parse(rawContent); + if (contentJson.action === 'tool' && contentJson.toolName) { + return { + type: 'tool_call', + name: contentJson.toolName, + args: contentJson.toolArgs || {}, + }; + } + // Fallback to treating it as text if JSON structure is not a valid tool call + return { type: 'final_answer', answer: rawContent }; + } catch (e) { + // If JSON parsing fails, treat it as plain text + return { type: 'final_answer', answer: rawContent }; + } + } else { + // Treat as plain text final answer + return { type: 'final_answer', answer: rawContent }; + } + } + + return { + type: 'error', + error: 'No valid response from LLM', + }; + } + + /** + * Enhanced JSON parsing with multiple fallback strategies + */ + static parseJSONWithFallbacks(text: string): any { + const strategies = [ + // Strategy 1: Direct parsing + () => JSON.parse(text), + + // Strategy 2: Trim and parse + () => JSON.parse(text.trim()), + + // Strategy 3: Remove markdown code blocks + () => { + let cleaned = text.trim(); + if (cleaned.startsWith('```json')) { + cleaned = cleaned.replace(/^```json\s*/, '').replace(/\s*```$/, ''); + } else if (cleaned.startsWith('```')) { + cleaned = cleaned.replace(/^```\s*/, '').replace(/\s*```$/, ''); + } + return JSON.parse(cleaned); + }, + + // Strategy 4: Extract JSON from text + () => { + const jsonMatch = text.match(/\{.*\}/s) || text.match(/\[.*\]/s); + if (jsonMatch) { + return JSON.parse(jsonMatch[0]); + } + throw new Error('No JSON found in text'); + }, + + // Strategy 5: Fix common JSON issues + () => { + let fixed = text.trim(); + + // Fix single quotes to double quotes + fixed = fixed.replace(/'/g, '"'); + + // Fix trailing commas + fixed = fixed.replace(/,(\s*[}\]])/g, '$1'); + + // Fix unquoted keys (basic attempt) + fixed = fixed.replace(/(\w+):/g, '"$1":'); + + return JSON.parse(fixed); + }, + ]; + + let lastError: Error | undefined; + + for (let i = 0; i < strategies.length; i++) { + try { + const result = strategies[i](); + if (i > 0) { + logger.warn(`JSON parsed using fallback strategy ${i + 1}`, { + originalText: text.substring(0, 100) + (text.length > 100 ? '...' : ''), + strategy: i + 1, + }); + } + return result; + } catch (error) { + lastError = error instanceof Error ? error : new Error(String(error)); + continue; + } + } + + // All strategies failed + logger.error('All JSON parsing strategies failed:', { + text: text.substring(0, 200) + (text.length > 200 ? '...' : ''), + lastError: lastError?.message, + }); + + throw new Error(`JSON parsing failed: ${lastError?.message || 'Unknown error'}`); + } + + /** + * Validate and clean JSON response for strict mode + */ + static validateStrictJSON(text: string): { isValid: boolean; cleaned?: string; error?: string } { + try { + // Try direct parsing first + JSON.parse(text.trim()); + return { isValid: true, cleaned: text.trim() }; + } catch (directError) { + try { + // Try with fallback strategies + const parsed = this.parseJSONWithFallbacks(text); + const cleaned = JSON.stringify(parsed); + return { isValid: true, cleaned }; + } catch (fallbackError) { + return { + isValid: false, + error: fallbackError instanceof Error ? fallbackError.message : String(fallbackError), + }; + } + } + } + + /** + * Extract structured data from free-form text response + */ + static extractStructuredData(text: string, expectedFields: string[]): Record { + const result: Record = {}; + + // Try JSON parsing first + try { + const parsed = this.parseJSONWithFallbacks(text); + if (typeof parsed === 'object' && parsed !== null) { + return parsed; + } + } catch { + // Fall back to text extraction + } + + // Extract fields using pattern matching + for (const field of expectedFields) { + const patterns = [ + new RegExp(`"${field}"\\s*:\\s*"([^"]*)"`, 'i'), + new RegExp(`${field}\\s*:\\s*"([^"]*)"`, 'i'), + new RegExp(`${field}\\s*:\\s*([^,}\\n]*)`, 'i'), + ]; + + for (const pattern of patterns) { + const match = text.match(pattern); + if (match) { + result[field] = match[1].trim(); + break; + } + } + } + + return result; + } + + /** + * Enhance response with parsed structured data + */ + static enhanceResponse(response: UnifiedLLMResponse, options: { + strictJsonMode?: boolean; + expectedFields?: string[]; + } = {}): UnifiedLLMResponse { + const enhanced = { ...response }; + + if (options.strictJsonMode && response.text) { + try { + enhanced.parsedJson = this.parseStrictJSON(response.text); + } catch (error) { + logger.error('Strict JSON parsing failed:', { + error: error instanceof Error ? error.message : String(error), + responseText: response.text, + }); + // Don't throw here, just log the error + } + } + + if (options.expectedFields && response.text) { + try { + const structuredData = this.extractStructuredData(response.text, options.expectedFields); + if (Object.keys(structuredData).length > 0) { + enhanced.parsedJson = { ...enhanced.parsedJson, ...structuredData }; + } + } catch (error) { + logger.warn('Structured data extraction failed:', { + error: error instanceof Error ? error.message : String(error), + }); + } + } + + return enhanced; + } + + /** + * Check if response appears to be valid JSON + */ + static isValidJSON(text: string): boolean { + try { + JSON.parse(text.trim()); + return true; + } catch { + return false; + } + } + + /** + * Get JSON parsing suggestions for failed responses + */ + static getJSONParsingSuggestions(text: string): string[] { + const suggestions: string[] = []; + + if (!text.trim().startsWith('{') && !text.trim().startsWith('[')) { + suggestions.push('Response should start with { or ['); + } + + if (!text.trim().endsWith('}') && !text.trim().endsWith(']')) { + suggestions.push('Response should end with } or ]'); + } + + if (text.includes("'")) { + suggestions.push('Use double quotes (") instead of single quotes (\')'); + } + + if (text.match(/,(\s*[}\]])/)) { + suggestions.push('Remove trailing commas before } or ]'); + } + + if (text.match(/\w+:/)) { + suggestions.push('Ensure all object keys are quoted'); + } + + return suggestions; + } +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/LLM/LLMTypes.ts b/front_end/panels/ai_chat/LLM/LLMTypes.ts new file mode 100644 index 00000000000..6763a8b7ef9 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/LLMTypes.ts @@ -0,0 +1,228 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** + * Core type definitions and interfaces for the unified LLM client system. + * This file contains all shared types used across the LLM infrastructure. + */ + +/** + * Error types that can occur during LLM calls + */ +export enum LLMErrorType { + JSON_PARSE_ERROR = 'JSON_PARSE_ERROR', + RATE_LIMIT_ERROR = 'RATE_LIMIT_ERROR', + NETWORK_ERROR = 'NETWORK_ERROR', + SERVER_ERROR = 'SERVER_ERROR', + AUTH_ERROR = 'AUTH_ERROR', + QUOTA_ERROR = 'QUOTA_ERROR', + UNKNOWN_ERROR = 'UNKNOWN_ERROR', +} + +/** + * Retry configuration for specific error types + */ +export interface RetryConfig { + maxRetries: number; + baseDelayMs: number; + maxDelayMs: number; + backoffMultiplier: number; + jitterMs: number; +} + +/** + * Unified options for LLM calls that work across different providers + */ +export interface UnifiedLLMOptions { + // Core LLM parameters + systemPrompt: string; // Required - all calls must have context + maxTokens?: number; + temperature?: number; + topP?: number; + frequencyPenalty?: number; + presencePenalty?: number; + responseFormat?: any; + n?: number; + stream?: boolean; + + // Connection and timeout settings + endpoint?: string; + timeout?: number; + signal?: AbortSignal; + + // Tool usage (for function calling) + tools?: any[]; + tool_choice?: any; + + // Feature flags + strictJsonMode?: boolean; // Enables strict JSON parsing with retries + + // Retry configuration override + customRetryConfig?: Partial; + + // Legacy compatibility (deprecated - use customRetryConfig instead) + maxRetries?: number; +} + +/** + * Unified response that includes function calls and parsed data + */ +export interface UnifiedLLMResponse { + text?: string; + functionCall?: { + name: string; + arguments: any; + }; + rawResponse?: any; + reasoning?: { + summary?: string[] | null; + effort?: string; + }; + parsedJson?: any; // Parsed JSON when strictJsonMode is enabled +} + +/** + * Model configuration from localStorage + */ +export interface ModelOption { + value: string; + type: 'openai' | 'litellm'; + label?: string; +} + +/** + * Standardized structure for parsed LLM actions + */ +export type ParsedLLMAction = + | { type: 'tool_call'; name: string; args: Record } + | { type: 'final_answer'; answer: string } + | { type: 'error'; error: string }; + +/** + * Configuration for error-specific retry behavior + */ +export interface ErrorRetryConfig { + [LLMErrorType.RATE_LIMIT_ERROR]?: RetryConfig; + [LLMErrorType.NETWORK_ERROR]?: RetryConfig; + [LLMErrorType.JSON_PARSE_ERROR]?: RetryConfig; + [LLMErrorType.SERVER_ERROR]?: RetryConfig; + [LLMErrorType.AUTH_ERROR]?: RetryConfig; + [LLMErrorType.QUOTA_ERROR]?: RetryConfig; + [LLMErrorType.UNKNOWN_ERROR]?: RetryConfig; +} + +/** + * Callback for retry events (useful for logging and monitoring) + */ +export interface RetryCallback { + (attempt: number, error: Error, errorType: LLMErrorType, delayMs: number): void; +} + +/** + * Extended retry configuration with callbacks and custom settings + */ +export interface ExtendedRetryConfig extends ErrorRetryConfig { + // Default configuration for unspecified error types + defaultConfig?: RetryConfig; + + // Global callback for all retry events + onRetry?: RetryCallback; + + // Maximum total time to spend on retries (across all attempts) + maxTotalTimeMs?: number; + + // Whether to enable retry logging + enableLogging?: boolean; +} + +/** + * LLM Provider types + */ +export type LLMProvider = 'openai' | 'litellm'; + +/** + * Content types for multimodal messages (text + images) + */ +export type MessageContent = + | string + | Array; + +export interface TextContent { + type: 'text'; + text: string; +} + +export interface ImageContent { + type: 'image_url'; + image_url: { + url: string; // Can be URL or base64 data URL + detail?: 'low' | 'high' | 'auto'; + }; +} + +/** + * Message format compatible with OpenAI and LiteLLM APIs + * Supports both text-only and multimodal (text + images) content + */ +export interface LLMMessage { + role: 'system' | 'user' | 'assistant' | 'tool'; + content?: MessageContent; + tool_calls?: Array<{ + id: string; + type: 'function'; + function: { + name: string; + arguments: string; + }; + }>; + tool_call_id?: string; + name?: string; +} + +/** + * Options for LLM calls + */ +export interface LLMCallOptions { + tools?: any[]; + tool_choice?: any; + temperature?: number; + reasoningLevel?: 'low' | 'medium' | 'high'; // For O-series models + retryConfig?: Partial; +} + +/** + * Unified LLM response format + */ +export interface LLMResponse { + text?: string; + functionCall?: { + name: string; + arguments: any; + }; + rawResponse: any; + reasoning?: { + summary?: string[] | null; + effort?: string; + }; +} + +/** + * Model capabilities + */ +export interface ModelCapabilities { + functionCalling: boolean; + reasoning: boolean; + vision: boolean; + structured: boolean; +} + +/** + * Model information with provider and capabilities + */ +export interface ModelInfo { + id: string; + name: string; + provider: LLMProvider; + capabilities?: ModelCapabilities; +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/LLM/LiteLLMProvider.ts b/front_end/panels/ai_chat/LLM/LiteLLMProvider.ts new file mode 100644 index 00000000000..a86fc1246e4 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/LiteLLMProvider.ts @@ -0,0 +1,377 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { LLMMessage, LLMResponse, LLMCallOptions, LLMProvider, ModelInfo } from './LLMTypes.js'; +import { LLMBaseProvider } from './LLMProvider.js'; +import { LLMRetryManager } from './LLMErrorHandler.js'; +import { LLMResponseParser } from './LLMResponseParser.js'; +import { createLogger } from '../core/Logger.js'; + +const logger = createLogger('LiteLLMProvider'); + +/** + * LiteLLM model information from /v1/models endpoint + */ +export interface LiteLLMModel { + id: string; + object: string; + created?: number; + owned_by?: string; +} + +export interface LiteLLMModelsResponse { + object: string; + data: LiteLLMModel[]; +} + +/** + * LiteLLM provider implementation using OpenAI-compatible format + */ +export class LiteLLMProvider extends LLMBaseProvider { + private static readonly DEFAULT_BASE_URL = 'http://localhost:4000'; + private static readonly CHAT_COMPLETIONS_PATH = '/v1/chat/completions'; + private static readonly MODELS_PATH = '/v1/models'; + + readonly name: LLMProvider = 'litellm'; + + constructor( + private readonly apiKey: string | null, + private readonly baseUrl?: string + ) { + super(); + } + + /** + * Constructs the full endpoint URL based on configuration + */ + private getEndpoint(): string { + // Check if we have a valid baseUrl + if (!this.baseUrl) { + // Check localStorage as a fallback for endpoint + const localStorageEndpoint = localStorage.getItem('ai_chat_litellm_endpoint'); + if (!localStorageEndpoint) { + throw new Error('LiteLLM endpoint not configured. Please set endpoint in settings.'); + } + logger.debug(`Using endpoint from localStorage: ${localStorageEndpoint}`); + const baseUrl = localStorageEndpoint.replace(/\/$/, ''); + return `${baseUrl}${LiteLLMProvider.CHAT_COMPLETIONS_PATH}`; + } + + // Remove trailing slash from base URL if present + const cleanBaseUrl = this.baseUrl.replace(/\/$/, ''); + return `${cleanBaseUrl}${LiteLLMProvider.CHAT_COMPLETIONS_PATH}`; + } + + /** + * Gets the models endpoint URL + */ + private getModelsEndpoint(): string { + const baseEndpoint = this.baseUrl || LiteLLMProvider.DEFAULT_BASE_URL; + return `${baseEndpoint.replace(/\/$/, '')}${LiteLLMProvider.MODELS_PATH}`; + } + + /** + * Converts LLMMessage format to OpenAI format + */ + private convertMessagesToOpenAI(messages: LLMMessage[]): any[] { + return messages.map(msg => ({ + role: msg.role, + content: msg.content, + ...(msg.tool_calls && { tool_calls: msg.tool_calls }), + ...(msg.tool_call_id && { tool_call_id: msg.tool_call_id }), + ...(msg.name && { name: msg.name }) + })); + } + + /** + * Makes a request to the LiteLLM API + */ + private async makeAPIRequest(payloadBody: any): Promise { + try { + const endpoint = this.getEndpoint(); + logger.debug('Using endpoint:', endpoint); + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(this.apiKey ? { Authorization: `Bearer ${this.apiKey}` } : {}), + }, + body: JSON.stringify(payloadBody), + }); + + if (!response.ok) { + const errorData = await response.json(); + logger.error('LiteLLM API error:', errorData); + throw new Error(`LiteLLM API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); + } + + const data = await response.json(); + logger.info('LiteLLM Response:', data); + + if (data.usage) { + logger.info('LiteLLM Usage:', { inputTokens: data.usage.prompt_tokens, outputTokens: data.usage.completion_tokens }); + } + + return data; + } catch (error) { + logger.error('LiteLLM API request failed:', error); + throw error; + } + } + + /** + * Processes the LiteLLM response and converts to LLMResponse format + */ + private processLiteLLMResponse(data: any): LLMResponse { + const result: LLMResponse = { + rawResponse: data + }; + + if (!data?.choices || data.choices.length === 0) { + throw new Error('No choices in LiteLLM response'); + } + + const choice = data.choices[0]; + const message = choice.message; + + if (!message) { + throw new Error('No message in LiteLLM choice'); + } + + // Check for tool calls + if (message.tool_calls && message.tool_calls.length > 0) { + const toolCall = message.tool_calls[0]; + if (toolCall.function) { + try { + result.functionCall = { + name: toolCall.function.name, + arguments: JSON.parse(toolCall.function.arguments) + }; + } catch (error) { + logger.error('Error parsing function arguments:', error); + result.functionCall = { + name: toolCall.function.name, + arguments: toolCall.function.arguments // Keep as string if parsing fails + }; + } + } + } else if (message.content) { + // Plain text response + result.text = message.content.trim(); + } + + return result; + } + + /** + * Call the LiteLLM API with messages + */ + async callWithMessages( + modelName: string, + messages: LLMMessage[], + options?: LLMCallOptions + ): Promise { + return LLMRetryManager.simpleRetry(async () => { + logger.debug('Calling LiteLLM with messages...', { model: modelName, messageCount: messages.length }); + + // Construct payload body in OpenAI format (LiteLLM is OpenAI-compatible) + const payloadBody: any = { + model: modelName, + messages: this.convertMessagesToOpenAI(messages), // Direct OpenAI format - no conversion needed! + }; + + // Add temperature if provided + if (options?.temperature !== undefined) { + payloadBody.temperature = options.temperature; + } + + // Add tools if provided + if (options?.tools) { + // Ensure all tools have valid parameters + payloadBody.tools = options.tools.map(tool => { + if (tool.type === 'function' && tool.function) { + return { + ...tool, + function: { + ...tool.function, + parameters: tool.function.parameters || { type: 'object', properties: {} } + } + }; + } + return tool; + }); + } + + // Add tool_choice if provided + if (options?.tool_choice) { + payloadBody.tool_choice = options.tool_choice; + } + + logger.info('Request payload:', payloadBody); + + const data = await this.makeAPIRequest(payloadBody); + return this.processLiteLLMResponse(data); + }, options?.retryConfig); + } + + /** + * Simple call method for backward compatibility + */ + async call( + modelName: string, + prompt: string, + systemPrompt: string, + options?: LLMCallOptions + ): Promise { + const messages: LLMMessage[] = []; + + if (systemPrompt) { + messages.push({ + role: 'system', + content: systemPrompt + }); + } + + messages.push({ + role: 'user', + content: prompt + }); + + return this.callWithMessages(modelName, messages, options); + } + + /** + * Parse response into standardized action structure + */ + parseResponse(response: LLMResponse): ReturnType { + return LLMResponseParser.parseResponse(response); + } + + /** + * Fetch available models from LiteLLM endpoint + */ + async fetchModels(): Promise { + logger.debug('Fetching available models...'); + + try { + const modelsUrl = this.getModelsEndpoint(); + logger.debug('Using models endpoint:', modelsUrl); + + const response = await fetch(modelsUrl, { + method: 'GET', + headers: { + ...(this.apiKey ? { Authorization: `Bearer ${this.apiKey}` } : {}), + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + logger.error('LiteLLM models API error:', errorData); + throw new Error(`LiteLLM models API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); + } + + const data: LiteLLMModelsResponse = await response.json(); + logger.debug('LiteLLM Models Response:', data); + + if (!data?.data || !Array.isArray(data.data)) { + throw new Error('Invalid models response format'); + } + + return data.data; + } catch (error) { + logger.error('Failed to fetch LiteLLM models:', error); + throw error; + } + } + + /** + * Get all models supported by this provider + */ + async getModels(): Promise { + const models: ModelInfo[] = []; + + try { + // Fetch models from LiteLLM API + const fetchedModels = await this.fetchModels(); + for (const model of fetchedModels) { + models.push({ + id: model.id, + name: model.id, // Use ID as name for LiteLLM models + provider: 'litellm', + capabilities: { + functionCalling: true, + reasoning: false, + vision: false, + structured: true + } + }); + } + } catch (error) { + logger.warn('Failed to fetch models from LiteLLM API:', error); + } + + // Add custom models from localStorage + try { + const customModelsJson = localStorage.getItem('ai_chat_custom_models'); + if (customModelsJson) { + const customModels = JSON.parse(customModelsJson); + if (Array.isArray(customModels)) { + for (const customModel of customModels) { + if (customModel.id && customModel.name) { + models.push({ + id: customModel.id, + name: customModel.name, + provider: 'litellm', + capabilities: { + functionCalling: true, + reasoning: false, + vision: false, + structured: true + } + }); + } + } + } + } + } catch (error) { + logger.warn('Failed to load custom models from localStorage:', error); + } + + logger.debug(`LiteLLM Provider returning ${models.length} models`); + return models; + } + + /** + * Test the LiteLLM connection with a simple completion request + */ + async testConnection(modelName: string): Promise<{success: boolean, message: string}> { + logger.debug('Testing connection...'); + + try { + const testPrompt = 'Please respond with "Connection successful!" to confirm the connection is working.'; + + const response = await this.call(modelName, testPrompt, '', { + temperature: 0.1, + }); + + if (response.text?.toLowerCase().includes('connection')) { + return { + success: true, + message: `Successfully connected to LiteLLM with model ${modelName}`, + }; + } + return { + success: true, + message: `Connected to LiteLLM, but received unexpected response: ${response.text || 'No response'}`, + }; + } catch (error) { + logger.error('LiteLLM connection test failed:', error); + return { + success: false, + message: error instanceof Error ? error.message : 'Unknown error occurred', + }; + } + } +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/LLM/OpenAIProvider.ts b/front_end/panels/ai_chat/LLM/OpenAIProvider.ts new file mode 100644 index 00000000000..9546c34c242 --- /dev/null +++ b/front_end/panels/ai_chat/LLM/OpenAIProvider.ts @@ -0,0 +1,427 @@ +// Copyright 2025 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import type { LLMMessage, LLMResponse, LLMCallOptions, LLMProvider, ModelInfo, MessageContent } from './LLMTypes.js'; +import { LLMBaseProvider } from './LLMProvider.js'; +import { LLMRetryManager } from './LLMErrorHandler.js'; +import { LLMResponseParser } from './LLMResponseParser.js'; +import { createLogger } from '../core/Logger.js'; + +const logger = createLogger('OpenAIProvider'); + +/** + * Enum to distinguish between model families with different request/response formats + */ +enum ModelFamily { + GPT = 'gpt', + O = 'o' +} + +/** + * Responses API message format for tool calls and results + */ +interface ResponsesAPIFunctionCall { + type: 'function_call'; + name: string; + arguments: string; + call_id: string; +} + +interface ResponsesAPIFunctionOutput { + type: 'function_call_output'; + call_id: string; + output: string; +} + +/** + * OpenAI provider implementation using the Responses API + */ +export class OpenAIProvider extends LLMBaseProvider { + private static readonly API_ENDPOINT = 'https://api.openai.com/v1/responses'; + + readonly name: LLMProvider = 'openai'; + + constructor(private readonly apiKey: string) { + super(); + } + + /** + * Determines the model family based on the model name + */ + private getModelFamily(modelName: string): ModelFamily { + // Check if model name starts with 'o' to identify O series models + if (modelName.startsWith('o')) { + return ModelFamily.O; + } + // Otherwise, assume it's a GPT model (gpt-3.5-turbo, gpt-4, etc.) + return ModelFamily.GPT; + } + + /** + * Converts tools from standard format to responses API format + */ + private convertToolsFormat(tools: any[]): any[] { + return tools.map(tool => { + if (tool.type === 'function' && tool.function) { + // Convert from standard format to responses API format + return { + type: 'function', + name: tool.function.name, + description: tool.function.description, + parameters: tool.function.parameters || { type: 'object', properties: {} } + }; + } + return tool; // Return as-is if already in correct format + }); + } + + /** + * Convert MessageContent to Responses API format based on model family + * Throws error if conversion fails + */ + private convertContentToResponsesAPI(content: MessageContent | undefined, modelFamily: ModelFamily): any { + // For GPT models, return simple string content + if (modelFamily === ModelFamily.GPT) { + if (!content) { + return ''; + } + if (typeof content === 'string') { + return content; + } + // For multimodal content on GPT models, we need to return the structured format + if (Array.isArray(content)) { + // Return as OpenAI Chat API format for GPT models + return content.map((item, index) => { + if (item.type === 'text') { + return { type: 'text', text: item.text }; + } else if (item.type === 'image_url') { + if (!item.image_url?.url) { + throw new Error(`Invalid image content at index ${index}: missing image_url.url`); + } + return { type: 'image_url', image_url: item.image_url }; + } else { + throw new Error(`Unknown content type at index ${index}: ${(item as any).type}`); + } + }); + } + return String(content); + } + + // For O-series models, use structured responses API format + if (!content) { + return [{ type: 'input_text', text: '' }]; + } + + if (typeof content === 'string') { + return [{ type: 'input_text', text: content }]; + } + + if (Array.isArray(content)) { + return content.map((item, index) => { + if (item.type === 'text') { + return { type: 'input_text', text: item.text }; + } else if (item.type === 'image_url') { + if (!item.image_url?.url) { + throw new Error(`Invalid image content at index ${index}: missing image_url.url`); + } + // O-series uses different image format + return { type: 'input_image', image_url: item.image_url.url }; + } else { + throw new Error(`Unknown content type at index ${index}: ${(item as any).type}`); + } + }); + } + + throw new Error(`Invalid content type: expected string or array, got ${typeof content}`); + } + + /** + * Converts messages to responses API format based on model family + */ + private convertMessagesToResponsesAPI(messages: LLMMessage[], modelFamily: ModelFamily): any[] { + try { + return messages.map((msg, index) => { + if (msg.role === 'system' || msg.role === 'user') { + return { + role: msg.role, + content: this.convertContentToResponsesAPI(msg.content, modelFamily) + }; + } else if (msg.role === 'assistant') { + if (msg.tool_calls && msg.tool_calls.length > 0) { + // Convert tool calls to responses API format + const toolCall = msg.tool_calls[0]; // Take first tool call + let argsString: string; + + // Ensure arguments are in string format for responses API + if (typeof toolCall.function.arguments === 'string') { + argsString = toolCall.function.arguments; + } else { + argsString = JSON.stringify(toolCall.function.arguments); + } + + return { + type: 'function_call', + name: toolCall.function.name, + arguments: argsString, + call_id: toolCall.id + } as ResponsesAPIFunctionCall; + } else { + // Regular assistant message with content + // For O-series models, assistant content uses 'output_text' + if (modelFamily === ModelFamily.O) { + const content = typeof msg.content === 'string' ? msg.content : + Array.isArray(msg.content) ? msg.content.map(c => c.type === 'text' ? c.text : '').join('') : + String(msg.content || ''); + return { + role: 'assistant', + content: [{ type: 'output_text', text: content }] + }; + } else { + // For GPT models, use simple content format + return { + role: 'assistant', + content: this.convertContentToResponsesAPI(msg.content, modelFamily) + }; + } + } + } else if (msg.role === 'tool') { + // Convert tool result to responses API format + // Tool responses are always text in the current implementation + return { + type: 'function_call_output', + call_id: msg.tool_call_id, + output: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) + } as ResponsesAPIFunctionOutput; + } + + throw new Error(`Unknown message role at index ${index}: ${msg.role}`); + }); + } catch (error) { + logger.error('Failed to convert messages to Responses API format:', error); + throw new Error(`Message conversion failed: ${error instanceof Error ? error.message : String(error)}`); + } + } + + /** + * Processes the responses API output and extracts relevant information + */ + private processResponsesAPIOutput(data: any): LLMResponse { + const result: LLMResponse = { + rawResponse: data + }; + + // Extract reasoning info if available (O models) + if (data.reasoning) { + result.reasoning = { + summary: data.reasoning.summary, + effort: data.reasoning.effort + }; + } + + if (!data?.output) { + throw new Error('No output from OpenAI'); + } + + if (data.output && data.output.length > 0) { + // Find function call or message by type instead of assuming position + const functionCallOutput = data.output.find((item: any) => item.type === 'function_call'); + const messageOutput = data.output.find((item: any) => item.type === 'message'); + + if (functionCallOutput) { + // Process function call + try { + result.functionCall = { + name: functionCallOutput.name, + arguments: JSON.parse(functionCallOutput.arguments) + }; + } catch (error) { + logger.error('Error parsing function arguments:', error); + result.functionCall = { + name: functionCallOutput.name, + arguments: functionCallOutput.arguments // Keep as string if parsing fails + }; + } + } + else if (messageOutput?.content && messageOutput.content.length > 0 && messageOutput.content[0].type === 'output_text') { + // Process text response + result.text = messageOutput.content[0].text.trim(); + } + } + + return result; + } + + /** + * Makes a request to the OpenAI Responses API + */ + private async makeAPIRequest(payloadBody: any): Promise { + try { + const response = await fetch(OpenAIProvider.API_ENDPOINT, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify(payloadBody), + }); + + if (!response.ok) { + const errorData = await response.json(); + logger.error('OpenAI API error:', errorData); + throw new Error(`OpenAI API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); + } + + const data = await response.json(); + logger.info('OpenAI Response:', data); + + if (data.usage) { + logger.info('OpenAI Usage:', { inputTokens: data.usage.input_tokens, outputTokens: data.usage.output_tokens }); + } + + return data; + } catch (error) { + logger.error('OpenAI API request failed:', error); + throw error; + } + } + + /** + * Call the OpenAI API with messages + */ + async callWithMessages( + modelName: string, + messages: LLMMessage[], + options?: LLMCallOptions + ): Promise { + return LLMRetryManager.simpleRetry(async () => { + logger.debug('Calling OpenAI responses API...', { model: modelName, messageCount: messages.length }); + + // Determine model family + const modelFamily = this.getModelFamily(modelName); + logger.debug('Model Family:', modelFamily); + + // Construct payload body for responses API format + const payloadBody: any = { + model: modelName, + }; + + // Convert messages to responses API format + const convertedMessages = this.convertMessagesToResponsesAPI(messages, modelFamily); + payloadBody.input = convertedMessages; + + // Add temperature if provided, but not for O models (they don't support it) + if (options?.temperature !== undefined && modelFamily !== ModelFamily.O) { + payloadBody.temperature = options.temperature; + } + + // Add tools if provided - convert from standard format to responses API format + if (options?.tools) { + payloadBody.tools = this.convertToolsFormat(options.tools); + } + + // Add tool_choice if provided + if (options?.tool_choice) { + payloadBody.tool_choice = options.tool_choice; + } + + // Add reasoning level for O-series model if provided + if (options?.reasoningLevel && modelFamily === ModelFamily.O) { + payloadBody.reasoning = { + effort: options.reasoningLevel + }; + } + + logger.info('Request payload:', payloadBody); + + const data = await this.makeAPIRequest(payloadBody); + return this.processResponsesAPIOutput(data); + }, options?.retryConfig); + } + + /** + * Simple call method for backward compatibility + */ + async call( + modelName: string, + prompt: string, + systemPrompt: string, + options?: LLMCallOptions + ): Promise { + const messages: LLMMessage[] = []; + + if (systemPrompt) { + messages.push({ + role: 'system', + content: systemPrompt + }); + } + + messages.push({ + role: 'user', + content: prompt + }); + + return this.callWithMessages(modelName, messages, options); + } + + /** + * Get all OpenAI models supported by this provider + */ + async getModels(): Promise { + // Return hardcoded OpenAI models with their capabilities + return [ + { + id: 'gpt-4.1-2025-04-14', + name: 'GPT-4.1', + provider: 'openai', + capabilities: { + functionCalling: true, + reasoning: false, + vision: true, + structured: true + } + }, + { + id: 'gpt-4.1-mini-2025-04-14', + name: 'GPT-4.1 Mini', + provider: 'openai', + capabilities: { + functionCalling: true, + reasoning: false, + vision: true, + structured: true + } + }, + { + id: 'gpt-4.1-nano-2025-04-14', + name: 'GPT-4.1 Nano', + provider: 'openai', + capabilities: { + functionCalling: true, + reasoning: false, + vision: true, + structured: true + } + }, + { + id: 'o4-mini-2025-04-16', + name: 'O4 Mini', + provider: 'openai', + capabilities: { + functionCalling: true, + reasoning: true, + vision: true, + structured: true + } + } + ]; + } + + /** + * Parse response into standardized action structure + */ + parseResponse(response: LLMResponse): ReturnType { + return LLMResponseParser.parseResponse(response); + } +} \ No newline at end of file diff --git a/front_end/panels/ai_chat/agent_framework/AgentRunner.ts b/front_end/panels/ai_chat/agent_framework/AgentRunner.ts index 3999e256445..a6a17fdd5b4 100644 --- a/front_end/panels/ai_chat/agent_framework/AgentRunner.ts +++ b/front_end/panels/ai_chat/agent_framework/AgentRunner.ts @@ -3,7 +3,8 @@ // found in the LICENSE file. import { enhancePromptWithPageContext } from '../core/PageInfoManager.js'; -import { UnifiedLLMClient, type UnifiedLLMResponse, type ParsedLLMAction } from '../core/UnifiedLLMClient.js'; +import { LLMClient } from '../LLM/LLMClient.js'; +import type { LLMResponse, ParsedLLMAction, LLMMessage, LLMProvider } from '../LLM/LLMTypes.js'; import type { Tool } from '../tools/Tools.js'; import { ChatMessageEntity, type ChatMessage, type ModelChatMessage, type ToolResultMessage } from '../ui/ChatView.js'; import { createLogger } from '../core/Logger.js'; @@ -47,6 +48,69 @@ function isConfigurableAgentResult(obj: any): obj is ConfigurableAgentResult { * Runs the core agent execution loop */ export class AgentRunner { + /** + * Helper function to convert ChatMessage[] to LLMMessage[] + */ + private static convertToLLMMessages(messages: ChatMessage[]): LLMMessage[] { + const llmMessages: LLMMessage[] = []; + + for (const msg of messages) { + if (msg.entity === ChatMessageEntity.USER) { + // User message + if ('text' in msg) { + llmMessages.push({ + role: 'user', + content: msg.text, + }); + } + } else if (msg.entity === ChatMessageEntity.MODEL) { + // Model message + const modelMsg = msg as ModelChatMessage; + if (modelMsg.action === 'final' && modelMsg.answer) { + llmMessages.push({ + role: 'assistant', + content: modelMsg.answer, + }); + } else if (modelMsg.action === 'tool' && modelMsg.toolCallId) { + // Tool call message + llmMessages.push({ + role: 'assistant', + content: undefined, + tool_calls: [{ + id: modelMsg.toolCallId, + type: 'function', + function: { + name: modelMsg.toolName || '', + arguments: JSON.stringify(modelMsg.toolArgs || {}), + } + }], + }); + } + } else if (msg.entity === ChatMessageEntity.TOOL_RESULT) { + // Tool result message + const toolResult = msg as ToolResultMessage; + if (toolResult.toolCallId && toolResult.resultText) { + llmMessages.push({ + role: 'tool', + content: toolResult.resultText, + tool_call_id: toolResult.toolCallId, + }); + } + } + } + + return llmMessages; + } + + /** + * Helper function to detect provider from user's settings + */ + private static detectProvider(modelName: string): LLMProvider { + // Respect user's provider selection from settings + const selectedProvider = localStorage.getItem('ai_chat_provider') || 'openai'; + return selectedProvider as LLMProvider; + } + // Helper function to execute the handoff logic (to avoid duplication) private static async executeHandoff( currentMessages: ChatMessage[], @@ -246,19 +310,22 @@ export class AgentRunner { // This includes updating the accessibility tree inside enhancePromptWithPageContext const currentSystemPrompt = await enhancePromptWithPageContext(systemPrompt + iterationInfo); - let llmResponse: UnifiedLLMResponse; + let llmResponse: LLMResponse; try { logger.info('${agentName} Calling LLM with ${messages.length} messages'); - llmResponse = await UnifiedLLMClient.callLLMWithMessages( - apiKey, - modelName, - messages, - { - tools: toolSchemas, - systemPrompt: currentSystemPrompt, - temperature: temperature ?? 0, - } - ); + + const llm = LLMClient.getInstance(); + const provider = AgentRunner.detectProvider(modelName); + const llmMessages = AgentRunner.convertToLLMMessages(messages); + + llmResponse = await llm.call({ + provider, + model: modelName, + messages: llmMessages, + systemPrompt: currentSystemPrompt, + tools: toolSchemas, + temperature: temperature ?? 0, + }); } catch (error: any) { logger.error(`${agentName} LLM call failed:`, error); const errorMsg = `LLM call failed: ${error.message || String(error)}`; @@ -276,7 +343,8 @@ export class AgentRunner { } // Parse LLM response - const parsedAction = UnifiedLLMClient.parseResponse(llmResponse); + const llm = LLMClient.getInstance(); + const parsedAction = llm.parseResponse(llmResponse); // Process parsed action try { diff --git a/front_end/panels/ai_chat/core/AgentNodes.ts b/front_end/panels/ai_chat/core/AgentNodes.ts index e24a4a6813b..b1042e1e79b 100644 --- a/front_end/panels/ai_chat/core/AgentNodes.ts +++ b/front_end/panels/ai_chat/core/AgentNodes.ts @@ -3,9 +3,10 @@ // found in the LICENSE file. import type { getTools } from '../tools/Tools.js'; -import { ChatMessageEntity, type ModelChatMessage, type ToolResultMessage } from '../ui/ChatView.js'; +import { ChatMessageEntity, type ModelChatMessage, type ToolResultMessage, type ChatMessage } from '../ui/ChatView.js'; -import type { Model } from './ChatOpenAI.js'; // Import Model interface +import { LLMClient } from '../LLM/LLMClient.js'; +import type { LLMMessage, LLMProvider } from '../LLM/LLMTypes.js'; import { createSystemPromptAsync, getAgentToolsFromState } from './GraphHelpers.js'; import { createLogger } from './Logger.js'; import type { AgentState } from './State.js'; @@ -13,12 +14,16 @@ import type { Runnable } from './Types.js'; const logger = createLogger('AgentNodes'); -export function createAgentNode(model: Model): Runnable { +export function createAgentNode(modelName: string, temperature: number): Runnable { const agentNode = new class AgentNode implements Runnable { - private model: Model; - - constructor(model: Model) { - this.model = model; + private modelName: string; + private temperature: number; + private callCount = 0; + private readonly MAX_CALLS_PER_INTERACTION = 50; + + constructor(modelName: string, temperature: number) { + this.modelName = modelName; + this.temperature = temperature; } async invoke(state: AgentState): Promise { @@ -28,7 +33,7 @@ export function createAgentNode(model: Model): Runnable // Reset call count on new user message const lastMessage = state.messages[state.messages.length - 1]; if (lastMessage?.entity === ChatMessageEntity.USER) { - this.model.resetCallCount(); + this.resetCallCount(); } if (lastMessage?.entity === ChatMessageEntity.TOOL_RESULT && lastMessage?.toolName === 'finalize_with_critique') { @@ -80,50 +85,172 @@ export function createAgentNode(model: Model): Runnable // 1. Create the enhanced system prompt based on the current state (including selected type) const systemPrompt = await createSystemPromptAsync(state); - // 2. Call the model with the message array directly instead of using ChatPromptFormatter - const response = await this.model.generateWithMessages(state.messages, systemPrompt, state); - logger.debug('AgentNode Response:', response); - const parsedAction = response.parsedAction!; - - // Directly create the ModelChatMessage object - let newModelMessage: ModelChatMessage; - if (parsedAction.action === 'tool') { - const toolCallId = crypto.randomUUID(); // Generate unique ID for OpenAI format - newModelMessage = { - entity: ChatMessageEntity.MODEL, - action: 'tool', - toolName: parsedAction.toolName, - toolArgs: parsedAction.toolArgs, - toolCallId, // Add for linking with tool response - isFinalAnswer: false, - reasoning: response.openAIReasoning?.summary, - }; + // 2. Call the LLM with the message array + this.callCount++; + + if (this.callCount > this.MAX_CALLS_PER_INTERACTION) { + logger.warn('Max calls per interaction reached:', this.callCount); + throw new Error(`Maximum calls (${this.MAX_CALLS_PER_INTERACTION}) per interaction exceeded. This might be an infinite loop.`); + } - logger.debug('AgentNode: Created tool message', { toolName: parsedAction.toolName, toolCallId }); - if (parsedAction.toolName === 'finalize_with_critique') { - logger.debug('AgentNode: finalize_with_critique call with args:', JSON.stringify(parsedAction.toolArgs)); + logger.debug('Generating response with LLMClient:', { + modelName: this.modelName, + callCount: this.callCount, + messageCount: state.messages.length, + }); + + try { + const llm = LLMClient.getInstance(); + + // Detect provider from model name + const provider = this.detectProvider(this.modelName); + + // Get tools for the current agent type + const tools = getAgentToolsFromState(state); + + // Convert ChatMessage[] to LLMMessage[] + const llmMessages = this.convertChatMessagesToLLMMessages(state.messages); + + // Call LLM with the new API + const response = await llm.call({ + provider, + model: this.modelName, + messages: llmMessages, + systemPrompt, + tools: tools.map(tool => ({ + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: tool.schema, + } + })), + temperature: this.temperature, + }); + + // Parse the response + const parsedAction = llm.parseResponse(response); + + // Directly create the ModelChatMessage object + let newModelMessage: ModelChatMessage; + if (parsedAction.type === 'tool_call') { + const toolCallId = crypto.randomUUID(); // Generate unique ID for OpenAI format + newModelMessage = { + entity: ChatMessageEntity.MODEL, + action: 'tool', + toolName: parsedAction.name, + toolArgs: parsedAction.args, + toolCallId, // Add for linking with tool response + isFinalAnswer: false, + reasoning: response.reasoning?.summary, + }; + + logger.debug('AgentNode: Created tool message', { toolName: parsedAction.name, toolCallId }); + if (parsedAction.name === 'finalize_with_critique') { + logger.debug('AgentNode: finalize_with_critique call with args:', JSON.stringify(parsedAction.args)); + } + } else if (parsedAction.type === 'final_answer') { + newModelMessage = { + entity: ChatMessageEntity.MODEL, + action: 'final', + answer: parsedAction.answer, + isFinalAnswer: true, + reasoning: response.reasoning?.summary, + }; + + logger.debug('AgentNode: Created final answer message'); + } else { + // Error case + newModelMessage = { + entity: ChatMessageEntity.MODEL, + action: 'final', + answer: parsedAction.error || 'An error occurred', + isFinalAnswer: true, + reasoning: response.reasoning?.summary, + }; + + logger.debug('AgentNode: Created error message'); } - } else { - newModelMessage = { - entity: ChatMessageEntity.MODEL, - action: 'final', - answer: parsedAction.answer, - isFinalAnswer: true, - reasoning: response.openAIReasoning?.summary, - }; - logger.debug('AgentNode: Created final answer message'); + logger.debug('New Model Message:', newModelMessage); + + return { + ...state, + messages: [...state.messages, newModelMessage], + error: undefined, + }; + } catch (error) { + logger.error('Error generating response:', error); + throw error; } + } - logger.debug('New Model Message:', newModelMessage); + resetCallCount(): void { + logger.debug(`Resetting call count from ${this.callCount} to 0`); + this.callCount = 0; + } - return { - ...state, - messages: [...state.messages, newModelMessage], - error: undefined, - }; + /** + * Detect provider from user's settings, not just model name + */ + private detectProvider(modelName: string): LLMProvider { + // Respect user's provider selection from settings + const selectedProvider = localStorage.getItem('ai_chat_provider') || 'openai'; + return selectedProvider as LLMProvider; + } + + /** + * Convert ChatMessage[] to LLMMessage[] + */ + private convertChatMessagesToLLMMessages(messages: ChatMessage[]): LLMMessage[] { + const llmMessages: LLMMessage[] = []; + + for (const msg of messages) { + if (msg.entity === ChatMessageEntity.USER) { + // User message + if ('text' in msg) { + llmMessages.push({ + role: 'user', + content: msg.text, + }); + } + } else if (msg.entity === ChatMessageEntity.MODEL) { + // Model message + if ('answer' in msg && msg.answer) { + llmMessages.push({ + role: 'assistant', + content: msg.answer, + }); + } else if ('action' in msg && msg.action === 'tool' && 'toolName' in msg && 'toolArgs' in msg && 'toolCallId' in msg) { + // Tool call message - convert from ModelChatMessage structure + llmMessages.push({ + role: 'assistant', + content: undefined, + tool_calls: [{ + id: msg.toolCallId!, + type: 'function' as const, + function: { + name: msg.toolName!, + arguments: JSON.stringify(msg.toolArgs), + } + }], + }); + } + } else if (msg.entity === ChatMessageEntity.TOOL_RESULT) { + // Tool result message + if ('toolCallId' in msg && 'resultText' in msg) { + llmMessages.push({ + role: 'tool', + content: String(msg.resultText), + tool_call_id: msg.toolCallId, + }); + } + } + } + + return llmMessages; } - }(model); + }(modelName, temperature); return agentNode; } diff --git a/front_end/panels/ai_chat/core/AgentService.ts b/front_end/panels/ai_chat/core/AgentService.ts index 48b450b1764..6d5dda5fd87 100644 --- a/front_end/panels/ai_chat/core/AgentService.ts +++ b/front_end/panels/ai_chat/core/AgentService.ts @@ -17,6 +17,7 @@ import {createAgentGraph} from './Graph.js'; import { createLogger } from './Logger.js'; import {type AgentState, createInitialState, createUserMessage} from './State.js'; import type {CompiledGraph} from './Types.js'; +import { LLMClient } from '../LLM/LLMClient.js'; const logger = createLogger('AgentService'); @@ -71,6 +72,48 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ return this.#apiKey; } + /** + * Initializes the LLM client with provider configurations + */ + async #initializeLLMClient(): Promise { + const llm = LLMClient.getInstance(); + + // Get configuration from localStorage + const provider = localStorage.getItem('ai_chat_provider') || 'openai'; + const openaiKey = localStorage.getItem('ai_chat_api_key') || ''; + const litellmKey = localStorage.getItem('ai_chat_litellm_api_key') || ''; + const litellmEndpoint = localStorage.getItem('ai_chat_litellm_endpoint') || ''; + + const providers = []; + + // Add OpenAI if it's the selected provider and has an API key + if (provider === 'openai' && openaiKey) { + providers.push({ + provider: 'openai' as const, + apiKey: openaiKey + }); + } + + // Add LiteLLM if it's the selected provider and has configuration + if (provider === 'litellm' && litellmEndpoint) { + providers.push({ + provider: 'litellm' as const, + apiKey: litellmKey, // Can be empty for some LiteLLM endpoints + providerURL: litellmEndpoint + }); + } + + if (providers.length === 0) { + const errorMessage = provider === 'openai' + ? 'OpenAI API key is required for this configuration' + : 'LiteLLM endpoint is required for this configuration'; + throw new Error(errorMessage); + } + + await llm.initialize({ providers }); + logger.info('LLM client initialized successfully'); + } + /** * Initializes the agent with the given API key */ @@ -82,6 +125,9 @@ export class AgentService extends Common.ObjectWrapper.ObjectWrapper<{ throw new Error('Model name is required for initialization'); } + // Initialize LLM client first + await this.#initializeLLMClient(); + // Check if the configuration requires an API key const requiresApiKey = this.#doesCurrentConfigRequireApiKey(); diff --git a/front_end/panels/ai_chat/core/ChatLiteLLM.ts b/front_end/panels/ai_chat/core/ChatLiteLLM.ts deleted file mode 100644 index f5199678a5f..00000000000 --- a/front_end/panels/ai_chat/core/ChatLiteLLM.ts +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2025 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -import type { getTools } from '../tools/Tools.js'; - -import * as BaseOrchestratorAgent from './BaseOrchestratorAgent.js'; -import { createLogger } from './Logger.js'; -import { enhancePromptWithPageContext } from './PageInfoManager.js'; -import type { AgentState } from './State.js'; -import { UnifiedLLMClient } from './UnifiedLLMClient.js'; -import { ChatMessageEntity, type ChatMessage } from '../ui/ChatView.js'; - -const logger = createLogger('ChatLiteLLM'); - -// Define interfaces for our custom implementation -interface ModelResponse { - parsedAction: { - action: 'tool' | 'final', // Discriminator - toolName?: string, // Defined if action is 'tool' - toolArgs?: Record, // Defined if action is 'tool' - answer?: string, // Defined if action is 'final'. This is the user-facing message or error. - }; -} - -interface Model { - generate(prompt: string, systemPrompt: string, state: AgentState): Promise; - generateWithMessages(messages: ChatMessage[], systemPrompt: string, state: AgentState): Promise; - resetCallCount(): void; -} - -// Create the appropriate tools for the agent based on agent type -function getAgentToolsFromState(state: AgentState): ReturnType { - // Use the helper from BaseOrchestratorAgent to get the pre-filtered list - return BaseOrchestratorAgent.getAgentTools(state.selectedAgentType ?? ''); // Pass agentType or empty string -} - -// Ensure ChatLiteLLM tracks interaction state -export class ChatLiteLLM implements Model { - private apiKey: string | null; - private modelName: string; - private temperature: number; - // Add a counter to track how many times generate has been called per interaction - private callCount = 0; - // Maximum number of calls per interaction - private maxCallsPerInteraction = 25; - - constructor(options: { - liteLLMApiKey: string | null, - modelName: string, - temperature?: number, - }) { - this.apiKey = options.liteLLMApiKey; - this.modelName = options.modelName; - this.temperature = options.temperature ?? 1.0; - } - - // Method to reset the call counter when a new user message is received - resetCallCount(): void { - this.callCount = 0; - } - - // Method to check if we've exceeded the maximum number of calls - hasExceededMaxCalls(): boolean { - return this.callCount >= this.maxCallsPerInteraction; - } - - async generate(prompt: string, systemPrompt: string, state: AgentState): Promise { - // Convert single prompt to message format for backward compatibility - const messages: ChatMessage[] = [{ - entity: ChatMessageEntity.USER, - text: prompt - }]; - - return this.generateWithMessages(messages, systemPrompt, state); - } - - async generateWithMessages(messages: ChatMessage[], systemPrompt: string, state: AgentState): Promise { - // Increment the call counter - this.callCount++; - - // Check if we've exceeded the maximum number of calls - if (this.hasExceededMaxCalls()) { - // Return a forced final response when limit is exceeded - return { - parsedAction: { - action: 'final', - answer: 'I reached the maximum number of tool calls. I need to provide a direct answer based on what I know so far. Let me know if you need more clarification.', - }, - }; - } - - logger.debug('Generating response from LiteLLM', { messageCount: messages.length }); - try { - // Get agent-specific tools to include in the request - const tools = getAgentToolsFromState(state).map(tool => ({ - type: 'function', - function: { - name: tool.name, - description: tool.description, - parameters: tool.schema - } - })); - - // Get the enhanced system prompt - use the async version - const enhancedSystemPrompt = await enhancePromptWithPageContext(systemPrompt); - - // Use UnifiedLLMClient for consistent message handling - const unifiedResponse = await UnifiedLLMClient.callLLMWithMessages( - this.apiKey || '', - this.modelName, - messages, - { - tools, - systemPrompt: enhancedSystemPrompt, - temperature: this.temperature, - } - ); - - // Process the response using UnifiedLLMClient's parser - const parsedAction = UnifiedLLMClient.parseResponse(unifiedResponse); - - // Convert to the expected ModelResponse format - const modelResponse: ModelResponse = { - parsedAction: { - action: parsedAction.type === 'tool_call' ? 'tool' : 'final', - toolName: parsedAction.type === 'tool_call' ? parsedAction.name : undefined, - toolArgs: parsedAction.type === 'tool_call' ? parsedAction.args : undefined, - answer: parsedAction.type === 'final_answer' ? parsedAction.answer : (parsedAction.type === 'error' ? parsedAction.error : undefined), - }, - }; - - return modelResponse; - } catch (error) { - logger.error('Error during UnifiedLLMClient call:', error); - // Return error as final answer - return { - parsedAction: { - action: 'final', - answer: `Error calling LiteLLM: ${error instanceof Error ? error.message : 'Unknown error'}`, - }, - }; - } - } -} diff --git a/front_end/panels/ai_chat/core/ChatOpenAI.ts b/front_end/panels/ai_chat/core/ChatOpenAI.ts deleted file mode 100644 index f8699180eb0..00000000000 --- a/front_end/panels/ai_chat/core/ChatOpenAI.ts +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2025 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -import type { getTools } from '../tools/Tools.js'; - -import * as BaseOrchestratorAgent from './BaseOrchestratorAgent.js'; -import { createLogger } from './Logger.js'; -import { OpenAIClient, type OpenAIResponse, type ParsedLLMAction, type OpenAICallOptions } from './OpenAIClient.js'; -import { enhancePromptWithPageContext } from './PageInfoManager.js'; -import type { AgentState } from './State.js'; -import { UnifiedLLMClient } from './UnifiedLLMClient.js'; -import { ChatMessageEntity, type ChatMessage } from '../ui/ChatView.js'; - -const logger = createLogger('ChatOpenAI'); - -// Define interfaces for our custom implementation -interface ModelResponse { - parsedAction: { - action: 'tool' | 'final', // Discriminator - toolName?: string, // Defined if action is 'tool' - toolArgs?: Record, // Defined if action is 'tool' - answer?: string, // Defined if action is 'final'. This is the user-facing message or error. - }; - openAIReasoning?: { - summary?: string[] | null, - effort?: string, - }; -} - -export interface Model { - generate(prompt: string, systemPrompt: string, state: AgentState): Promise; - generateWithMessages(messages: ChatMessage[], systemPrompt: string, state: AgentState): Promise; - resetCallCount(): void; -} - -// Create the appropriate tools for the agent based on agent type -function getAgentToolsFromState(state: AgentState): ReturnType { - // Use the helper from BaseOrchestratorAgent to get the pre-filtered list - return BaseOrchestratorAgent.getAgentTools(state.selectedAgentType ?? ''); // Pass agentType or empty string -} - -// Ensure ChatOpenAI tracks interaction state -export class ChatOpenAI implements Model { - private apiKey: string; - private modelName: string; - private temperature: number; - // Add a counter to track how many times generate has been called per interaction - private callCount = 0; - // Maximum number of calls per interaction - private maxCallsPerInteraction = 25; - - constructor(options: { openAIApiKey: string, modelName: string, temperature?: number }) { - this.apiKey = options.openAIApiKey; - this.modelName = options.modelName; - this.temperature = options.temperature ?? 1.0; - } - - // Method to reset the call counter when a new user message is received - resetCallCount(): void { - this.callCount = 0; - } - - // Method to check if we\'ve exceeded the maximum number of calls - hasExceededMaxCalls(): boolean { - return this.callCount >= this.maxCallsPerInteraction; - } - - async generate(prompt: string, systemPrompt: string, state: AgentState): Promise { - // Convert single prompt to message format for backward compatibility - const messages: ChatMessage[] = [{ - entity: ChatMessageEntity.USER, - text: prompt - }]; - - return this.generateWithMessages(messages, systemPrompt, state); - } - - async generateWithMessages(messages: ChatMessage[], systemPrompt: string, state: AgentState): Promise { - // Increment the call counter - this.callCount++; - - // Check if we've exceeded the maximum number of calls - if (this.hasExceededMaxCalls()) { - // Return a forced final response when limit is exceeded - return { - parsedAction: { - action: 'final', - answer: 'I reached the maximum number of tool calls. I need to provide a direct answer based on what I know so far. Let me know if you need more clarification.', - }, - }; - } - - logger.debug('Generating response from OpenAI', { messageCount: messages.length }); - try { - // Get agent-specific tools to include in the request - const tools = getAgentToolsFromState(state).map(tool => ({ - type: 'function', - function: { - name: tool.name, - description: tool.description, - parameters: tool.schema - } - })); - - // Get the enhanced system prompt - use the async version - const enhancedSystemPrompt = await enhancePromptWithPageContext(systemPrompt); - - // Use UnifiedLLMClient for consistent message handling - const unifiedResponse = await UnifiedLLMClient.callLLMWithMessages( - this.apiKey, - this.modelName, - messages, - { - tools, - systemPrompt: enhancedSystemPrompt, - temperature: this.temperature, - } - ); - - // Process the response using UnifiedLLMClient's parser - const parsedLlmAction: ParsedLLMAction = UnifiedLLMClient.parseResponse(unifiedResponse); - - let parsedActionData: ModelResponse['parsedAction']; - let openAIReasoning: ModelResponse['openAIReasoning'] = undefined; - - switch (parsedLlmAction.type) { - case 'tool_call': - parsedActionData = { - action: 'tool', - toolName: parsedLlmAction.name, - toolArgs: parsedLlmAction.args, - }; - break; - case 'final_answer': - parsedActionData = { - action: 'final', - answer: parsedLlmAction.answer, - }; - break; - case 'error': - const errorMessage = `LLM response processing error: ${parsedLlmAction.error}`; - parsedActionData = { - action: 'final', - answer: errorMessage, - }; - break; - } - - // Extract reasoning information if available - if (unifiedResponse.reasoning) { - openAIReasoning = { - summary: unifiedResponse.reasoning.summary, - effort: unifiedResponse.reasoning.effort, - }; - } - - return { parsedAction: parsedActionData, openAIReasoning }; - } catch (error) { - // Error logging is handled within UnifiedLLMClient, but re-throw if needed - logger.error('Error in ChatOpenAI.generateWithMessages after calling UnifiedLLMClient:', error); - return { - parsedAction: { - action: 'final', - answer: `error in calling API client: ${error}`, - }, - }; - } - } -} - -// Export the interfaces and class -export type { ModelResponse }; diff --git a/front_end/panels/ai_chat/core/ConfigurableGraph.ts b/front_end/panels/ai_chat/core/ConfigurableGraph.ts index be19c245910..28ef0b0ad5d 100644 --- a/front_end/panels/ai_chat/core/ConfigurableGraph.ts +++ b/front_end/panels/ai_chat/core/ConfigurableGraph.ts @@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -import type { Model } from './ChatOpenAI.js'; import { createAgentNode, createFinalNode, createToolExecutorNode, routeNextNode } from './Graph.js'; import { createLogger } from './Logger.js'; import type { AgentState } from './State.js'; @@ -34,27 +33,27 @@ export interface GraphConfig { entryPoint: string; nodes: GraphNodeConfig[]; edges: GraphEdgeConfig[]; + modelName?: string; + temperature?: number; } /** * Creates a compiled agent graph from a configuration object. * - * @param config The graph configuration. - * @param model The Model instance (ChatOpenAI or ChatLiteLLM), already initialized. + * @param config The graph configuration with model information. * @returns A compiled StateGraph. */ export function createAgentGraphFromConfig( config: GraphConfig, - model: Model, ): CompiledGraph { - logger.info(`Creating graph from config: ${config.name}`); + logger.info(`Creating graph from config: ${config.name} with model: ${config.modelName}`); const graph = new StateGraph({ name: config.name }); - const nodeFactories: Record) => Runnable> = { - agent: model => createAgentNode(model), + const nodeFactories: Record) => Runnable> = { + agent: () => createAgentNode(config.modelName!, config.temperature || 0), final: () => createFinalNode(), - toolExecutor: (_model, nodeCfg) => { + toolExecutor: (nodeCfg) => { return { invoke: async (state: AgentState) => { logger.warn(`ToolExecutorNode "${nodeCfg.name}" invoked without being dynamically replaced. This indicates an issue.`); @@ -67,7 +66,7 @@ export function createAgentGraphFromConfig( for (const nodeConfig of config.nodes) { const factory = nodeFactories[nodeConfig.type]; if (factory) { - const nodeInstance = factory(model, nodeConfig, graph); + const nodeInstance = factory(nodeConfig, graph); graph.addNode(nodeConfig.name, nodeInstance); logger.debug(`Added node: ${nodeConfig.name} (type: ${nodeConfig.type})`); } else { @@ -81,7 +80,7 @@ export function createAgentGraphFromConfig( } } - type ConditionFunctionGenerator = (state: AgentState, graphInstance: StateGraph, edgeConfig: GraphEdgeConfig, model: Model) => string; + type ConditionFunctionGenerator = (state: AgentState, graphInstance: StateGraph, edgeConfig: GraphEdgeConfig) => string; const conditionFactories: Record = { routeBasedOnLastMessage: state => routeNextNode(state), @@ -107,7 +106,7 @@ export function createAgentGraphFromConfig( const conditionFactory = conditionFactories[edgeConfig.conditionType]; if (conditionFactory) { const conditionFn = (state: AgentState) => { - return conditionFactory(state, graph, edgeConfig, model); + return conditionFactory(state, graph, edgeConfig); }; graph.addConditionalEdges(edgeConfig.source, conditionFn, edgeConfig.targetMap); logger.debug(`Added edge from ${edgeConfig.source} via ${edgeConfig.conditionType}`); diff --git a/front_end/panels/ai_chat/core/Graph.ts b/front_end/panels/ai_chat/core/Graph.ts index df0db7aea23..ce01871c188 100644 --- a/front_end/panels/ai_chat/core/Graph.ts +++ b/front_end/panels/ai_chat/core/Graph.ts @@ -7,75 +7,34 @@ import { createToolExecutorNode, createFinalNode, } from './AgentNodes.js'; -import { ChatLiteLLM } from './ChatLiteLLM.js'; -import { ChatOpenAI } from './ChatOpenAI.js'; import { createAgentGraphFromConfig } from './ConfigurableGraph.js'; import { defaultAgentGraphConfig } from './GraphConfigs.js'; import { createLogger } from './Logger.js'; -import { AIChatPanel } from '../ui/AIChatPanel.js'; import { createSystemPrompt, getAgentToolsFromState, routeNextNode, } from './GraphHelpers.js'; -import type { AgentState } from './State.js'; import { type CompiledGraph, NodeType } from './Types.js'; const logger = createLogger('Graph'); -// createAgentGraph now uses the imported typed configuration object -export function createAgentGraph(apiKey: string | null, modelName: string): CompiledGraph { +// createAgentGraph now uses the LLM SDK directly +export function createAgentGraph(_apiKey: string | null, modelName: string): CompiledGraph { if (!modelName) { throw new Error('Model name is required'); } - let model; - // Get model options using the centralized method - const modelOptions = AIChatPanel.getModelOptions(); - - const modelOption = modelOptions.find((opt: {value: string, type: string}) => opt.value === modelName); - const isLiteLLMModel = modelOption?.type === 'litellm' || modelName.startsWith('litellm/'); + logger.debug('Creating graph for model:', modelName); - if (isLiteLLMModel) { - // Get LiteLLM configuration from localStorage - const liteLLMEndpoint = localStorage.getItem('ai_chat_litellm_endpoint'); - - // Check if endpoint is configured - if (!liteLLMEndpoint) { - throw new Error('LiteLLM endpoint is required for LiteLLM models'); - } + // Create graph configuration with model name - nodes will use LLMClient directly + const graphConfigWithModel = { + ...defaultAgentGraphConfig, + modelName: modelName, + temperature: 0, + }; - // Handle both cases: models with and without 'litellm/' prefix - const actualModelName = modelName.startsWith('litellm/') ? - modelName.substring('litellm/'.length) : - modelName; - - logger.debug('Creating ChatLiteLLM model:', { - modelName: actualModelName, - endpoint: liteLLMEndpoint, - hasApiKey: Boolean(apiKey) - }); - - model = new ChatLiteLLM({ - liteLLMApiKey: apiKey, - modelName: actualModelName, - temperature: 0, - }); - } else { - // Standard OpenAI model - requires API key - if (!apiKey) { - throw new Error('OpenAI API key is required for OpenAI models'); - } - model = new ChatOpenAI({ - openAIApiKey: apiKey, - modelName, - temperature: 0, - }); - } - - // Use the imported configuration object directly - logger.debug('Using defaultAgentGraphConfig to create graph.'); - return createAgentGraphFromConfig(defaultAgentGraphConfig, model); + return createAgentGraphFromConfig(graphConfigWithModel); } export { createAgentNode, createToolExecutorNode, createFinalNode, routeNextNode, createSystemPrompt, getAgentToolsFromState, NodeType }; diff --git a/front_end/panels/ai_chat/core/GraphHelpers.ts b/front_end/panels/ai_chat/core/GraphHelpers.ts index fb95181045c..e4ae5adb671 100644 --- a/front_end/panels/ai_chat/core/GraphHelpers.ts +++ b/front_end/panels/ai_chat/core/GraphHelpers.ts @@ -13,63 +13,6 @@ import { NodeType } from './Types.js'; const logger = createLogger('GraphHelpers'); -// DEPRECATED: ChatPromptFormatter -// This class was used to concatenate messages into a single string for older LLM APIs. -// Now we use the message-based approach with UnifiedLLMClient for proper conversation handling. -// Kept for backward compatibility but should not be used in new code. -export class ChatPromptFormatter { - format(values: { messages: ChatMessage[] }): string { - logger.warn('ChatPromptFormatter is deprecated. Use message-based approach with UnifiedLLMClient instead.'); - const messageHistory = values.messages || []; - const formattedParts: string[] = []; - - // Find the last GetAccessibilityTreeTool result, if any - let lastAccessibilityTreeIndex = -1; - for (let i = messageHistory.length - 1; i >= 0; i--) { - const message = messageHistory[i]; - if (message.entity === ChatMessageEntity.TOOL_RESULT && - message.toolName === 'get_accessibility_tree') { - lastAccessibilityTreeIndex = i; - break; - } - } - - for (let i = 0; i < messageHistory.length; i++) { - const message = messageHistory[i]; - switch (message.entity) { - case ChatMessageEntity.USER: - if (message.text) { - formattedParts.push(`user: ${message.text}`); - } - break; - case ChatMessageEntity.MODEL: - // Format model message based on its action - if (message.action === 'tool') { - // Represent tool call in natural language instead of JSON to avoid LLM confusion - formattedParts.push(`assistant: Used tool "${message.toolName || 'unknown'}" to ${message.reasoning || 'perform an action'}.`); - } else if (message.answer) { - // Represent final answer - now just using plain markdown text - formattedParts.push(`assistant: ${message.answer}`); - } - break; - case ChatMessageEntity.TOOL_RESULT: - // For GetAccessibilityTreeTool, only include the most recent result - if (message.toolName === 'get_accessibility_tree' && i !== lastAccessibilityTreeIndex) { - // Skip older accessibility tree results - const resultPrefix = message.isError ? `tool_error[${message.toolName}]` : `tool_result[${message.toolName}]`; - formattedParts.push(`${resultPrefix}: Hidden to save on tokens`); - continue; - } - - // Represent tool results clearly - const resultPrefix = message.isError ? `tool_error[${message.toolName}]` : `tool_result[${message.toolName}]`; - formattedParts.push(`${resultPrefix}: ${message.resultText}`); - break; - } - } - return formattedParts.join('\n\n'); - } -} // Replace createSystemPrompt with this version export function createSystemPrompt(state: AgentState): string { diff --git a/front_end/panels/ai_chat/core/LiteLLMClient.ts b/front_end/panels/ai_chat/core/LiteLLMClient.ts deleted file mode 100644 index 4eb5e8b9315..00000000000 --- a/front_end/panels/ai_chat/core/LiteLLMClient.ts +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright 2025 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -import { createLogger } from './Logger.js'; - -const logger = createLogger('LiteLLMClient'); - -/** - * OpenAI-compatible message format (LiteLLM uses OpenAI standard) - */ -export interface OpenAIMessage { - role: 'system' | 'user' | 'assistant' | 'tool'; - content?: string | null; - tool_calls?: Array<{ - id: string; - type: 'function'; - function: { - name: string; - arguments: string; - }; - }>; - tool_call_id?: string; - name?: string; -} - -/** - * Types for LiteLLM API request and response - */ -export interface LiteLLMCallOptions { - tools?: any[]; - tool_choice?: any; - systemPrompt?: string; // Kept for backward compatibility with old methods - temperature?: number; - endpoint?: string; // Full endpoint URL (e.g., http://localhost:4000/v1/chat/completions) - baseUrl?: string; // Base URL only (e.g., http://localhost:4000 or https://your-cloud-litellm.com) -} - -export interface LiteLLMResponse { - text?: string; - functionCall?: { - name: string, - arguments: any, - }; - rawResponse: any; -} - -/** - * Types for LiteLLM models endpoint - */ -export interface LiteLLMModel { - id: string; - object: string; - created?: number; - owned_by?: string; -} - -export interface LiteLLMModelsResponse { - object: string; - data: LiteLLMModel[]; -} - -/** - * Standardized structure for parsed LLM action - */ -export type ParsedLLMAction = - | { type: 'tool_call', name: string, args: Record } - | { type: 'final_answer', answer: string } - | { type: 'error', error: string }; - -/** - * LiteLLMClient class for making requests to LiteLLM API - */ -export class LiteLLMClient { - /** - * Default base URL for local LiteLLM proxy - */ - private static DEFAULT_BASE_URL = 'http://localhost:4000'; - - /** - * Default endpoint path for chat completions - */ - private static CHAT_COMPLETIONS_PATH = '/v1/chat/completions'; - - /** - * Endpoint path for models list - */ - private static MODELS_PATH = '/v1/models'; - - /** - * Constructs the full endpoint URL based on provided options - */ - private static getEndpoint(options?: LiteLLMCallOptions): string { - // Check if we have a valid endpoint or baseUrl - if (!options?.endpoint && !options?.baseUrl) { - // Check localStorage as a fallback for endpoint - const localStorageEndpoint = localStorage.getItem('ai_chat_litellm_endpoint'); - if (!localStorageEndpoint) { - throw new Error('LiteLLM endpoint not configured. Please set endpoint in settings.'); - } - logger.debug(`Using endpoint from localStorage: ${localStorageEndpoint}`); - const baseUrl = localStorageEndpoint.replace(/\/$/, ''); - return `${baseUrl}${this.CHAT_COMPLETIONS_PATH}`; - } - - // If full endpoint is provided, check if it includes the chat completions path - if (options?.endpoint) { - // Check if the endpoint already includes the chat completions path - if (options.endpoint.includes('/v1/chat/completions')) { - return options.endpoint; - } - // If not, treat it as a base URL and append the path - const baseUrl = options.endpoint.replace(/\/$/, ''); - logger.debug(`Endpoint missing chat completions path, appending: ${baseUrl}${this.CHAT_COMPLETIONS_PATH}`); - return `${baseUrl}${this.CHAT_COMPLETIONS_PATH}`; - } - - // If base URL is provided, append the path - if (options?.baseUrl) { - // Remove trailing slash from base URL if present - const baseUrl = options.baseUrl.replace(/\/$/, ''); - return `${baseUrl}${this.CHAT_COMPLETIONS_PATH}`; - } - - // Default to local LiteLLM (should not reach here due to the check at the top) - return `${this.DEFAULT_BASE_URL}${this.CHAT_COMPLETIONS_PATH}`; - } - - /** - * Call the LiteLLM API with the provided parameters - */ - static async callLiteLLM( - apiKey: string | null, - modelName: string, - prompt: string, - options?: LiteLLMCallOptions - ): Promise { - logger.debug('Calling LiteLLM...', { model: modelName, prompt }); - - // Use standard OpenAI chat completions format - const messages = []; - - // Add system prompt if provided (backward compatibility) - if (options?.systemPrompt) { - messages.push({ - role: 'system', - content: options.systemPrompt - }); - } - - // Add user message - messages.push({ - role: 'user', - content: prompt - }); - - // Construct payload body in standard OpenAI format - const payloadBody: any = { - model: modelName, - messages, - }; - - // Add temperature if provided - if (options?.temperature !== undefined) { - payloadBody.temperature = options.temperature; - } - - // Add tools if provided - if (options?.tools) { - // Ensure all tools have valid parameters - payloadBody.tools = options.tools.map(tool => { - if (tool.type === 'function' && tool.function) { - return { - ...tool, - function: { - ...tool.function, - parameters: tool.function.parameters || { type: 'object', properties: {} } - } - }; - } - return tool; - }); - } - - // Add tool_choice if provided - if (options?.tool_choice) { - payloadBody.tool_choice = options.tool_choice; - } - - logger.info('Request payload:', payloadBody); - - try { - const endpoint = this.getEndpoint(options); - logger.debug('Using endpoint:', endpoint); - - const response = await fetch(endpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}), - }, - body: JSON.stringify(payloadBody), - }); - - if (!response.ok) { - const errorData = await response.json(); - logger.error('LiteLLM API error:', errorData); - throw new Error(`LiteLLM API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); - } - - const data = await response.json(); - logger.info('LiteLLM Response:', data); - - if (data.usage) { - logger.info('LiteLLM Usage:', { inputTokens: data.usage.prompt_tokens, outputTokens: data.usage.completion_tokens }); - } - - // Process the response in standard OpenAI format - const result: LiteLLMResponse = { - rawResponse: data - }; - - if (!data?.choices || data.choices.length === 0) { - throw new Error('No choices in LiteLLM response'); - } - - const choice = data.choices[0]; - const message = choice.message; - - if (!message) { - throw new Error('No message in LiteLLM choice'); - } - - // Check for tool calls - if (message.tool_calls && message.tool_calls.length > 0) { - const toolCall = message.tool_calls[0]; - if (toolCall.function) { - try { - result.functionCall = { - name: toolCall.function.name, - arguments: JSON.parse(toolCall.function.arguments) - }; - } catch (error) { - logger.error('Error parsing function arguments:', error); - result.functionCall = { - name: toolCall.function.name, - arguments: toolCall.function.arguments // Keep as string if parsing fails - }; - } - } - } else if (message.content) { - // Plain text response - result.text = message.content.trim(); - } - - return result; - } catch (error) { - logger.error('LiteLLM API request failed:', error); - throw error; - } - } - - /** - * Call LiteLLM API with OpenAI-compatible messages array (simplified approach) - */ - static async callLiteLLMWithMessages( - apiKey: string | null, - modelName: string, - messages: OpenAIMessage[], - options?: LiteLLMCallOptions - ): Promise { - logger.debug('Calling LiteLLM with messages...', { model: modelName, messageCount: messages.length }); - - // Construct payload body in OpenAI format (LiteLLM is OpenAI-compatible) - const payloadBody: any = { - model: modelName, - messages, // Direct OpenAI format - no conversion needed! - }; - - // Add temperature if provided - if (options?.temperature !== undefined) { - payloadBody.temperature = options.temperature; - } - - // Add tools if provided - if (options?.tools) { - payloadBody.tools = options.tools; - } - - // Add tool_choice if provided - if (options?.tool_choice) { - payloadBody.tool_choice = options.tool_choice; - } - - logger.info('Request payload:', payloadBody); - - try { - const endpoint = this.getEndpoint(options); - logger.debug('Using endpoint:', endpoint); - - const response = await fetch(endpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}), - }, - body: JSON.stringify(payloadBody), - }); - - if (!response.ok) { - const errorData = await response.json(); - logger.error('LiteLLM API error:', errorData); - throw new Error(`LiteLLM API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); - } - - const data = await response.json(); - logger.info('LiteLLM Response:', data); - - if (data.usage) { - logger.info('LiteLLM Usage:', { inputTokens: data.usage.prompt_tokens, outputTokens: data.usage.completion_tokens }); - } - - // Process the response in standard OpenAI format (same as before) - const result: LiteLLMResponse = { - rawResponse: data - }; - - if (!data?.choices || data.choices.length === 0) { - throw new Error('No choices in LiteLLM response'); - } - - const choice = data.choices[0]; - const message = choice.message; - - if (!message) { - throw new Error('No message in LiteLLM choice'); - } - - // Check for tool calls - if (message.tool_calls && message.tool_calls.length > 0) { - const toolCall = message.tool_calls[0]; - if (toolCall.function) { - try { - result.functionCall = { - name: toolCall.function.name, - arguments: JSON.parse(toolCall.function.arguments) - }; - } catch (error) { - logger.error('Error parsing function arguments:', error); - result.functionCall = { - name: toolCall.function.name, - arguments: toolCall.function.arguments // Keep as string if parsing fails - }; - } - } - } else if (message.content) { - // Plain text response - result.text = message.content.trim(); - } - - return result; - } catch (error) { - logger.error('LiteLLM API request failed:', error); - throw error; - } - } - - /** - * Fetch available models from LiteLLM endpoint - */ - static async fetchModels(apiKey: string | null, baseUrl?: string): Promise { - logger.debug('Fetching available models...'); - - try { - // Construct models endpoint URL - const baseEndpoint = baseUrl || this.DEFAULT_BASE_URL; - const modelsUrl = `${baseEndpoint.replace(/\/$/, '')}${this.MODELS_PATH}`; - logger.debug('Using models endpoint:', modelsUrl); - - const response = await fetch(modelsUrl, { - method: 'GET', - headers: { - ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}), - }, - }); - - if (!response.ok) { - const errorData = await response.json(); - logger.error('LiteLLM models API error:', errorData); - throw new Error(`LiteLLM models API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); - } - - const data: LiteLLMModelsResponse = await response.json(); - logger.debug('LiteLLM Models Response:', data); - - if (!data?.data || !Array.isArray(data.data)) { - throw new Error('Invalid models response format'); - } - - return data.data; - } catch (error) { - logger.error('Failed to fetch LiteLLM models:', error); - throw error; - } - } - - /** - * Test the LiteLLM connection with a simple completion request - */ - static async testConnection(apiKey: string | null, modelName: string, baseUrl?: string): Promise<{success: boolean, message: string}> { - logger.debug('Testing connection...'); - - try { - const testPrompt = 'Please respond with "Connection successful!" to confirm the connection is working.'; - - const options: LiteLLMCallOptions = { - temperature: 0.1, - baseUrl, - }; - - const response = await this.callLiteLLM(apiKey, modelName, testPrompt, options); - - if (response.text?.toLowerCase().includes('connection')) { - return { - success: true, - message: `Successfully connected to LiteLLM with model ${modelName}`, - }; - } - return { - success: true, - message: `Connected to LiteLLM, but received unexpected response: ${response.text || 'No response'}`, - }; - } catch (error) { - logger.error('LiteLLM connection test failed:', error); - return { - success: false, - message: error instanceof Error ? error.message : 'Unknown error occurred', - }; - } - } - - /** - * Parses the raw LiteLLM response into a standardized action structure - */ - static parseLiteLLMResponse(response: LiteLLMResponse): ParsedLLMAction { - if (response.functionCall) { - return { - type: 'tool_call', - name: response.functionCall.name, - args: response.functionCall.arguments || {}, - }; - } if (response.text) { - const rawContent = response.text; - // Attempt to parse text as JSON tool call (fallback for some models) - if (rawContent.trim().startsWith('{') && rawContent.includes('"action":"tool"')) { // Heuristic - try { - const contentJson = JSON.parse(rawContent); - if (contentJson.action === 'tool' && contentJson.toolName) { - return { - type: 'tool_call', - name: contentJson.toolName, - args: contentJson.toolArgs || {}, - }; - } - // Fallback to treating it as text if JSON structure is not a valid tool call - return { type: 'final_answer', answer: rawContent }; - - } catch (e) { - // If JSON parsing fails, treat it as plain text - return { type: 'final_answer', answer: rawContent }; - } - } else { - // Treat as plain text final answer - return { type: 'final_answer', answer: rawContent }; - } - } else { - // No function call or text found - logger.error('LLM response had no function call or text.'); - return { type: 'error', error: 'LLM returned empty response.' }; - } - } -} diff --git a/front_end/panels/ai_chat/core/OpenAIClient.ts b/front_end/panels/ai_chat/core/OpenAIClient.ts deleted file mode 100644 index ddee4544722..00000000000 --- a/front_end/panels/ai_chat/core/OpenAIClient.ts +++ /dev/null @@ -1,371 +0,0 @@ -// Copyright 2025 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -import { createLogger } from './Logger.js'; - -const logger = createLogger('OpenAIClient'); - -/** - * OpenAI-compatible message format - */ -export interface OpenAIMessage { - role: 'system' | 'user' | 'assistant' | 'tool'; - content?: string | null; - tool_calls?: Array<{ - id: string; - type: 'function'; - function: { - name: string; - arguments: string; - }; - }>; - tool_call_id?: string; - name?: string; -} - -/** - * Types for OpenAI API request and response - */ -export interface OpenAICallOptions { - tools?: any[]; - tool_choice?: any; - systemPrompt?: string; // For backward compatibility with simple prompt methods - temperature?: number; - reasoningLevel?: 'low' | 'medium' | 'high'; -} - -export interface OpenAIResponse { - text?: string; - functionCall?: { - name: string, - arguments: any, - }; - rawResponse: any; - reasoning?: { - summary?: string[] | null, - effort?: string, - }; -} - -/** - * Standardized structure for parsed LLM action - */ -export type ParsedLLMAction = - | { type: 'tool_call', name: string, args: Record } - | { type: 'final_answer', answer: string } - | { type: 'error', error: string }; - -/** - * Enum to distinguish between model families with different request/response formats - */ -export enum ModelFamily { - GPT = 'gpt', - O = 'o' -} - -/** - * Responses API message format for tool calls and results - */ -interface ResponsesAPIFunctionCall { - type: 'function_call'; - name: string; - arguments: string; - call_id: string; -} - -interface ResponsesAPIFunctionOutput { - type: 'function_call_output'; - call_id: string; - output: string; -} - -/** - * OpenAIClient class for making requests to OpenAI Responses API - */ -export class OpenAIClient { - private static readonly API_ENDPOINT = 'https://api.openai.com/v1/responses'; - - /** - * Determines the model family based on the model name - */ - private static getModelFamily(modelName: string): ModelFamily { - // Check if model name starts with 'o' to identify O series models - if (modelName.startsWith('o')) { - return ModelFamily.O; - } - // Otherwise, assume it's a GPT model (gpt-3.5-turbo, gpt-4, etc.) - return ModelFamily.GPT; - } - - /** - * Converts tools from chat/completions format to responses API format - */ - private static convertToolsFormat(tools: any[]): any[] { - return tools.map(tool => { - if (tool.type === 'function' && tool.function) { - // Convert from chat/completions format to responses API format - return { - type: 'function', - name: tool.function.name, - description: tool.function.description, - parameters: tool.function.parameters || { type: 'object', properties: {} } - }; - } - return tool; // Return as-is if already in correct format - }); - } - - /** - * Converts messages to responses API format based on model family - */ - private static convertMessagesToResponsesAPI( - messages: OpenAIMessage[], - modelFamily: ModelFamily - ): any[] { - return messages.map(msg => { - if (msg.role === 'system') { - if (modelFamily === ModelFamily.O) { - return { - role: 'system', - content: [{ type: 'input_text', text: msg.content || '' }] - }; - } else { - return { - role: 'system', - content: msg.content || '' - }; - } - } else if (msg.role === 'user') { - if (modelFamily === ModelFamily.O) { - return { - role: 'user', - content: [{ type: 'input_text', text: msg.content || '' }] - }; - } else { - return { - role: 'user', - content: msg.content || '' - }; - } - } else if (msg.role === 'assistant') { - if (msg.tool_calls && msg.tool_calls.length > 0) { - // Convert tool calls to responses API format - const toolCall = msg.tool_calls[0]; // Take first tool call - let argsString: string; - - // Ensure arguments are in string format for responses API - if (typeof toolCall.function.arguments === 'string') { - argsString = toolCall.function.arguments; - } else { - argsString = JSON.stringify(toolCall.function.arguments); - } - - return { - type: 'function_call', - name: toolCall.function.name, - arguments: argsString, - call_id: toolCall.id - } as ResponsesAPIFunctionCall; - } else { - if (modelFamily === ModelFamily.O) { - return { - role: 'assistant', - content: [{ type: 'output_text', text: msg.content || '' }] - }; - } else { - return { - role: 'assistant', - content: msg.content || '' - }; - } - } - } else if (msg.role === 'tool') { - // Convert tool result to responses API format - return { - type: 'function_call_output', - call_id: msg.tool_call_id, - output: msg.content || '' - } as ResponsesAPIFunctionOutput; - } - return msg; - }); - } - - /** - * Processes the responses API output and extracts relevant information - */ - private static processResponsesAPIOutput(data: any): OpenAIResponse { - const result: OpenAIResponse = { - rawResponse: data - }; - - // Extract reasoning info if available (O models) - if (data.reasoning) { - result.reasoning = { - summary: data.reasoning.summary, - effort: data.reasoning.effort - }; - } - - if (!data?.output) { - throw new Error('No output from OpenAI'); - } - - if (data.output && data.output.length > 0) { - // Find function call or message by type instead of assuming position - const functionCallOutput = data.output.find((item: any) => item.type === 'function_call'); - const messageOutput = data.output.find((item: any) => item.type === 'message'); - - if (functionCallOutput) { - // Process function call - try { - result.functionCall = { - name: functionCallOutput.name, - arguments: JSON.parse(functionCallOutput.arguments) - }; - } catch (error) { - logger.error('Error parsing function arguments:', error); - result.functionCall = { - name: functionCallOutput.name, - arguments: functionCallOutput.arguments // Keep as string if parsing fails - }; - } - } - else if (messageOutput?.content && messageOutput.content.length > 0 && messageOutput.content[0].type === 'output_text') { - // Process text response - result.text = messageOutput.content[0].text.trim(); - } - } - - return result; - } - - /** - * Makes a request to the OpenAI Responses API - */ - private static async makeAPIRequest(apiKey: string, payloadBody: any): Promise { - try { - const response = await fetch(this.API_ENDPOINT, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}`, - }, - body: JSON.stringify(payloadBody), - }); - - if (!response.ok) { - const errorData = await response.json(); - logger.error('OpenAI API error:', errorData); - throw new Error(`OpenAI API error: ${response.statusText} - ${errorData?.error?.message || 'Unknown error'}`); - } - - const data = await response.json(); - logger.info('OpenAI Response:', data); - - if (data.usage) { - logger.info('OpenAI Usage:', { inputTokens: data.usage.input_tokens, outputTokens: data.usage.output_tokens }); - } - - return data; - } catch (error) { - logger.error('OpenAI API request failed:', error); - throw error; - } - } - - /** - * Main method to call OpenAI API with messages using the responses API - */ - static async callOpenAIWithMessages( - apiKey: string, - modelName: string, - messages: OpenAIMessage[], - options?: OpenAICallOptions - ): Promise { - logger.debug('Calling OpenAI responses API...', { model: modelName, messageCount: messages.length }); - - // Determine model family - const modelFamily = this.getModelFamily(modelName); - logger.debug('Model Family:', modelFamily); - - // Construct payload body for responses API format - const payloadBody: any = { - model: modelName, - }; - - // Convert messages to responses API format - const convertedMessages = this.convertMessagesToResponsesAPI(messages, modelFamily); - payloadBody.input = convertedMessages; - - // Add temperature if provided, but not for O models (they don't support it) - if (options?.temperature !== undefined && modelFamily !== ModelFamily.O) { - payloadBody.temperature = options.temperature; - } - - // Add tools if provided - convert from chat/completions format to responses API format - if (options?.tools) { - payloadBody.tools = this.convertToolsFormat(options.tools); - } - - // Add tool_choice if provided - if (options?.tool_choice) { - payloadBody.tool_choice = options.tool_choice; - } - - // Add reasoning level for O-series model if provided - if (options?.reasoningLevel && modelFamily === ModelFamily.O) { - payloadBody.reasoning = { - effort: options.reasoningLevel - }; - } - - logger.info('Request payload:', payloadBody); - - const data = await this.makeAPIRequest(apiKey, payloadBody); - return this.processResponsesAPIOutput(data); - } - - /** - * Parses the raw OpenAI response into a standardized action structure - */ - static parseOpenAIResponse(response: OpenAIResponse): ParsedLLMAction { - if (response.functionCall) { - return { - type: 'tool_call', - name: response.functionCall.name, - args: response.functionCall.arguments || {}, - }; - } - - if (response.text) { - const rawContent = response.text; - // Attempt to parse text as JSON tool call (fallback for some models) - if (rawContent.trim().startsWith('{') && rawContent.includes('"action":"tool"')) { - try { - const contentJson = JSON.parse(rawContent); - if (contentJson.action === 'tool' && contentJson.toolName) { - return { - type: 'tool_call', - name: contentJson.toolName, - args: contentJson.toolArgs || {}, - }; - } - // Fallback to treating it as text if JSON structure is not a valid tool call - return { type: 'final_answer', answer: rawContent }; - } catch (e) { - // If JSON parsing fails, treat it as plain text - return { type: 'final_answer', answer: rawContent }; - } - } else { - // Treat as plain text final answer - return { type: 'final_answer', answer: rawContent }; - } - } - - // No function call or text found - logger.error('LLM response had no function call or text.'); - return { type: 'error', error: 'LLM returned empty response.' }; - } -} \ No newline at end of file diff --git a/front_end/panels/ai_chat/core/UnifiedLLMClient.ts b/front_end/panels/ai_chat/core/UnifiedLLMClient.ts deleted file mode 100644 index 6fed96f2478..00000000000 --- a/front_end/panels/ai_chat/core/UnifiedLLMClient.ts +++ /dev/null @@ -1,588 +0,0 @@ -// Copyright 2025 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -import {LiteLLMClient, type LiteLLMResponse, type OpenAIMessage} from './LiteLLMClient.js'; -import {OpenAIClient, type OpenAIResponse} from './OpenAIClient.js'; -import { createLogger } from './Logger.js'; -import { ChatMessageEntity, type ChatMessage } from '../ui/ChatView.js'; - -const logger = createLogger('UnifiedLLMClient'); - -/** - * Error types that can occur during LLM calls - */ -export enum LLMErrorType { - JSON_PARSE_ERROR = 'JSON_PARSE_ERROR', - RATE_LIMIT_ERROR = 'RATE_LIMIT_ERROR', - NETWORK_ERROR = 'NETWORK_ERROR', - SERVER_ERROR = 'SERVER_ERROR', - AUTH_ERROR = 'AUTH_ERROR', - QUOTA_ERROR = 'QUOTA_ERROR', - UNKNOWN_ERROR = 'UNKNOWN_ERROR', -} - -/** - * Retry configuration for a specific error type - */ -export interface RetryConfig { - maxRetries?: number; - baseDelayMs?: number; - maxDelayMs?: number; - backoffMultiplier?: number; - jitterMs?: number; -} - -/** - * Default retry configuration - */ -const DEFAULT_RETRY_CONFIG: RetryConfig = { - maxRetries: 2, - baseDelayMs: 1000, - maxDelayMs: 10000, - backoffMultiplier: 2, - jitterMs: 500, -}; - -/** - * Error-specific retry configurations (only for specific error types) - */ -const ERROR_SPECIFIC_RETRY_CONFIGS: Partial> = { - [LLMErrorType.RATE_LIMIT_ERROR]: { - maxRetries: 3, - baseDelayMs: 60000, // 60 seconds for rate limits - maxDelayMs: 300000, // Max 5 minutes - backoffMultiplier: 1, // No exponential backoff for rate limits - jitterMs: 5000, // Small jitter to avoid thundering herd - }, - - [LLMErrorType.NETWORK_ERROR]: { - maxRetries: 3, - baseDelayMs: 2000, - maxDelayMs: 30000, - backoffMultiplier: 2, - jitterMs: 1000, - }, -}; - -/** - * Unified options for LLM calls that work across different providers - */ -export interface UnifiedLLMOptions { - endpoint?: string; - timeout?: number; - maxTokens?: number; - temperature?: number; - topP?: number; - frequencyPenalty?: number; - presencePenalty?: number; - responseFormat?: any; - n?: number; - stream?: boolean; - maxRetries?: number; - signal?: AbortSignal; - systemPrompt: string; // Made required - tools?: any[]; - tool_choice?: any; - strictJsonMode?: boolean; // New flag for strict JSON parsing - customRetryConfig?: RetryConfig; // Override default retry configuration -} - -/** - * Unified response that includes function calls - */ -export interface UnifiedLLMResponse { - text?: string; - functionCall?: { - name: string, - arguments: any, - }; - rawResponse?: any; - reasoning?: { - summary?: string[] | null, - effort?: string, - }; - parsedJson?: any; // Parsed JSON when strictJsonMode is enabled -} - -/** - * Model configuration from localStorage - */ -interface ModelOption { - value: string; - type: 'openai' | 'litellm'; - label?: string; -} - -/** - * UnifiedLLMClient provides a single interface for calling different LLM providers - * (OpenAI, LiteLLM) based on the model type configuration. - */ -export class UnifiedLLMClient { - private static readonly MODEL_OPTIONS_KEY = 'ai_chat_model_options'; - private static readonly LITELLM_ENDPOINT_KEY = 'ai_chat_litellm_endpoint'; - private static readonly LITELLM_API_KEY_KEY = 'ai_chat_litellm_api_key'; - - /** - * Main unified method to call any LLM based on model configuration - * Returns string for backward compatibility, or parsed JSON if strictJsonMode is enabled - */ - static async callLLM( - apiKey: string, - modelName: string, - userPrompt: string, - options: UnifiedLLMOptions - ): Promise { - // Convert simple prompt to message format - const messages = [{ - entity: ChatMessageEntity.USER as const, - text: userPrompt - }]; - - let systemPrompt = options.systemPrompt; - let enhancedOptions = options; - - // If strict JSON mode is enabled, enhance the system prompt and options - if (options.strictJsonMode) { - systemPrompt = `${systemPrompt}\n\nIMPORTANT: You must respond with valid JSON only. Do not include any text before or after the JSON object.`; - enhancedOptions = { - ...options, - responseFormat: { type: 'json_object' }, // Enable JSON mode for compatible models - systemPrompt - }; - } - - const response = await this.callLLMWithMessages(apiKey, modelName, messages, enhancedOptions); - - // If strict JSON mode is enabled, return parsed JSON or throw error - if (options.strictJsonMode) { - if (response.parsedJson) { - return response.parsedJson; - } - - // Fallback: try to parse the text response - if (response.text) { - try { - return JSON.parse(response.text); - } catch (parseError) { - throw new Error(`Invalid JSON response from LLM: ${parseError instanceof Error ? parseError.message : String(parseError)}. Response: ${response.text}`); - } - } - - throw new Error('No response from LLM for JSON parsing'); - } - - // Default behavior: return text - return response.text || ''; - } - - /** - * Call LLM and get full response including function calls using message array format - */ - static async callLLMWithMessages( - apiKey: string, - modelName: string, - messages: ChatMessage[], - options: UnifiedLLMOptions - ): Promise { - - const modelType = this.getModelType(modelName); - - logger.info('Calling LLM with messages:', { - modelName, - modelType, - messageCount: messages.length, - hasOptions: Boolean(options), - }); - - // Convert to OpenAI format with system prompt - const openaiMessages = this.convertToOpenAIMessages(messages, options.systemPrompt); - - logger.info(`Converted to OpenAI messages:\n${JSON.stringify(openaiMessages, null, 2)}`); - - try { - let response: any; - if (modelType === 'litellm') { - response = await this.callLiteLLMWithMessages(apiKey, modelName, openaiMessages, options); - } else { - response = await this.callOpenAIWithMessages(apiKey, modelName, openaiMessages, options); - } - - const result: UnifiedLLMResponse = { - text: response.text, - functionCall: response.functionCall, - rawResponse: response.rawResponse, - reasoning: response.reasoning || (response as any).reasoning, - }; - - // Handle strict JSON mode parsing - if (options.strictJsonMode && result.text) { - try { - result.parsedJson = this.parseStrictJSON(result.text); - } catch (parseError) { - logger.error('JSON parsing failed in strict mode:', { - error: parseError instanceof Error ? parseError.message : String(parseError), - responseText: result.text, - }); - // Don't throw here, let the caller handle it - } - } - - return result; - - } catch (error) { - logger.error('Error calling LLM with messages:', { - modelName, - modelType, - error: error instanceof Error ? error.message : String(error), - }); - throw error; - } - } - - - /** - * Parse strict JSON from LLM response, handling common formatting issues - */ - private static parseStrictJSON(text: string): any { - // Trim whitespace - let jsonText = text.trim(); - - // Remove markdown code blocks if present - if (jsonText.startsWith('```json')) { - jsonText = jsonText.replace(/^```json\s*/, '').replace(/\s*```$/, ''); - } else if (jsonText.startsWith('```')) { - jsonText = jsonText.replace(/^```\s*/, '').replace(/\s*```$/, ''); - } - - // Remove any leading/trailing text that's not part of JSON - const jsonMatch = jsonText.match(/\{.*\}/s) || jsonText.match(/\[.*\]/s); - if (jsonMatch) { - jsonText = jsonMatch[0]; - } - - // Try to parse - try { - return JSON.parse(jsonText); - } catch (error) { - // Log the problematic text for debugging - logger.error('Failed to parse JSON after cleanup:', { - original: text, - cleaned: jsonText, - error: error instanceof Error ? error.message : String(error), - }); - throw new Error(`Unable to parse JSON: ${error instanceof Error ? error.message : String(error)}`); - } - } - - /** - * Converts internal ChatMessage array to OpenAI-compatible messages array - */ - private static convertToOpenAIMessages( - messages: ChatMessage[], - systemPrompt: string - ): OpenAIMessage[] { - const result: OpenAIMessage[] = []; - - // Always add system prompt first - result.push({ - role: 'system', - content: systemPrompt - }); - - for (const msg of messages) { - switch (msg.entity) { - case ChatMessageEntity.USER: - result.push({ - role: 'user', - content: msg.text - }); - break; - - case ChatMessageEntity.MODEL: - if (msg.action === 'tool' && msg.toolName) { - result.push({ - role: 'assistant', - content: msg.reasoning ? msg.reasoning.join('\n') : null, - tool_calls: [{ - id: msg.toolCallId || crypto.randomUUID(), - type: 'function', - function: { - name: msg.toolName, - arguments: JSON.stringify(msg.toolArgs || {}) - } - }] - }); - } else if (msg.action === 'final' && msg.answer) { - result.push({ - role: 'assistant', - content: msg.answer - }); - } - break; - - case ChatMessageEntity.TOOL_RESULT: - if (msg.toolCallId) { - result.push({ - role: 'tool', - content: msg.resultText, - tool_call_id: msg.toolCallId, - name: msg.toolName - }); - } - break; - } - } - - return result; - } - - /** - * Determine the model type from localStorage configuration - */ - private static getModelType(modelName: string): 'openai' | 'litellm' { - try { - const modelOptions = JSON.parse(localStorage.getItem(this.MODEL_OPTIONS_KEY) || '[]') as ModelOption[]; - const modelOption = modelOptions.find(opt => opt.value === modelName); - return modelOption?.type || 'openai'; - } catch (e) { - logger.error('Error parsing model options:', e); - return 'openai'; - } - } - - /** - * Call OpenAI models with message array format - */ - private static async callOpenAIWithMessages( - apiKey: string, - modelName: string, - openaiMessages: OpenAIMessage[], - options?: UnifiedLLMOptions - ) { - try { - // Convert UnifiedLLMOptions to OpenAI-specific options (excluding tools and systemPrompt) - const openAIOptions = this.convertToOpenAIOptions(options); - - return await OpenAIClient.callOpenAIWithMessages( - apiKey, - modelName, - openaiMessages, - openAIOptions - ); - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error); - throw new Error(`OpenAI call failed for model ${modelName}: ${errorMessage}`); - } - } - - /** - * Call LiteLLM models with message array format - */ - private static async callLiteLLMWithMessages( - apiKey: string, - modelName: string, - openaiMessages: OpenAIMessage[], - options?: UnifiedLLMOptions - ) { - try { - const { endpoint, apiKey: liteLLMApiKey } = this.getLiteLLMConfig(); - - // Convert UnifiedLLMOptions to LiteLLM-specific options (excluding tools and systemPrompt) - const liteLLMOptions = this.convertToLiteLLMOptions(options, endpoint); - - return await LiteLLMClient.callLiteLLMWithMessages( - liteLLMApiKey || apiKey, - modelName, - openaiMessages, - liteLLMOptions - ); - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error); - throw new Error(`LiteLLM call failed for model ${modelName}: ${errorMessage}`); - } - } - - /** - * Get LiteLLM configuration from localStorage - */ - private static getLiteLLMConfig(): { endpoint: string, apiKey: string } { - const endpoint = localStorage.getItem(this.LITELLM_ENDPOINT_KEY) || ''; - const apiKey = localStorage.getItem(this.LITELLM_API_KEY_KEY) || ''; - - logger.debug('LiteLLM config:', { - hasEndpoint: Boolean(endpoint), - hasApiKey: Boolean(apiKey), - endpointLength: endpoint.length, - }); - - if (!endpoint) { - throw new Error('LiteLLM endpoint not configured. Please configure in AI Chat settings.'); - } - - return { endpoint, apiKey }; - } - - /** - * Convert unified options to OpenAI-specific format - */ - private static convertToOpenAIOptions(options?: UnifiedLLMOptions): any { - if (!options) {return {};} - - return { - max_tokens: options.maxTokens, - temperature: options.temperature, - top_p: options.topP, - frequency_penalty: options.frequencyPenalty, - presence_penalty: options.presencePenalty, - response_format: options.responseFormat, - n: options.n, - stream: options.stream, - signal: options.signal, - systemPrompt: options.systemPrompt, - tools: options.tools, - tool_choice: options.tool_choice, - }; - } - - /** - * Convert unified options to LiteLLM-specific format - */ - private static convertToLiteLLMOptions(options?: UnifiedLLMOptions, endpoint?: string): any { - if (!options) {return { endpoint };} - - // Transform tools for LiteLLM/Anthropic format - let transformedTools = options.tools; - if (options.tools) { - transformedTools = options.tools.map(tool => { - // If the tool is already in the correct format with 'function' property, return as is - if ('function' in tool) { - return tool; - } - - // Transform OpenAI format to Anthropic format - // OpenAI: { type: 'function', name: '...', description: '...', parameters: {...} } - // Anthropic expects: { type: 'function', function: { name: '...', description: '...', parameters: {...} } } - if (tool.type === 'function') { - return { - type: 'function', - function: { - name: tool.name, - description: tool.description, - parameters: tool.parameters - } - }; - } - - // Default: return as is if we don't recognize the format - return tool; - }); - } - - return { - endpoint: options.endpoint || endpoint, - max_tokens: options.maxTokens, - temperature: options.temperature, - top_p: options.topP, - frequency_penalty: options.frequencyPenalty, - presence_penalty: options.presencePenalty, - response_format: options.responseFormat, - n: options.n, - stream: options.stream, - signal: options.signal, - systemPrompt: options.systemPrompt, - tools: transformedTools, - tool_choice: options.tool_choice, - }; - } - - /** - * Test if a model is available and working - */ - static async testModel( - modelName: string, - apiKey?: string - ): Promise<{ success: boolean, error?: string }> { - try { - await this.callLLM( - apiKey || '', - modelName, - 'Hello, this is a test message. Please respond with "OK".', - { - systemPrompt: 'You are a helpful AI assistant for testing purposes.', - maxTokens: 5 - } - ); - return { success: true }; - } catch (error) { - return { - success: false, - error: error instanceof Error ? error.message : 'Unknown error occurred', - }; - } - } - - /** - * Get all configured models from localStorage - */ - static getConfiguredModels(): ModelOption[] { - try { - return JSON.parse(localStorage.getItem(this.MODEL_OPTIONS_KEY) || '[]') as ModelOption[]; - } catch (e) { - logger.error('Error parsing model options:', e); - return []; - } - } - - /** - * Parse unified response to determine action type - * Equivalent to OpenAIClient.parseOpenAIResponse - */ - static parseResponse(response: UnifiedLLMResponse): ParsedLLMAction { - if (response.functionCall) { - return { - type: 'tool_call', - name: response.functionCall.name, - args: response.functionCall.arguments, - }; - } - - if (response.text) { - const rawContent = response.text; - // Attempt to parse text as JSON tool call (fallback for some models) - if (rawContent.trim().startsWith('{') && rawContent.includes('"action":"tool"')) { - try { - const contentJson = JSON.parse(rawContent); - if (contentJson.action === 'tool' && contentJson.toolName) { - return { - type: 'tool_call', - name: contentJson.toolName, - args: contentJson.toolArgs || {}, - }; - } - // Fallback to treating it as text if JSON structure is not a valid tool call - return { type: 'final_answer', answer: rawContent }; - } catch (e) { - // If JSON parsing fails, treat it as plain text - return { type: 'final_answer', answer: rawContent }; - } - } else { - // Treat as plain text final answer - return { type: 'final_answer', answer: rawContent }; - } - } - - return { - type: 'error', - error: 'No valid response from LLM', - }; - } -} - -/** - * Standardized structure for parsed LLM action - */ -export type ParsedLLMAction = - | { type: 'tool_call', name: string, args: Record } - | { type: 'final_answer', answer: string } - | { type: 'error', error: string }; diff --git a/front_end/panels/ai_chat/evaluation/framework/GenericToolEvaluator.ts b/front_end/panels/ai_chat/evaluation/framework/GenericToolEvaluator.ts index 616fc5aa416..59cfe03794a 100644 --- a/front_end/panels/ai_chat/evaluation/framework/GenericToolEvaluator.ts +++ b/front_end/panels/ai_chat/evaluation/framework/GenericToolEvaluator.ts @@ -78,8 +78,6 @@ export class GenericToolEvaluator { `GenericToolEvaluator.toolExecution:${testCase.tool}` ); - // 3. Store the raw tool response for debugging - const rawResponse = toolResult; // Call afterToolExecution hook if (this.hooks?.afterToolExecution) { @@ -87,15 +85,14 @@ export class GenericToolEvaluator { await this.hooks.afterToolExecution(testCase, tool, toolResult); } - // 4. Extract success/failure and output + // 3. Extract success/failure and error from tool result const success = this.isSuccessfulResult(toolResult); - const output = this.extractOutput(toolResult); const error = this.extractError(toolResult); const result: TestResult = { testId: testCase.id, status: success ? 'passed' : 'failed', - output, + output: toolResult, // Use raw tool result directly error: error ? ErrorHandlingUtils.formatUserFriendlyError(error, undefined) : undefined, duration: Date.now() - startTime, timestamp: Date.now(), @@ -105,8 +102,8 @@ export class GenericToolEvaluator { ? `Successfully executed ${testCase.tool}` : `${testCase.tool} execution failed: ${error}`, }, - // Add raw response for debugging - rawResponse, + // Store full tool response for debugging and display + rawResponse: toolResult, }; // Call beforeEvaluation hook @@ -214,49 +211,6 @@ export class GenericToolEvaluator { return true; } - /** - * Extract the meaningful output from any tool result - */ - private extractOutput(result: unknown): unknown { - if (typeof result === 'object' && result !== null) { - // Common output patterns - if ('data' in result) return this.sanitizeOutputIfNeeded(result.data); - if ('output' in result) return this.sanitizeOutputIfNeeded(result.output); - if ('result' in result) return this.sanitizeOutputIfNeeded(result.result); - if ('value' in result) return this.sanitizeOutputIfNeeded(result.value); - - // For tools that return success + other fields - if ('success' in result) { - const resultObj = result as Record; - const { success, error, ...output } = resultObj; - return this.sanitizeOutputIfNeeded(output); - } - } - return this.sanitizeOutputIfNeeded(result); - } - - /** - * Sanitize output data if it contains URLs or dynamic content - */ - private sanitizeOutputIfNeeded(output: unknown): unknown { - if (typeof output === 'string' && this.looksLikeUrl(output)) { - return SanitizationUtils.sanitizeUrl(output); - } - - if (typeof output === 'object' && output !== null) { - // Deep clone and sanitize - return SanitizationUtils.sanitizeOutput(output); - } - - return output; - } - - /** - * Check if a string looks like a URL - */ - private looksLikeUrl(str: string): boolean { - return str.startsWith('http://') || str.startsWith('https://') || str.includes('://'); - } /** * Extract error message from any tool result diff --git a/front_end/panels/ai_chat/evaluation/framework/judges/LLMEvaluator.ts b/front_end/panels/ai_chat/evaluation/framework/judges/LLMEvaluator.ts index 640c1fd45fb..91a2ed2562b 100644 --- a/front_end/panels/ai_chat/evaluation/framework/judges/LLMEvaluator.ts +++ b/front_end/panels/ai_chat/evaluation/framework/judges/LLMEvaluator.ts @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -import { UnifiedLLMClient } from '../../../core/UnifiedLLMClient.js'; +import { LLMClient } from '../../../LLM/LLMClient.js'; import type { TestCase, LLMJudgeResult, ValidationConfig } from '../types.js'; import { createLogger } from '../../../core/Logger.js'; import { ErrorHandlingUtils } from '../../utils/ErrorHandlingUtils.js'; @@ -25,6 +25,15 @@ export class LLMEvaluator { this.defaultModel = defaultModel; } + /** + * Helper function to detect provider from user's settings + */ + private detectProvider(modelName: string): 'openai' | 'litellm' { + // Respect user's provider selection from settings + const selectedProvider = localStorage.getItem('ai_chat_provider') || 'openai'; + return selectedProvider as 'openai' | 'litellm'; + } + /** * Evaluate tool output using an LLM judge (supports both text and vision) */ @@ -84,16 +93,18 @@ export class LLMEvaluator { for (let attempt = 1; attempt <= maxRetries; attempt++) { try { - const response = await UnifiedLLMClient.callLLM( - this.apiKey, - model, - prompt, - { - systemPrompt: PromptTemplates.buildSystemPrompt({ hasVision: false }), - temperature: llmConfig.temperature ?? 0, - responseFormat: { type: 'json_object' }, - } - ); + const llm = LLMClient.getInstance(); + const llmResponse = await llm.call({ + provider: this.detectProvider(model), + model: model, + messages: [ + { role: 'system', content: PromptTemplates.buildSystemPrompt({ hasVision: false }) }, + { role: 'user', content: prompt } + ], + systemPrompt: PromptTemplates.buildSystemPrompt({ hasVision: false }), + temperature: llmConfig.temperature ?? 0 + }); + const response = llmResponse.text || ''; // Clean response before parsing const cleanedResponse = ResponseParsingUtils.cleanResponseText(response); diff --git a/front_end/panels/ai_chat/tools/CritiqueTool.ts b/front_end/panels/ai_chat/tools/CritiqueTool.ts index d64570a597e..685f98a6334 100644 --- a/front_end/panels/ai_chat/tools/CritiqueTool.ts +++ b/front_end/panels/ai_chat/tools/CritiqueTool.ts @@ -4,7 +4,8 @@ import { AgentService } from '../core/AgentService.js'; import { createLogger } from '../core/Logger.js'; -import { UnifiedLLMClient } from '../core/UnifiedLLMClient.js'; +import { LLMClient } from '../LLM/LLMClient.js'; +import type { LLMProvider } from '../LLM/LLMTypes.js'; import { AIChatPanel } from '../ui/AIChatPanel.js'; import type { Tool } from './Tools.js'; @@ -52,6 +53,21 @@ export class CritiqueTool implements Tool name = 'critique_tool'; description = 'Evaluates if finalresponse satisfies the user\'s requirements and provides feedback if needed.'; + /** + * Helper method to detect provider from model name + */ + private detectProvider(modelName: string): LLMProvider { + // OpenAI patterns + if (modelName.startsWith('gpt-') || + modelName.startsWith('o1-') || + modelName.startsWith('o4-')) { + return 'openai'; + } + + // Everything else goes to LiteLLM + return 'litellm'; + } + schema = { type: 'object', properties: { @@ -166,19 +182,25 @@ Return a JSON array of requirement statements. Example format: try { const modelName = AIChatPanel.getMiniModel(); - const response = await UnifiedLLMClient.callLLM( - apiKey, - modelName, - userPrompt, - { systemPrompt, temperature: 0.1 } - ); + const llm = LLMClient.getInstance(); + const provider = this.detectProvider(modelName); + + const response = await llm.call({ + provider, + model: modelName, + messages: [ + { role: 'user', content: userPrompt } + ], + systemPrompt, + temperature: 0.1, + }); - if (!response) { + if (!response.text) { return { success: false, requirements: [], error: 'No response received' }; } // Parse the JSON array from the response - const requirementsMatch = response.match(/\[(.*)\]/s); + const requirementsMatch = response.text.match(/\[(.*)\]/s); if (!requirementsMatch) { return { success: false, requirements: [], error: 'Failed to parse requirements' }; } @@ -254,19 +276,25 @@ ${JSON.stringify(evaluationSchema, null, 2)}`; try { const modelName = AIChatPanel.getMiniModel(); - const response = await UnifiedLLMClient.callLLM( - apiKey, - modelName, - userPrompt, - { systemPrompt, temperature: 0.1 } - ); + const llm = LLMClient.getInstance(); + const provider = this.detectProvider(modelName); + + const response = await llm.call({ + provider, + model: modelName, + messages: [ + { role: 'user', content: userPrompt } + ], + systemPrompt, + temperature: 0.1, + }); - if (!response) { + if (!response.text) { return { success: false, error: 'No response received' }; } // Extract JSON object from the response - const jsonMatch = response.match(/\{[\s\S]*\}/); + const jsonMatch = response.text.match(/\{[\s\S]*\}/); if (!jsonMatch) { return { success: false, error: 'Failed to parse evaluation criteria' }; } @@ -309,14 +337,20 @@ Be concise, specific, and constructive.`; try { const modelName = AIChatPanel.getMiniModel(); - const response = await UnifiedLLMClient.callLLM( - apiKey, - modelName, - userPrompt, - { systemPrompt, temperature: 0.7 } - ); + const llm = LLMClient.getInstance(); + const provider = this.detectProvider(modelName); + + const response = await llm.call({ + provider, + model: modelName, + messages: [ + { role: 'user', content: userPrompt } + ], + systemPrompt, + temperature: 0.7, + }); - return response || 'The plan does not meet all requirements, but no specific feedback could be generated.'; + return response.text || 'The plan does not meet all requirements, but no specific feedback could be generated.'; } catch (error: any) { logger.error('Error generating feedback', error); return 'Failed to generate detailed feedback, but the plan does not meet all requirements.'; diff --git a/front_end/panels/ai_chat/tools/FullPageAccessibilityTreeToMarkdownTool.ts b/front_end/panels/ai_chat/tools/FullPageAccessibilityTreeToMarkdownTool.ts index 7df9d6c1aa1..0748a732106 100644 --- a/front_end/panels/ai_chat/tools/FullPageAccessibilityTreeToMarkdownTool.ts +++ b/front_end/panels/ai_chat/tools/FullPageAccessibilityTreeToMarkdownTool.ts @@ -3,7 +3,7 @@ // found in the LICENSE file. import { AgentService } from '../core/AgentService.js'; -import { UnifiedLLMClient } from '../core/UnifiedLLMClient.js'; +import { LLMClient } from '../LLM/LLMClient.js'; import { AIChatPanel } from '../ui/AIChatPanel.js'; import { GetAccessibilityTreeTool, type Tool, type ErrorResult } from './Tools.js'; @@ -25,6 +25,15 @@ export class FullPageAccessibilityTreeToMarkdownTool implements Tool { // Call LLM using the unified client - const response = await UnifiedLLMClient.callLLM( - params.apiKey, - AIChatPanel.getNanoModel(), - params.userPrompt, - { - systemPrompt: params.systemPrompt, - temperature: 0.2, // Lower temperature for more deterministic results - } - ); + const modelName = AIChatPanel.getNanoModel(); + const llm = LLMClient.getInstance(); + const llmResponse = await llm.call({ + provider: this.detectProvider(modelName), + model: modelName, + messages: [ + { role: 'system', content: params.systemPrompt }, + { role: 'user', content: params.userPrompt } + ], + systemPrompt: params.systemPrompt, + temperature: 0.2 // Lower temperature for more deterministic results + }); + const response = llmResponse.text; // Process the response - UnifiedLLMClient returns string directly const markdownContent = response || ''; diff --git a/front_end/panels/ai_chat/tools/SchemaBasedExtractorTool.ts b/front_end/panels/ai_chat/tools/SchemaBasedExtractorTool.ts index dfb72b4cd47..4f6e561242c 100644 --- a/front_end/panels/ai_chat/tools/SchemaBasedExtractorTool.ts +++ b/front_end/panels/ai_chat/tools/SchemaBasedExtractorTool.ts @@ -7,7 +7,7 @@ import * as Protocol from '../../../generated/protocol.js'; import * as Utils from '../common/utils.js'; import { AgentService } from '../core/AgentService.js'; import { createLogger } from '../core/Logger.js'; -import { UnifiedLLMClient } from '../core/UnifiedLLMClient.js'; +import { LLMClient } from '../LLM/LLMClient.js'; import { AIChatPanel } from '../ui/AIChatPanel.js'; import { NodeIDsToURLsTool, type Tool } from './Tools.js'; @@ -63,6 +63,15 @@ export class SchemaBasedExtractorTool implements Tool { try { const context = await this.setupExecution(args); @@ -237,12 +246,18 @@ IMPORTANT: Only extract data that you can see in the accessibility tree above. D } const modelName = AIChatPanel.getMiniModel(); - const result = await UnifiedLLMClient.callLLM( - apiKey, - modelName, - extractionPrompt, - { systemPrompt, temperature: 0.1, strictJsonMode: true } - ); + const llm = LLMClient.getInstance(); + const llmResponse = await llm.call({ + provider: this.detectProvider(modelName), + model: modelName, + messages: [ + { role: 'system', content: systemPrompt }, + { role: 'user', content: extractionPrompt } + ], + systemPrompt: systemPrompt, + temperature: 0.1 + }); + const result = llmResponse.text; logger.debug(`JSON extraction successful on attempt ${attempt}`); return result; @@ -357,12 +372,18 @@ CRITICAL: Only use nodeIds that you can actually see in the accessibility tree a try { const modelName = AIChatPanel.getMiniModel(); - const result = await UnifiedLLMClient.callLLM( - apiKey, - modelName, - extractionPrompt, - { systemPrompt, temperature: 0.1, strictJsonMode: true } - ); + const llm = LLMClient.getInstance(); + const llmResponse = await llm.call({ + provider: this.detectProvider(modelName), + model: modelName, + messages: [ + { role: 'system', content: systemPrompt }, + { role: 'user', content: extractionPrompt } + ], + systemPrompt: systemPrompt, + temperature: 0.1 + }); + const result = llmResponse.text; return result; } catch (error) { diff --git a/front_end/panels/ai_chat/tools/Tools.ts b/front_end/panels/ai_chat/tools/Tools.ts index 1cdbb893eb9..35914cdfc23 100644 --- a/front_end/panels/ai_chat/tools/Tools.ts +++ b/front_end/panels/ai_chat/tools/Tools.ts @@ -16,9 +16,8 @@ import type { LogLine } from '../common/log.js'; import * as Utils from '../common/utils.js'; import { getXPathByBackendNodeId } from '../common/utils.js'; import { AgentService } from '../core/AgentService.js'; -import { OpenAIClient } from '../core/OpenAIClient.js'; import type { DevToolsContext } from '../core/State.js'; -import { UnifiedLLMClient } from '../core/UnifiedLLMClient.js'; +import { LLMClient } from '../LLM/LLMClient.js'; import { AIChatPanel } from '../ui/AIChatPanel.js'; import { ChatMessageEntity } from '../ui/ChatView.js'; @@ -1828,6 +1827,15 @@ Important guidelines: - Choose the most semantically appropriate element when multiple options exist.`; } + /** + * Helper function to detect provider from user's settings + */ + private detectProvider(modelName: string): 'openai' | 'litellm' { + // Respect user's provider selection from settings + const selectedProvider = localStorage.getItem('ai_chat_provider') || 'openai'; + return selectedProvider as 'openai' | 'litellm'; + } + async execute(args: { objective: string, offset?: number, chunkSize?: number, maxRetries?: number }): Promise { const { objective, offset = 0, chunkSize = 60000, maxRetries = 1 } = args; // Default offset 0, chunkSize 60000, maxRetries 1 let currentTry = 0; @@ -1891,26 +1899,32 @@ Important guidelines: - Prefer the most direct path to accomplishing the objective. - Choose the most semantically appropriate element when multiple options exist.`; - // Use UnifiedLLMClient with function call support - const messages = [{ - entity: ChatMessageEntity.USER as const, - text: promptGetAction - }]; - - const response = await UnifiedLLMClient.callLLMWithMessages( - apiKey, - modelNameForAction, - messages, - { - systemPrompt: this.getSystemPrompt(), - tools: [{ - type: 'function', + // Use LLMClient with function call support + const llm = LLMClient.getInstance(); + const llmResponse = await llm.call({ + provider: this.detectProvider(modelNameForAction), + model: modelNameForAction, + messages: [ + { role: 'system', content: this.getSystemPrompt() }, + { role: 'user', content: promptGetAction } + ], + systemPrompt: this.getSystemPrompt(), + tools: [{ + type: 'function', + function: { name: performActionTool.name, description: performActionTool.description, parameters: performActionTool.schema - }], - }, - ); + } + }], + temperature: 0.4 + }); + + // Convert LLMResponse to expected format + const response = { + text: llmResponse.text, + functionCall: llmResponse.functionCall + }; // --- Parse the Tool Call Response --- if (!response.functionCall || response.functionCall.name !== performActionTool.name) { @@ -2704,6 +2718,15 @@ CRITICAL: return value === undefined || value === null || value === ''; } + /** + * Helper function to detect provider from user's settings + */ + private detectProvider(modelName: string): 'openai' | 'litellm' { + // Respect user's provider selection from settings + const selectedProvider = localStorage.getItem('ai_chat_provider') || 'openai'; + return selectedProvider as 'openai' | 'litellm'; + } + async execute(args: { objective: string, schema: Record, offset?: number, chunkSize?: number, maxRetries?: number }): Promise { const { objective, schema, offset = 0, chunkSize = 60000, maxRetries = 1 } = args; // Default offset 0, chunkSize 60000, maxRetries 1 let currentTry = 0; @@ -2760,15 +2783,19 @@ ${lastError ? `Previous attempt failed with this error: "${lastError}". Consider Extract NodeIDs according to the provided objective and schema, then return a structured JSON with NodeIDs instead of content.`; logger.info('SchemaBasedDataExtractionTool: Prompt:', promptExtractData); - // Use UnifiedLLMClient to call the LLM - const response = await UnifiedLLMClient.callLLM( - apiKey, - modelNameForExtraction, - promptExtractData, - { - systemPrompt: this.getSystemPrompt(), - }, - ); + // Use LLMClient to call the LLM + const llm = LLMClient.getInstance(); + const llmResponse = await llm.call({ + provider: this.detectProvider(modelNameForExtraction), + model: modelNameForExtraction, + messages: [ + { role: 'system', content: this.getSystemPrompt() }, + { role: 'user', content: promptExtractData } + ], + systemPrompt: this.getSystemPrompt(), + temperature: 0.7 + }); + const response = llmResponse.text; logger.info('SchemaBasedDataExtractionTool: Response:', response); // Process the LLM response - this now contains NodeIDs instead of content diff --git a/front_end/panels/ai_chat/ui/AIChatPanel.ts b/front_end/panels/ai_chat/ui/AIChatPanel.ts index a341277ba9f..037e13212cf 100644 --- a/front_end/panels/ai_chat/ui/AIChatPanel.ts +++ b/front_end/panels/ai_chat/ui/AIChatPanel.ts @@ -10,7 +10,7 @@ import * as UI from '../../../ui/legacy/legacy.js'; import * as Lit from '../../../ui/lit/lit.js'; import * as VisualLogging from '../../../ui/visual_logging/visual_logging.js'; import {AgentService, Events as AgentEvents} from '../core/AgentService.js'; -import { LiteLLMClient } from '../core/LiteLLMClient.js'; +import { LLMClient } from '../LLM/LLMClient.js'; import { createLogger } from '../core/Logger.js'; const logger = createLogger('AIChatPanel'); @@ -603,7 +603,7 @@ export class AIChatPanel extends UI.Panel.Panel { } // Always fetch fresh models from LiteLLM - const models = await LiteLLMClient.fetchModels(apiKey, endpoint); + const models = await LLMClient.fetchLiteLLMModels(apiKey, endpoint); // Check if wildcard model exists const hadWildcard = models.some(model => model.id === '*'); diff --git a/front_end/panels/ai_chat/ui/SettingsDialog.ts b/front_end/panels/ai_chat/ui/SettingsDialog.ts index 35696068742..5f92ec88363 100644 --- a/front_end/panels/ai_chat/ui/SettingsDialog.ts +++ b/front_end/panels/ai_chat/ui/SettingsDialog.ts @@ -4,7 +4,7 @@ import * as i18n from '../../../core/i18n/i18n.js'; import * as UI from '../../../ui/legacy/legacy.js'; -import { LiteLLMClient } from '../core/LiteLLMClient.js'; +import { LLMClient } from '../LLM/LLMClient.js'; import { createLogger } from '../core/Logger.js'; const logger = createLogger('SettingsDialog'); @@ -740,7 +740,7 @@ export class SettingsDialog { throw new Error(i18nString(UIStrings.endpointRequired)); } - const result = await LiteLLMClient.testConnection(liteLLMApiKey, model, endpoint); + const result = await LLMClient.testLiteLLMConnection(liteLLMApiKey, model, endpoint); if (result.success) { testStatus.textContent = '✓'; @@ -831,7 +831,7 @@ export class SettingsDialog { throw new Error(i18nString(UIStrings.endpointRequired)); } - const result = await LiteLLMClient.testConnection(liteLLMApiKey, modelName, endpoint); + const result = await LLMClient.testLiteLLMConnection(liteLLMApiKey, modelName, endpoint); if (result.success) { modelTestStatus.textContent = `Test passed: ${result.message}`;