From c4ce874cc59c5be890ab2ea443cbd48a5222e008 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Mon, 29 Jan 2024 19:04:08 +0900 Subject: [PATCH] Review `OllamaLLM` and `TogetherInferenceLLM` (#305) * Minor type-hint and style fixes within `OllamaLLM` * Update URL to `ollama` in GitHub * Fix typo in `TogetherInferenceLLM` docstring * Exclude image models from `TogetherInferenceLLM` model listing * Fix `test_together.py` due to `available_models` update --- src/distilabel/llm/ollama.py | 17 +++++++++-------- src/distilabel/llm/together.py | 11 +++++++---- tests/llm/test_together.py | 16 ++++++++++++---- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/distilabel/llm/ollama.py b/src/distilabel/llm/ollama.py index 8d6fa966f0..47b4694c7a 100644 --- a/src/distilabel/llm/ollama.py +++ b/src/distilabel/llm/ollama.py @@ -50,7 +50,7 @@ def __init__( self, model: str, task: "Task", - max_new_tokens: int = None, # num_predict + max_new_tokens: Union[int, None] = None, temperature: Union[float, None] = None, top_k: Union[int, None] = None, top_p: Union[float, None] = None, @@ -59,12 +59,11 @@ def __init__( prompt_formatting_fn: Union[Callable[..., str], None] = None, ) -> None: """ - Initializes the OllamaLLM class and align with https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + Initializes the OllamaLLM class and aligns with https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values Args: model (str): the model to be used for generation. task (Task): the task to be performed by the LLM. - max_new_tokens (int, optional): the maximum number of tokens to be generated. Defaults to `None`. temperature (float, optional): the temperature to be used for generation. @@ -90,7 +89,7 @@ def __init__( ValueError: if the Ollama API request failed. Examples: - >>> from distilabel.tasks.text_generation import TextGenerationTask as Task + >>> from distilabel.tasks import TextGenerationTask as Task >>> from distilabel.llm import OllamaLLM >>> task = Task() >>> llm = OllamaLLM(model="notus", task=task) @@ -117,8 +116,8 @@ def model_name(self) -> str: return self.model def _api_available(self): - """calls GET {OLLAMA_HOST}""" - msg = f"Could not connect to Ollama as {self.OLLAMA_HOST}. Check https://github.com/jmorganca/ollama for deployment guide." + """Calls GET {OLLAMA_HOST}""" + msg = f"Could not connect to Ollama as {self.OLLAMA_HOST}. Check https://github.com/ollama/ollama for deployment guide." try: response = request.urlopen(self.OLLAMA_HOST) if response.getcode() != 200: @@ -145,7 +144,9 @@ def _api_model_available(self): before_sleep=before_sleep_log(logger, logging.INFO), after=after_log(logger, logging.INFO), ) - def _text_generation_with_backoff(self, prompt: str, **kwargs) -> str: + def _text_generation_with_backoff( + self, prompt: List[Dict[str, str]], **kwargs + ) -> str: """Calls POST {OLLAMA_HOST}/api/chat""" # Request payload payload = { @@ -215,7 +216,7 @@ def _generate( ] output = [] for response in responses: - raw_output = response.get("message", {}).get("content") + raw_output = response.get("message", {}).get("content") # type: ignore try: parsed_response = self.task.parse_output(raw_output.strip()) except Exception as e: diff --git a/src/distilabel/llm/together.py b/src/distilabel/llm/together.py index a6d43ad7f4..8c22d4046e 100644 --- a/src/distilabel/llm/together.py +++ b/src/distilabel/llm/together.py @@ -49,7 +49,7 @@ def __init__( prompt_format: Union["SupportedFormats", None] = None, prompt_formatting_fn: Union[Callable[..., str], None] = None, ) -> None: - """Initializes the OpenAILLM class. + """Initializes the TogetherInferenceLLM class. Args: task (Task): the task to be performed by the LLM. @@ -96,7 +96,7 @@ def __init__( AssertionError: if the provided `model` is not available in Together Inference. Examples: - >>> from distilabel.tasks.text_generation import TextGenerationTask as Task + >>> from distilabel.tasks import TextGenerationTask as Task >>> from distilabel.llm import TogetherInferenceLLM >>> task = Task() >>> llm = TogetherInferenceLLM(model="togethercomputer/llama-2-7b", task=task, prompt_format="llama2") @@ -152,8 +152,11 @@ def __rich_repr__(self) -> Generator[Any, None, None]: @cached_property def available_models(self) -> List[str]: """Returns the list of available models in Together Inference.""" - # TODO: exclude the image models - return [model["name"] for model in together.Models.list()] + return [ + model["name"] + for model in together.Models.list() + if model["display_type"] != "image" + ] @property def model_name(self) -> str: diff --git a/tests/llm/test_together.py b/tests/llm/test_together.py index 92a9de528f..0c159c8bc2 100644 --- a/tests/llm/test_together.py +++ b/tests/llm/test_together.py @@ -23,7 +23,9 @@ class TestTogetherInferenceLLM: def test_available_models(self) -> None: together.Models.list = mock.MagicMock( - return_value=[{"name": "togethercomputer/llama-2-7b"}] + return_value=[ + {"name": "togethercomputer/llama-2-7b", "display_type": "chat"} + ] ) llm = TogetherInferenceLLM( model="togethercomputer/llama-2-7b", @@ -34,7 +36,9 @@ def test_available_models(self) -> None: def test_inference_kwargs(self) -> None: together.Models.list = mock.MagicMock( - return_value=[{"name": "togethercomputer/llama-2-7b"}] + return_value=[ + {"name": "togethercomputer/llama-2-7b", "display_type": "chat"} + ] ) llm = TogetherInferenceLLM( model="togethercomputer/llama-2-7b", @@ -60,7 +64,9 @@ def test_inference_kwargs(self) -> None: def test__generate_single_output(self) -> None: together.Models.list = mock.MagicMock( - return_value=[{"name": "togethercomputer/llama-2-7b"}] + return_value=[ + {"name": "togethercomputer/llama-2-7b", "display_type": "chat"} + ] ) llm = TogetherInferenceLLM( model="togethercomputer/llama-2-7b", @@ -87,7 +93,9 @@ def test__generate_single_output(self) -> None: def test__generate(self) -> None: together.Models.list = mock.MagicMock( - return_value=[{"name": "togethercomputer/llama-2-7b"}] + return_value=[ + {"name": "togethercomputer/llama-2-7b", "display_type": "chat"} + ] ) llm = TogetherInferenceLLM( model="togethercomputer/llama-2-7b",