Skip to content

Commit

Permalink
Improve Azure OpenAI support (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
BramVanroy authored Mar 28, 2024
1 parent aafd8fc commit 1240852
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
40 changes: 39 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ In addition, the following extras are available:

- `hf-transformers`: for using models available in [transformers](https://github.com/huggingface/transformers) package via the `TransformersLLM` integration.
- `hf-inference-endpoints`: for using the [HuggingFace Inference Endpoints](https://huggingface.co/inference-endpoints) via the `InferenceEndpointsLLM` integration.
- `openai`: for using OpenAI API models via the `OpenAILLM` integration.
- `openai`: for using (Azure) OpenAI API models via the `OpenAILLM` integration.
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
- `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as Python bindings for `llama.cpp`.
- `ollama`: for using [Ollama](https://github.com/ollama/ollama) and their available models via their Python client.
Expand Down Expand Up @@ -92,6 +92,44 @@ rg_dataset = dataset.to_argilla()
rg_dataset.push_to_argilla(name="preference-dataset", workspace="admin")
```

## Azure OpenAI API

To use the Azure OpenAI API you can use the `distilabel.llm.OpenAILLM` but in a slightly different way than using
the regular OpenAI API. For now, you will need to instantiate an `openai.AzureOpenAI` client yourself and pass
that to the `OpenAILLM` constructor rather than relying on its `model` and `api_key` arguments. Secondly,
instead of using `model` to define the model you want to use (like `gpt4`), you need to set `model` to your
Azure deployment name.

An example

```python
from distilabel.llm import OpenAILLM
from distilabel.tasks import TextGenerationTask
from openai import AzureOpenAI


api_key= "<azure-super-secret-api-key>"
api_version = "2024-02-15-preview" # replace with your own
azure_endpoint = "https://<endpoint-name>.openai.azure.com"
deployment = "<deployment-name>"

client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint
)

llm = OpenAILLM(
task=TextGenerationTask(),
client=client, # Important!
model=deployment, # Important!
)

messages = [{"input": "Write me a short poem."}]

print(llm.generate(messages))
```

## More examples

Find more examples of different use cases of `distilabel` under [`examples/`](./examples/).
Expand Down
19 changes: 11 additions & 8 deletions src/distilabel/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from distilabel.utils.imports import _OPENAI_AVAILABLE

if _OPENAI_AVAILABLE:
from openai import OpenAI
from openai import OpenAI, AzureOpenAI

if TYPE_CHECKING:
from distilabel.tasks.base import Task
Expand All @@ -45,7 +45,7 @@ def __init__(
self,
task: "Task",
model: str = "gpt-3.5-turbo",
client: Union["OpenAI", None] = None,
client: Union["OpenAI", "AzureOpenAI", None] = None,
api_key: Union[str, None] = None,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
Expand Down Expand Up @@ -117,12 +117,15 @@ def __init__(

self.client = client or OpenAI(api_key=api_key, max_retries=6)

try:
assert (
model in self.available_models
), f"Provided `model` is not available in your OpenAI account, available models are {self.available_models}"
except AttributeError:
logger.warning("Unable to check if model is available in your account.")
if not isinstance(self.client, AzureOpenAI):
# In Azure, the model in create() is actually the deployment_name and therefore it can be any arbitrarily
# chosen name. So it is not a given that it is part of the available models
try:
assert (
model in self.available_models
), f"Provided `model` is not available in your OpenAI account, available models are {self.available_models}"
except AttributeError:
logger.warning("Unable to check if model is available in your account.")
self.model = model

def __rich_repr__(self) -> Generator[Any, None, None]:
Expand Down

0 comments on commit 1240852

Please sign in to comment.