Skip to content

Commit

Permalink
Merge pull request #4 from argilla-io/feat/choose-models
Browse files Browse the repository at this point in the history
add support for custom BASE_URL, MODEL, API_KEY
  • Loading branch information
davidberenstein1957 authored Dec 3, 2024
2 parents da59bd9 + 9feda8c commit ec33fc2
Show file tree
Hide file tree
Showing 20 changed files with 428 additions and 712 deletions.
1 change: 0 additions & 1 deletion .python-version

This file was deleted.

52 changes: 32 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,13 @@ hf_oauth_scopes:

<h1 align="center">
<br>
Synthetic Data Generator
🧬 Synthetic Data Generator
<br>
</h1>
<h3 align="center">Build datasets using natural language</h2>

![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png)

<p align="center">
<a href="https://pypi.org/project/synthetic-dataset-generator/">
<img alt="CI" src="https://img.shields.io/pypi/v/synthetic-dataset-generator.svg?style=flat-round&logo=pypi&logoColor=white">
</a>
<a href="https://pepy.tech/project/synthetic-dataset-generator">
<img alt="CI" src="https://static.pepy.tech/personalized-badge/synthetic-dataset-generator?period=month&units=international_system&left_color=grey&right_color=blue&left_text=pypi%20downloads/month">
</a>
<a href="https://huggingface.co/spaces/argilla/synthetic-data-generator?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg"/>
</a>
</p>

<p align="center">
<a href="https://twitter.com/argilla_io">
<img src="https://img.shields.io/badge/twitter-black?logo=x"/>
Expand Down Expand Up @@ -78,21 +66,29 @@ You can simply install the package with:
pip install synthetic-dataset-generator
```

### Quickstart

```python
from synthetic_dataset_generator.app import demo

demo.launch()
```

### Environment Variables

- `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run *Free* Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained).
- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints.

Optionally, you can set the following environment variables to customize the generation process.

- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`.

Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:

- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.

## Quickstart

```bash
python app.py
```

### Argilla integration

Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
Expand All @@ -104,3 +100,19 @@ Argilla is a open source tool for data curation. It allows you to annotate and r
Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.

Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information.

## Development

Install the dependencies:

```bash
python -m venv .venv
source .venv/bin/activate
pip install -e .
```

Run the app:

```bash
python app.py
```
36 changes: 1 addition & 35 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,4 @@
from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
from src.distilabel_dataset_generator.apps.eval import app as eval_app
from src.distilabel_dataset_generator.apps.faq import app as faq_app
from src.distilabel_dataset_generator.apps.sft import app as sft_app
from src.distilabel_dataset_generator.apps.textcat import app as textcat_app

theme = "argilla/argilla-theme"

css = """
button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)}
button.hf-login {background: var(--neutral-800); color: white}
button.hf-login:hover {background: var(--neutral-700); color: white}
.tabitem { border: 0; padding-inline: 0}
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
.group_padding{padding: .55em}
.gallery-item {background: var(--background-fill-secondary); text-align: left}
.gallery {white-space: wrap}
#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
#system_prompt_examples {
color: var(--body-text-color) !important;
background-color: var(--block-background-fill) !important;
}
.container {padding-inline: 0 !important}
"""

demo = TabbedInterface(
[textcat_app, sft_app, eval_app, faq_app],
["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
css=css,
title="Synthetic Data Generator",
head="Synthetic Data Generator",
theme=theme,
)

from distilabel_dataset_generator.app import demo

if __name__ == "__main__":
demo.launch()
16 changes: 12 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ description = "Build datasets using natural language"
authors = [
{name = "davidberenstein1957", email = "[email protected]"},
]
tags = [
"gradio",
"synthetic-data",
"huggingface",
"argilla",
"generative-ai",
"ai",
]
requires-python = "<3.13,>=3.10"
readme = "README.md"
license = {text = "Apache 2"}

dependencies = [
"distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
"gradio[oauth]<5.0.0",
Expand All @@ -14,14 +26,10 @@ dependencies = [
"gradio-huggingfacehub-search>=0.0.7",
"argilla>=2.4.0",
]
requires-python = "<3.13,>=3.10"
readme = "README.md"
license = {text = "apache 2"}

[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"


[tool.pdm]
distribution = true
74 changes: 50 additions & 24 deletions src/distilabel_dataset_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,64 @@
import os
import warnings
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import argilla as rg
import distilabel
import distilabel.distiset
from distilabel.llms import InferenceEndpointsLLM
from distilabel.utils.card.dataset_card import (
DistilabelDatasetCard,
size_categories_parser,
)
from huggingface_hub import DatasetCardData, HfApi, upload_file
from huggingface_hub import DatasetCardData, HfApi
from pydantic import (
ValidationError,
model_validator,
)

HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
HF_TOKENS = [token for token in HF_TOKENS if token]

if len(HF_TOKENS) == 0:
raise ValueError(
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
)
class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
@model_validator(mode="after") # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
self,
) -> "InferenceEndpointsLLM":
"""Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
favour of the dynamically calculated one.."""

if self.base_url and (self.model_id or self.endpoint_name):
warnings.warn( # type: ignore
f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
" or `endpoint_name` is also provided, the `base_url` will either be ignored"
" or overwritten with the one generated from either of those args, for serverless"
" or dedicated inference endpoints, respectively."
)

if self.use_magpie_template and self.tokenizer_id is None:
raise ValueError(
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
" set a `tokenizer_id` and try again."
)

ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
if (
self.model_id
and self.tokenizer_id is None
and self.structured_output is not None
):
self.tokenizer_id = self.model_id

if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
argilla_client = None
else:
argilla_client = rg.Argilla(
api_url=ARGILLA_API_URL,
api_key=ARGILLA_API_KEY,
)
if self.base_url and not (self.model_id or self.endpoint_name):
return self

if self.model_id and not self.endpoint_name:
return self

if self.endpoint_name and not self.model_id:
return self

raise ValidationError(
f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
)


class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
Expand Down Expand Up @@ -138,3 +163,4 @@ def _get_card(


distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag
distilabel.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
4 changes: 3 additions & 1 deletion src/distilabel_dataset_generator/_tabbedinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def __init__(
with gr.Column(scale=3):
pass
with gr.Column(scale=2):
gr.LoginButton(value="Sign in!", variant="hf-login", size="sm", scale=2)
gr.LoginButton(
value="Sign in", variant="hf-login", size="sm", scale=2
)
with Tabs():
for interface, tab_name in zip(interface_list, tab_names, strict=False):
with Tab(label=tab_name):
Expand Down
38 changes: 38 additions & 0 deletions src/distilabel_dataset_generator/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from distilabel_dataset_generator._tabbedinterface import TabbedInterface
from distilabel_dataset_generator.apps.eval import app as eval_app
from distilabel_dataset_generator.apps.faq import app as faq_app
from distilabel_dataset_generator.apps.sft import app as sft_app
from distilabel_dataset_generator.apps.textcat import app as textcat_app

theme = "argilla/argilla-theme"

css = """
button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)}
button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)}
button.hf-login {background: var(--neutral-800); color: white}
button.hf-login:hover {background: var(--neutral-700); color: white}
.tabitem { border: 0; padding-inline: 0}
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
.group_padding{padding: .55em}
.gallery-item {background: var(--background-fill-secondary); text-align: left}
.gallery {white-space: wrap}
#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
#system_prompt_examples {
color: var(--body-text-color) !important;
background-color: var(--block-background-fill) !important;
}
.container {padding-inline: 0 !important}
"""

demo = TabbedInterface(
[textcat_app, sft_app, eval_app, faq_app],
["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
css=css,
title="Synthetic Data Generator",
head="Synthetic Data Generator",
theme=theme,
)


if __name__ == "__main__":
demo.launch()
Empty file.
Loading

0 comments on commit ec33fc2

Please sign in to comment.