Skip to content

Commit

Permalink
Add MixtureOfAgentsLLM (#735)
Browse files Browse the repository at this point in the history
* Update `RuntimeParametersMixin` to handle `list`s

* Check if `generation_kwargs` is present

* Add `get_generation_kwargs` method

* Add `_num_generation_param_supported` attribute to avoid code
duplication

* Refactor `OllamaLLM` and `VertexAILLM`

* Add `MixtureOfAgents` llm

* Add docstrings

* Fix unit tests

* Update docstrings

* Fix missing #

* Update Arena Hard tasks docstrings

* Fix cross-reference

* Update unit tests

* Rename to `MixtureOfAgentsLLM`

* Update `_extra_serializable_fields` to work with `List[_Serializable]`
attributes

* Remove `from_dict` method for `Step`

* Update to render list

* Update handling `List[RuntimeParametersMixin]` attributes

* Fix unit tests

* Remove test code

* Add `MixtureOfAgentsLLM` docstring example

* Add alias for runtime parameters names
  • Loading branch information
gabrielmbmb authored Jun 18, 2024
1 parent 9d6a152 commit d736dd7
Show file tree
Hide file tree
Showing 44 changed files with 799 additions and 507 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ When working with model providers behind an API, there's no direct way of access
```

!!! Note
Take a look at [`InstructorStructuredOutputType`][distilabel.steps.tasks.structured_outputs.instructor.InstructorStructuredOutputType] to see the expected format
Take a look at [`InstructorStructuredOutputType`][distilabel.steps.tasks.typing.InstructorStructuredOutputType] to see the expected format
of the `structured_output` dict variable.

The following is the same example you can see with `outlines`'s `JSON` section for comparison purposes.
Expand Down
31 changes: 19 additions & 12 deletions src/distilabel/cli/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def get_pipeline(config: str) -> "BasePipeline":
FileNotFoundError: If the configuration file does not exist.
"""
if valid_http_url(config):
return Pipeline.from_dict(get_config_from_url(config))
data = get_config_from_url(config)
return Pipeline.from_dict(data)

if Path(config).is_file():
return Pipeline.from_file(config)
Expand Down Expand Up @@ -200,32 +201,38 @@ def _build_steps_panel(pipeline: "BasePipeline") -> "Panel":
from rich.table import Table

def _add_rows(
table: Table, runtime_params: List[Dict[str, Any]], prefix: str = ""
table: Table,
runtime_params: List[Dict[str, Any]],
prefix: str = "",
) -> None:
for param in runtime_params:
if isinstance(param, str):
_add_rows(table, runtime_params[param], f"{prefix}{param}.")
continue

# nested (for example `LLM` in `Task`)
if "runtime_parameters_info" in param:
_add_rows(
table=table,
runtime_params=param["runtime_parameters_info"],
prefix=f"{prefix}{param['name']}.",
)
continue

# `LLM` special case
if "keys" in param:
elif "keys" in param:
_add_rows(
table=table,
runtime_params=param["keys"],
prefix=f"{prefix}{param['name']}.",
)
continue

optional = param.get("optional", "")
if optional != "":
optional = "Yes" if optional else "No"
return
else:
optional = param.get("optional", "")
if optional != "":
optional = "Yes" if optional else "No"

table.add_row(prefix + param["name"], param.get("description"), optional)
table.add_row(
prefix + param["name"], param.get("description"), optional
)

steps = []
for step_name, runtime_params in pipeline.get_runtime_parameters_info().items():
Expand All @@ -239,7 +246,7 @@ def _add_rows(
expand=True,
)

table.add_column("Runtime parameter", style="dim", width=50)
table.add_column("Runtime parameter", style="dim", width=60)
table.add_column("Description", width=100)
table.add_column("Optional", justify="right")
_add_rows(table, runtime_params)
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from distilabel.llms.llamacpp import LlamaCppLLM
from distilabel.llms.mistral import MistralLLM
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.moa import MixtureOfAgentsLLM
from distilabel.llms.ollama import OllamaLLM
from distilabel.llms.openai import OpenAILLM
from distilabel.llms.together import TogetherLLM
Expand All @@ -43,6 +44,7 @@
"LlamaCppLLM",
"MistralLLM",
"CudaDevicePlacementMixin",
"MixtureOfAgentsLLM",
"OllamaLLM",
"OpenAILLM",
"TogetherLLM",
Expand Down
41 changes: 6 additions & 35 deletions src/distilabel/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Optional,
Expand All @@ -28,13 +26,14 @@

from httpx import AsyncClient
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from typing_extensions import override

from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
from distilabel.utils.itertools import grouper
from distilabel.steps.tasks.typing import (
FormattedInput,
InstructorStructuredOutputType,
)

if TYPE_CHECKING:
from anthropic import AsyncAnthropic
Expand Down Expand Up @@ -112,11 +111,7 @@ class User(BaseModel):
llm.load()
# Synchronous request
output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
# Asynchronous request
output = await llm.agenerate(input=[{"role": "user", "content": "Create a user profile for the following marathon"}])
```
"""

Expand Down Expand Up @@ -148,6 +143,8 @@ class User(BaseModel):
)
)

_num_generations_param_supported = False

_api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME)
_aclient: Optional["AsyncAnthropic"] = PrivateAttr(...)

Expand Down Expand Up @@ -282,29 +279,3 @@ async def agenerate( # type: ignore
)
generations.append(content)
return generations

# TODO: remove this function once Anthropic client allows `n` parameter
@override
def generate(
self,
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Method to generate a list of responses asynchronously, returning the output
synchronously awaiting for the response of each input sent to `agenerate`.
"""

async def agenerate(
inputs: List["FormattedInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
asyncio.create_task(self.agenerate(input=input, **kwargs))
for input in inputs
for _ in range(num_generations)
]
return [outputs[0] for outputs in await asyncio.gather(*tasks)]

outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
return list(grouper(outputs, n=num_generations, incomplete="ignore"))
4 changes: 0 additions & 4 deletions src/distilabel/llms/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ class AnyscaleLLM(OpenAILLM):
llm.load()
# Synchronous request
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
# Asynchronous request
output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}])
```
"""

Expand Down
4 changes: 0 additions & 4 deletions src/distilabel/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ class User(BaseModel):
llm.load()
# Synchronous request
output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
# Asynchronous request
output = await llm.agenerate(input=[{"role": "user", "content": "Create a user profile for the following marathon"}])
```
"""

Expand Down
Loading

0 comments on commit d736dd7

Please sign in to comment.