Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update deployment with API providers #20

Merged
merged 16 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ hf_oauth_scopes:

## Introduction

Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also wathh the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.
Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also watch the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action.

Supported Tasks:

Expand Down Expand Up @@ -76,21 +76,25 @@ launch()

- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder.

Optionally, you can set the following environment variables to customize the generation process.
You can set the following environment variables to customize the generation process.

- `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
- `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
- `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.

Optionally, you can use different models and APIs. For providers outside of Hugging Face, we provide an integration through [LiteLLM](https://docs.litellm.ai/docs/providers).
Optionally, you can use different API providers and models.

- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`, `http://127.0.0.1:11434/v1/`.
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`.
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
- `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.

SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2.
SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.

- `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. Llama3 and Qwen2 are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"` respectively. For other models, you can pass a custom pre-query template string.
- `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
- `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string.

Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from synthetic_dataset_generator import launch

# Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL
os.environ["ARGILLA_API_URL"] = "https://[your-owner-name]-[your_space_name].hf.space"
os.environ["ARGILLA_API_KEY"] = "my_api_key"
os.environ["HF_TOKEN"] = "hf_..."
os.environ["ARGILLA_API_URL"] = (
"https://[your-owner-name]-[your_space_name].hf.space" # argilla base url
)
os.environ["ARGILLA_API_KEY"] = "my_api_key" # argilla api key

launch()
14 changes: 0 additions & 14 deletions examples/enforce_mapgie_template copy.py

This file was deleted.

2 changes: 1 addition & 1 deletion examples/fine-tune-modernbert-classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.11.11"
}
},
"nbformat": 4,
Expand Down
19 changes: 19 additions & 0 deletions examples/hf-dedicated-or-tgi-deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "synthetic-dataset-generator",
# ]
# ///
import os

from synthetic_dataset_generator import launch

os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template
os.environ["TOKENIZER_ID"] = (
"meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint
)
os.environ["MODEL"] = None # model is linked to endpoint

launch()
15 changes: 15 additions & 0 deletions examples/hf-serverless-deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "synthetic-dataset-generator",
# ]
# ///
import os

from synthetic_dataset_generator import launch

os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model

launch()
22 changes: 22 additions & 0 deletions examples/ollama-deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "synthetic-dataset-generator",
# ]
# ///
# ollama serve
# ollama run qwen2.5:32b-instruct-q5_K_S
import os

from synthetic_dataset_generator import launch

os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url
os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S" # model id
os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct" # tokenizer id
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
os.environ["MAX_NUM_ROWS"] = "10000"
os.environ["DEFAULT_BATCH_SIZE"] = "2"
os.environ["MAX_NUM_TOKENS"] = "1024"

launch()
15 changes: 0 additions & 15 deletions examples/ollama_local.py

This file was deleted.

18 changes: 18 additions & 0 deletions examples/openai-deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "synthetic-dataset-generator",
# ]
# ///

import os

from synthetic_dataset_generator import launch

os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url
os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") # openai api key
os.environ["MODEL"] = "gpt-4o" # model id
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI

launch()
16 changes: 0 additions & 16 deletions examples/openai_local.py

This file was deleted.

21 changes: 21 additions & 0 deletions examples/vllm-deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "synthetic-dataset-generator",
# ]
# ///
# vllm serve Qwen/Qwen2.5-1.5B-Instruct
import os

from synthetic_dataset_generator import launch

os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url
os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct" # model id
os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct" # tokenizer id
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2"
os.environ["MAX_NUM_ROWS"] = "10000"
os.environ["DEFAULT_BATCH_SIZE"] = "2"
os.environ["MAX_NUM_TOKENS"] = "1024"

launch()
67 changes: 51 additions & 16 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ readme = "README.md"
license = {text = "Apache 2"}

dependencies = [
"distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1,<2.0.0",
"distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm] @ git+https://github.com/argilla-io/distilabel.git@develop",
"gradio[oauth]>=5.4.0,<6.0.0",
"transformers>=4.44.2,<5.0.0",
"sentence-transformers>=3.2.0,<4.0.0",
Expand Down
10 changes: 10 additions & 0 deletions src/synthetic_dataset_generator/_distiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def _get_card(
dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
)

keys = list(sample_records.keys())
if len(keys) != 2 or not (
("label" in keys and "text" in keys)
or ("labels" in keys and "text" in keys)
):
task_categories = ["text-classification"]
elif "prompt" in keys or "messages" in keys:
task_categories = ["text-generation", "text2text-generation"]

readme_metadata = {}
if repo_id and token:
readme_metadata = self._extract_readme_metadata(repo_id, token)
Expand All @@ -90,6 +99,7 @@ def _get_card(
"size_categories": size_categories_parser(
max(len(dataset) for dataset in self.values())
),
"task_categories": task_categories,
"tags": [
"synthetic",
"distilabel",
Expand Down
9 changes: 7 additions & 2 deletions src/synthetic_dataset_generator/apps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,15 @@ def validate_push_to_hub(org_name, repo_name):
return repo_id


def combine_datasets(repo_id: str, dataset: Dataset) -> Dataset:
def combine_datasets(
repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None]
) -> Dataset:
try:
new_dataset = load_dataset(
repo_id, split="train", download_mode="force_redownload"
repo_id,
split="train",
download_mode="force_redownload",
token=oauth_token.token,
)
return concatenate_datasets([dataset, new_dataset])
except Exception:
Expand Down
Loading