Skip to content

Commit

Permalink
Ollama Text Embedder with new format (#252)
Browse files Browse the repository at this point in the history
* add tests

* add ollama text embedder

* add init for text embedder

* format with black

* lint with ruff

* add meta to return message
  • Loading branch information
AlistairLR112 authored Feb 9, 2024
1 parent 9bd4417 commit a2e5e8f
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .text_embedder import OllamaTextEmbedder

__all__ = ["OllamaTextEmbedder"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any, Dict, List, Optional

import requests
from haystack import component


@component
class OllamaTextEmbedder:
def __init__(
self,
model: str = "orca-mini",
url: str = "http://localhost:11434/api/embeddings",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model

def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param text: Text that is to be converted to an embedding
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}}

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run an Ollama Model on a given chat history.
:param text: Text to be converted to an embedding.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: A dictionary with the key "embedding" and a list of floats as the value
"""

payload = self._create_json_payload(text, generation_kwargs)

response = requests.post(url=self.url, json=payload, timeout=self.timeout)

response.raise_for_status()

result = response.json()
result["meta"] = {"model": self.model, "duration": response.elapsed}

return result
43 changes: 43 additions & 0 deletions integrations/ollama/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from haystack_integrations.components.embedders.ollama import OllamaTextEmbedder
from requests import HTTPError


class TestOllamaTextEmbedder:
def test_init_defaults(self):
embedder = OllamaTextEmbedder()

assert embedder.timeout == 120
assert embedder.generation_kwargs == {}
assert embedder.url == "http://localhost:11434/api/embeddings"
assert embedder.model == "orca-mini"

def test_init(self):
embedder = OllamaTextEmbedder(
model="llama2",
url="http://my-custom-endpoint:11434/api/embeddings",
generation_kwargs={"temperature": 0.5},
timeout=3000,
)

assert embedder.timeout == 3000
assert embedder.generation_kwargs == {"temperature": 0.5}
assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings"
assert embedder.model == "llama2"

@pytest.mark.integration
def test_model_not_found(self):
embedder = OllamaTextEmbedder(model="cheese")

with pytest.raises(HTTPError):
embedder.run("hello")

@pytest.mark.integration
def test_run(self):
embedder = OllamaTextEmbedder(model="orca-mini")

reply = embedder.run("hello")

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["embedding"])
assert reply["meta"]["model"] == "orca-mini"

0 comments on commit a2e5e8f

Please sign in to comment.