Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich committed Feb 28, 2024
1 parent 4a463ca commit f6bb399
Showing 1 changed file with 8 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@


class NemoEmbedding(BaseEmbedding):
"""Nvidia NeMo embeddings.
"""
"""Nvidia NeMo embeddings."""

def __init__(
self,
Expand All @@ -37,24 +36,21 @@ def class_name(cls) -> str:
return "NemoEmbedding"

def _get_embedding(self, text: str, input_type: str):
payload = json.dumps({
"input": text,
"model": self.model_name,
"input_type": input_type
})
headers = {
'Content-Type': 'application/json'
}
payload = json.dumps(
{"input": text, "model": self.model_name, "input_type": input_type}
)
headers = {"Content-Type": "application/json"}

response = requests.request(
"POST", self.api_endpoint_url, headers=headers, data=payload)
"POST", self.api_endpoint_url, headers=headers, data=payload
)
response = json.loads(response.text)

return response["data"][0]["embedding"]

def _get_query_embedding(self, query: str) -> List[float]:
return self._get_embedding(text, input_type="query")

def _get_text_embedding(self, text: str) -> List[float]:
return self._get_embedding(text, input_type="passage")

Expand Down

0 comments on commit f6bb399

Please sign in to comment.