From e91d747692dc397c64b5781afce850ae87428e69 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 26 Aug 2024 21:07:48 +0800 Subject: [PATCH] check nan and inf --- langport/model/executor/generation/chatgpt.py | 8 +++++--- langport/model/executor/generation/huggingface.py | 5 +++++ requirements.txt | 4 ++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/langport/model/executor/generation/chatgpt.py b/langport/model/executor/generation/chatgpt.py index e5b5d97..9942cb3 100644 --- a/langport/model/executor/generation/chatgpt.py +++ b/langport/model/executor/generation/chatgpt.py @@ -53,8 +53,10 @@ def __init__( api_key=api_key, ) - openai.api_base = api_url - openai.api_key = api_key + self.client = openai.OpenAI( + base_url = api_url, + api_key=api_key, + ) self._context_len = 2048 @@ -92,7 +94,7 @@ def inference(self, worker: "GenerationModelWorker"): "role": role, "content": section }) - response = openai.ChatCompletion.create( + response = self.client.chat.completions.create( model=self.model_name, messages=messages, stream=True, diff --git a/langport/model/executor/generation/huggingface.py b/langport/model/executor/generation/huggingface.py index c7b059e..9f36246 100644 --- a/langport/model/executor/generation/huggingface.py +++ b/langport/model/executor/generation/huggingface.py @@ -307,6 +307,11 @@ def generate(self, inputs: BatchingTask, if task.temperature < 1e-5 or task.top_p < 1e-8: # greedy token = int(torch.argmax(last_token_logits)) else: + # 检测 nan 和 inf,并将它们替换为 0 + last_token_logits[torch.isnan(last_token_logits)] = 0 + last_token_logits[torch.isinf(last_token_logits)] = 0 + # 对 logits 进行处理,确保没有 inf、nan 或小于 0 的值 + last_token_logits = torch.clamp(last_token_logits, min=0) # 将所有小于 0 的值设为 0 probs = torch.softmax(last_token_logits, dim=-1) sampled_tensor = torch.multinomial(probs, num_samples=2, replacement=False) token = int(sampled_tensor[0].item()) diff --git a/requirements.txt b/requirements.txt index c694ad6..f84db93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,8 @@ fastapi==0.111.0 httpx==0.24.0 numpy==1.24.3 psutil==5.9.5 -pydantic==2.7.4 -pydantic-settings==2.3.4 +pydantic==2.8.2 +pydantic-settings==2.4.0 requests==2.30.0 shortuuid==1.0.11 tenacity==8.2.2