Skip to content

Commit

Permalink
Review OllamaLLM and TogetherInferenceLLM (#305)
Browse files Browse the repository at this point in the history
* Minor type-hint and style fixes within `OllamaLLM`

* Update URL to `ollama` in GitHub

* Fix typo in `TogetherInferenceLLM` docstring

* Exclude image models from `TogetherInferenceLLM` model listing

* Fix `test_together.py` due to `available_models` update
  • Loading branch information
alvarobartt authored Jan 29, 2024
1 parent cc5d08e commit c4ce874
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
17 changes: 9 additions & 8 deletions src/distilabel/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self,
model: str,
task: "Task",
max_new_tokens: int = None, # num_predict
max_new_tokens: Union[int, None] = None,
temperature: Union[float, None] = None,
top_k: Union[int, None] = None,
top_p: Union[float, None] = None,
Expand All @@ -59,12 +59,11 @@ def __init__(
prompt_formatting_fn: Union[Callable[..., str], None] = None,
) -> None:
"""
Initializes the OllamaLLM class and align with https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
Initializes the OllamaLLM class and aligns with https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
Args:
model (str): the model to be used for generation.
task (Task): the task to be performed by the LLM.
max_new_tokens (int, optional): the maximum number of tokens to be generated.
Defaults to `None`.
temperature (float, optional): the temperature to be used for generation.
Expand All @@ -90,7 +89,7 @@ def __init__(
ValueError: if the Ollama API request failed.
Examples:
>>> from distilabel.tasks.text_generation import TextGenerationTask as Task
>>> from distilabel.tasks import TextGenerationTask as Task
>>> from distilabel.llm import OllamaLLM
>>> task = Task()
>>> llm = OllamaLLM(model="notus", task=task)
Expand All @@ -117,8 +116,8 @@ def model_name(self) -> str:
return self.model

def _api_available(self):
"""calls GET {OLLAMA_HOST}"""
msg = f"Could not connect to Ollama as {self.OLLAMA_HOST}. Check https://github.com/jmorganca/ollama for deployment guide."
"""Calls GET {OLLAMA_HOST}"""
msg = f"Could not connect to Ollama as {self.OLLAMA_HOST}. Check https://github.com/ollama/ollama for deployment guide."
try:
response = request.urlopen(self.OLLAMA_HOST)
if response.getcode() != 200:
Expand All @@ -145,7 +144,9 @@ def _api_model_available(self):
before_sleep=before_sleep_log(logger, logging.INFO),
after=after_log(logger, logging.INFO),
)
def _text_generation_with_backoff(self, prompt: str, **kwargs) -> str:
def _text_generation_with_backoff(
self, prompt: List[Dict[str, str]], **kwargs
) -> str:
"""Calls POST {OLLAMA_HOST}/api/chat"""
# Request payload
payload = {
Expand Down Expand Up @@ -215,7 +216,7 @@ def _generate(
]
output = []
for response in responses:
raw_output = response.get("message", {}).get("content")
raw_output = response.get("message", {}).get("content") # type: ignore
try:
parsed_response = self.task.parse_output(raw_output.strip())
except Exception as e:
Expand Down
11 changes: 7 additions & 4 deletions src/distilabel/llm/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
prompt_format: Union["SupportedFormats", None] = None,
prompt_formatting_fn: Union[Callable[..., str], None] = None,
) -> None:
"""Initializes the OpenAILLM class.
"""Initializes the TogetherInferenceLLM class.
Args:
task (Task): the task to be performed by the LLM.
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
AssertionError: if the provided `model` is not available in Together Inference.
Examples:
>>> from distilabel.tasks.text_generation import TextGenerationTask as Task
>>> from distilabel.tasks import TextGenerationTask as Task
>>> from distilabel.llm import TogetherInferenceLLM
>>> task = Task()
>>> llm = TogetherInferenceLLM(model="togethercomputer/llama-2-7b", task=task, prompt_format="llama2")
Expand Down Expand Up @@ -152,8 +152,11 @@ def __rich_repr__(self) -> Generator[Any, None, None]:
@cached_property
def available_models(self) -> List[str]:
"""Returns the list of available models in Together Inference."""
# TODO: exclude the image models
return [model["name"] for model in together.Models.list()]
return [
model["name"]
for model in together.Models.list()
if model["display_type"] != "image"
]

@property
def model_name(self) -> str:
Expand Down
16 changes: 12 additions & 4 deletions tests/llm/test_together.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
class TestTogetherInferenceLLM:
def test_available_models(self) -> None:
together.Models.list = mock.MagicMock(
return_value=[{"name": "togethercomputer/llama-2-7b"}]
return_value=[
{"name": "togethercomputer/llama-2-7b", "display_type": "chat"}
]
)
llm = TogetherInferenceLLM(
model="togethercomputer/llama-2-7b",
Expand All @@ -34,7 +36,9 @@ def test_available_models(self) -> None:

def test_inference_kwargs(self) -> None:
together.Models.list = mock.MagicMock(
return_value=[{"name": "togethercomputer/llama-2-7b"}]
return_value=[
{"name": "togethercomputer/llama-2-7b", "display_type": "chat"}
]
)
llm = TogetherInferenceLLM(
model="togethercomputer/llama-2-7b",
Expand All @@ -60,7 +64,9 @@ def test_inference_kwargs(self) -> None:

def test__generate_single_output(self) -> None:
together.Models.list = mock.MagicMock(
return_value=[{"name": "togethercomputer/llama-2-7b"}]
return_value=[
{"name": "togethercomputer/llama-2-7b", "display_type": "chat"}
]
)
llm = TogetherInferenceLLM(
model="togethercomputer/llama-2-7b",
Expand All @@ -87,7 +93,9 @@ def test__generate_single_output(self) -> None:

def test__generate(self) -> None:
together.Models.list = mock.MagicMock(
return_value=[{"name": "togethercomputer/llama-2-7b"}]
return_value=[
{"name": "togethercomputer/llama-2-7b", "display_type": "chat"}
]
)
llm = TogetherInferenceLLM(
model="togethercomputer/llama-2-7b",
Expand Down

0 comments on commit c4ce874

Please sign in to comment.