From 365309ac840612b077f5603a91a9e86bb5ee32b9 Mon Sep 17 00:00:00 2001 From: yym68686 Date: Fri, 7 Feb 2025 17:37:29 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Enhance=20timeout=20calculation=20f?= =?UTF-8?q?or=20multi-provider=20requests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 292996c..b663e11 100644 --- a/main.py +++ b/main.py @@ -820,7 +820,7 @@ def get_timeout_value(provider_timeouts, original_model): return timeout_value # 在 process_request 函数中更新成功和失败计数 -async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, role=None): +async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, role=None, num_matching_providers=1): url = provider['base_url'] parsed_url = urlparse(url) # print("parsed_url", parsed_url) @@ -905,7 +905,8 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A timeout_value = get_timeout_value(app.state.provider_timeouts["global_time_out"], original_model) if timeout_value is None: timeout_value = app.state.timeouts.get("default", DEFAULT_TIMEOUT) - # print("timeout_value", timeout_value) + timeout_value = timeout_value * num_matching_providers + # print("timeout_value", channel_id, timeout_value) proxy = safe_get(provider, "preferences", "proxy", default=None) # print("proxy", proxy) @@ -1187,8 +1188,17 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques current_index = (start_index + index) % num_matching_providers index += 1 provider = matching_providers[current_index] + + if provider['provider'].startswith("sk-") and provider['provider'] in app.state.api_list: + local_provider_api_index = app.state.api_list.index(provider['provider']) + local_provider_scheduling_algorithm = safe_get(config, 'api_keys', local_provider_api_index, "preferences", "SCHEDULING_ALGORITHM", default="fixed_priority") + local_provider_matching_providers = await get_right_order_providers(request_model, config, local_provider_api_index, local_provider_scheduling_algorithm) + local_provider_num_matching_providers = len(local_provider_matching_providers) + else: + local_provider_num_matching_providers = 1 + try: - response = await process_request(request, provider, endpoint, role) + response = await process_request(request, provider, endpoint, role, local_provider_num_matching_providers) return response except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectError) as e: