From 125db534cb04b4e21453698d065a7fd55f610930 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 11 Jul 2024 19:39:02 +0800 Subject: [PATCH] fix: Fix http empty headers error (#1710) --- dbgpt/model/cluster/worker/remote_worker.py | 3 ++- dbgpt/rag/embedding/embeddings.py | 10 ++++++---- dbgpt/rag/embedding/rerank.py | 10 ++++++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/dbgpt/model/cluster/worker/remote_worker.py b/dbgpt/model/cluster/worker/remote_worker.py index 70cbbb02d..5025ca990 100644 --- a/dbgpt/model/cluster/worker/remote_worker.py +++ b/dbgpt/model/cluster/worker/remote_worker.py @@ -161,5 +161,6 @@ async def async_embeddings(self, params: Dict) -> List[List[float]]: def _get_trace_headers(self): span_id = root_tracer.get_current_span_id() headers = self.headers.copy() - headers.update({DBGPT_TRACER_SPAN_ID: span_id}) + if span_id: + headers.update({DBGPT_TRACER_SPAN_ID: span_id}) return headers diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 97820cc5d..7d14c0fb5 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -693,9 +693,10 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: """ # Call OpenAI Embedding API headers = {} - if self.pass_trace_id: + current_span_id = root_tracer.get_current_span_id() + if self.pass_trace_id and current_span_id: # Set the trace ID if available - headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id() + headers[DBGPT_TRACER_SPAN_ID] = current_span_id res = self.session.post( # type: ignore self.api_url, json={"input": texts, "model": self.model_name}, @@ -726,9 +727,10 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: List[float] corresponds to a single input text. """ headers = {"Authorization": f"Bearer {self.api_key}"} - if self.pass_trace_id: + current_span_id = root_tracer.get_current_span_id() + if self.pass_trace_id and current_span_id: # Set the trace ID if available - headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id() + headers[DBGPT_TRACER_SPAN_ID] = current_span_id async with aiohttp.ClientSession( headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) ) as session: diff --git a/dbgpt/rag/embedding/rerank.py b/dbgpt/rag/embedding/rerank.py index 0c8fb8009..b797d0fa7 100644 --- a/dbgpt/rag/embedding/rerank.py +++ b/dbgpt/rag/embedding/rerank.py @@ -117,9 +117,10 @@ def predict(self, query: str, candidates: List[str]) -> List[float]: if not candidates: return [] headers = {} - if self.pass_trace_id: + current_span_id = root_tracer.get_current_span_id() + if self.pass_trace_id and current_span_id: # Set the trace ID if available - headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id() + headers[DBGPT_TRACER_SPAN_ID] = current_span_id data = {"model": self.model_name, "query": query, "documents": candidates} response = self.session.post( # type: ignore self.api_url, json=data, timeout=self.timeout, headers=headers @@ -130,9 +131,10 @@ def predict(self, query: str, candidates: List[str]) -> List[float]: async def apredict(self, query: str, candidates: List[str]) -> List[float]: """Predict the rank scores of the candidates asynchronously.""" headers = {"Authorization": f"Bearer {self.api_key}"} - if self.pass_trace_id: + current_span_id = root_tracer.get_current_span_id() + if self.pass_trace_id and current_span_id: # Set the trace ID if available - headers[DBGPT_TRACER_SPAN_ID] = root_tracer.get_current_span_id() + headers[DBGPT_TRACER_SPAN_ID] = current_span_id async with aiohttp.ClientSession( headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) ) as session: