diff --git a/.python-version b/.python-version deleted file mode 100644 index 9f675fa..0000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -synthetic-data-generator diff --git a/README.md b/README.md index 16c61ad..ca77bd8 100644 --- a/README.md +++ b/README.md @@ -20,25 +20,13 @@ hf_oauth_scopes:


- Synthetic Data Generator + 🧬 Synthetic Data Generator

Build datasets using natural language

![Synthetic Data Generator](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/ui-full.png) -

- -CI - - -CI - - - - -

-

@@ -78,21 +66,29 @@ You can simply install the package with: pip install synthetic-dataset-generator ``` +### Quickstart + +```python +from synthetic_dataset_generator.app import demo + +demo.launch() +``` + ### Environment Variables -- `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run *Free* Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained). +- `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. + +Optionally, you can set the following environment variables to customize the generation process. + +- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`. +- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`. +- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`. Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables: - `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla. - `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla. -## Quickstart - -```bash -python app.py -``` - ### Argilla integration Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/). @@ -104,3 +100,19 @@ Argilla is a open source tool for data curation. It allows you to annotate and r Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps. Check out the [distilabel library](https://github.com/argilla-io/distilabel) for more information. + +## Development + +Install the dependencies: + +```bash +python -m venv .venv +source .venv/bin/activate +pip install -e . +``` + +Run the app: + +```bash +python app.py +``` diff --git a/app.py b/app.py index 04b9409..a952cb7 100644 --- a/app.py +++ b/app.py @@ -1,38 +1,4 @@ -from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface -from src.distilabel_dataset_generator.apps.eval import app as eval_app -from src.distilabel_dataset_generator.apps.faq import app as faq_app -from src.distilabel_dataset_generator.apps.sft import app as sft_app -from src.distilabel_dataset_generator.apps.textcat import app as textcat_app - -theme = "argilla/argilla-theme" - -css = """ -button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)} -button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)} -button.hf-login {background: var(--neutral-800); color: white} -button.hf-login:hover {background: var(--neutral-700); color: white} -.tabitem { border: 0; padding-inline: 0} -.main_ui_logged_out{opacity: 0.3; pointer-events: none} -.group_padding{padding: .55em} -.gallery-item {background: var(--background-fill-secondary); text-align: left} -.gallery {white-space: wrap} -#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none} -#system_prompt_examples { - color: var(--body-text-color) !important; - background-color: var(--block-background-fill) !important; -} -.container {padding-inline: 0 !important} -""" - -demo = TabbedInterface( - [textcat_app, sft_app, eval_app, faq_app], - ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"], - css=css, - title="Synthetic Data Generator", - head="Synthetic Data Generator", - theme=theme, -) - +from distilabel_dataset_generator.app import demo if __name__ == "__main__": demo.launch() diff --git a/pyproject.toml b/pyproject.toml index 47c3a51..ddf19ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,18 @@ description = "Build datasets using natural language" authors = [ {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"}, ] +tags = [ + "gradio", + "synthetic-data", + "huggingface", + "argilla", + "generative-ai", + "ai", +] +requires-python = "<3.13,>=3.10" +readme = "README.md" +license = {text = "Apache 2"} + dependencies = [ "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1", "gradio[oauth]<5.0.0", @@ -14,14 +26,10 @@ dependencies = [ "gradio-huggingfacehub-search>=0.0.7", "argilla>=2.4.0", ] -requires-python = "<3.13,>=3.10" -readme = "README.md" -license = {text = "apache 2"} [build-system] requires = ["pdm-backend"] build-backend = "pdm.backend" - [tool.pdm] distribution = true diff --git a/src/distilabel_dataset_generator/__init__.py b/src/distilabel_dataset_generator/__init__.py index 1c9126c..9b8c50d 100644 --- a/src/distilabel_dataset_generator/__init__.py +++ b/src/distilabel_dataset_generator/__init__.py @@ -1,39 +1,64 @@ -import os import warnings -from pathlib import Path -from typing import Optional, Union +from typing import Optional -import argilla as rg import distilabel import distilabel.distiset +from distilabel.llms import InferenceEndpointsLLM from distilabel.utils.card.dataset_card import ( DistilabelDatasetCard, size_categories_parser, ) -from huggingface_hub import DatasetCardData, HfApi, upload_file +from huggingface_hub import DatasetCardData, HfApi +from pydantic import ( + ValidationError, + model_validator, +) -HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)] -HF_TOKENS = [token for token in HF_TOKENS if token] -if len(HF_TOKENS) == 0: - raise ValueError( - "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints." - ) +class CustomInferenceEndpointsLLM(InferenceEndpointsLLM): + @model_validator(mode="after") # type: ignore + def only_one_of_model_id_endpoint_name_or_base_url_provided( + self, + ) -> "InferenceEndpointsLLM": + """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also + provided, a warning will be shown informing the user that the provided `base_url` will be ignored in + favour of the dynamically calculated one..""" + + if self.base_url and (self.model_id or self.endpoint_name): + warnings.warn( # type: ignore + f"Since the `base_url={self.base_url}` is available and either one of `model_id`" + " or `endpoint_name` is also provided, the `base_url` will either be ignored" + " or overwritten with the one generated from either of those args, for serverless" + " or dedicated inference endpoints, respectively." + ) + + if self.use_magpie_template and self.tokenizer_id is None: + raise ValueError( + "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please," + " set a `tokenizer_id` and try again." + ) -ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") -ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") -if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: - ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER") - ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER") + if ( + self.model_id + and self.tokenizer_id is None + and self.structured_output is not None + ): + self.tokenizer_id = self.model_id -if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: - warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set") - argilla_client = None -else: - argilla_client = rg.Argilla( - api_url=ARGILLA_API_URL, - api_key=ARGILLA_API_KEY, - ) + if self.base_url and not (self.model_id or self.endpoint_name): + return self + + if self.model_id and not self.endpoint_name: + return self + + if self.endpoint_name and not self.model_id: + return self + + raise ValidationError( + f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is" + f" provided too, it will be overwritten instead. Found `model_id`={self.model_id}," + f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}." + ) class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset): @@ -138,3 +163,4 @@ def _get_card( distilabel.distiset.Distiset = CustomDistisetWithAdditionalTag +distilabel.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM diff --git a/src/distilabel_dataset_generator/_tabbedinterface.py b/src/distilabel_dataset_generator/_tabbedinterface.py index 277004f..4263c06 100644 --- a/src/distilabel_dataset_generator/_tabbedinterface.py +++ b/src/distilabel_dataset_generator/_tabbedinterface.py @@ -68,7 +68,9 @@ def __init__( with gr.Column(scale=3): pass with gr.Column(scale=2): - gr.LoginButton(value="Sign in!", variant="hf-login", size="sm", scale=2) + gr.LoginButton( + value="Sign in", variant="hf-login", size="sm", scale=2 + ) with Tabs(): for interface, tab_name in zip(interface_list, tab_names, strict=False): with Tab(label=tab_name): diff --git a/src/distilabel_dataset_generator/app.py b/src/distilabel_dataset_generator/app.py new file mode 100644 index 0000000..53ec94f --- /dev/null +++ b/src/distilabel_dataset_generator/app.py @@ -0,0 +1,38 @@ +from distilabel_dataset_generator._tabbedinterface import TabbedInterface +from distilabel_dataset_generator.apps.eval import app as eval_app +from distilabel_dataset_generator.apps.faq import app as faq_app +from distilabel_dataset_generator.apps.sft import app as sft_app +from distilabel_dataset_generator.apps.textcat import app as textcat_app + +theme = "argilla/argilla-theme" + +css = """ +button[role="tab"][aria-selected="true"] { border: 0; background: var(--neutral-800); color: white; border-top-right-radius: var(--radius-md); border-top-left-radius: var(--radius-md)} +button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-primary-background-fill)} +button.hf-login {background: var(--neutral-800); color: white} +button.hf-login:hover {background: var(--neutral-700); color: white} +.tabitem { border: 0; padding-inline: 0} +.main_ui_logged_out{opacity: 0.3; pointer-events: none} +.group_padding{padding: .55em} +.gallery-item {background: var(--background-fill-secondary); text-align: left} +.gallery {white-space: wrap} +#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none} +#system_prompt_examples { + color: var(--body-text-color) !important; + background-color: var(--block-background-fill) !important; +} +.container {padding-inline: 0 !important} +""" + +demo = TabbedInterface( + [textcat_app, sft_app, eval_app, faq_app], + ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"], + css=css, + title="Synthetic Data Generator", + head="Synthetic Data Generator", + theme=theme, +) + + +if __name__ == "__main__": + demo.launch() diff --git a/src/distilabel_dataset_generator/apps/__init__.py b/src/distilabel_dataset_generator/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/distilabel_dataset_generator/apps/base.py b/src/distilabel_dataset_generator/apps/base.py index 0b6cc4f..aead4df 100644 --- a/src/distilabel_dataset_generator/apps/base.py +++ b/src/distilabel_dataset_generator/apps/base.py @@ -1,6 +1,6 @@ import io import uuid -from typing import Any, Callable, List, Tuple, Union +from typing import List, Union import argilla as rg import gradio as gr @@ -10,161 +10,11 @@ from gradio import OAuthToken from huggingface_hub import HfApi, upload_file -from src.distilabel_dataset_generator.utils import ( - _LOGGED_OUT_CSS, +from distilabel_dataset_generator.constants import TEXTCAT_TASK +from distilabel_dataset_generator.utils import ( get_argilla_client, - get_login_button, - list_orgs, - swap_visibility, ) -TEXTCAT_TASK = "text_classification" -SFT_TASK = "supervised_fine_tuning" - - -def get_main_ui( - default_dataset_descriptions: List[str], - default_system_prompts: List[str], - default_datasets: List[pd.DataFrame], - fn_generate_system_prompt: Callable, - fn_generate_dataset: Callable, - task: str, -): - def fn_generate_sample_dataset(system_prompt, progress=gr.Progress()): - if system_prompt in default_system_prompts: - index = default_system_prompts.index(system_prompt) - if index < len(default_datasets): - return default_datasets[index] - if task == TEXTCAT_TASK: - result = fn_generate_dataset( - system_prompt=system_prompt, - difficulty="high school", - clarity="clear", - labels=[], - num_labels=1, - num_rows=1, - progress=progress, - is_sample=True, - ) - else: - result = fn_generate_dataset( - system_prompt=system_prompt, - num_turns=1, - num_rows=1, - progress=progress, - is_sample=True, - ) - return result - - with gr.Blocks( - title="🧬 Synthetic Data Generator", - head="🧬 Synthetic Data Generator", - css=_LOGGED_OUT_CSS, - ) as app: - with gr.Row(): - gr.HTML( - """

How does it work?

""" - ) - with gr.Row(): - gr.Markdown( - "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation." - ) - with gr.Row(): - gr.Column() - get_login_button() - gr.Column() - - gr.Markdown("## Iterate on a sample dataset") - with gr.Column() as main_ui: - ( - dataset_description, - examples, - btn_generate_system_prompt, - system_prompt, - sample_dataset, - btn_generate_sample_dataset, - ) = get_iterate_on_sample_dataset_ui( - default_dataset_descriptions=default_dataset_descriptions, - default_system_prompts=default_system_prompts, - default_datasets=default_datasets, - task=task, - ) - gr.Markdown("## Generate full dataset") - gr.Markdown( - "Once you're satisfied with the sample, generate a larger dataset and push it to Argilla or the Hugging Face Hub." - ) - with gr.Row(variant="panel") as custom_input_ui: - pass - - ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - final_dataset, - success_message, - ) = get_push_to_ui(default_datasets) - - sample_dataset.change( - fn=lambda x: x, - inputs=[sample_dataset], - outputs=[final_dataset], - ) - - btn_generate_system_prompt.click( - fn=fn_generate_system_prompt, - inputs=[dataset_description], - outputs=[system_prompt], - show_progress=True, - ).then( - fn=fn_generate_sample_dataset, - inputs=[system_prompt], - outputs=[sample_dataset], - show_progress=True, - ) - - btn_generate_sample_dataset.click( - fn=fn_generate_sample_dataset, - inputs=[system_prompt], - outputs=[sample_dataset], - show_progress=True, - ) - - app.load(fn=swap_visibility, outputs=main_ui) - app.load(get_org_dropdown, outputs=[org_name]) - - return ( - app, - main_ui, - custom_input_ui, - dataset_description, - examples, - btn_generate_system_prompt, - system_prompt, - sample_dataset, - btn_generate_sample_dataset, - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - final_dataset, - success_message, - ) - def validate_argilla_user_workspace_dataset( dataset_name: str, @@ -195,186 +45,6 @@ def validate_argilla_user_workspace_dataset( return "" -def get_org_dropdown(oauth_token: Union[OAuthToken, None]): - orgs = list_orgs(oauth_token) - return gr.Dropdown( - label="Organization", - choices=orgs, - value=orgs[0] if orgs else None, - allow_custom_value=True, - ) - - -def get_push_to_ui(default_datasets): - with gr.Column() as push_to_ui: - ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - ) = get_argilla_tab() - ( - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - ) = get_hf_tab() - final_dataset = get_final_dataset_row(default_datasets) - success_message = get_success_message_row() - return ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - final_dataset, - success_message, - ) - - -def get_iterate_on_sample_dataset_ui( - default_dataset_descriptions: List[str], - default_system_prompts: List[str], - default_datasets: List[pd.DataFrame], - task: str, -): - with gr.Column(): - dataset_description = gr.TextArea( - label="Give a precise description of your desired application. Check the examples for inspiration.", - value=default_dataset_descriptions[0], - lines=2, - ) - examples = gr.Examples( - elem_id="system_prompt_examples", - examples=[[example] for example in default_dataset_descriptions], - inputs=[dataset_description], - ) - with gr.Row(): - gr.Column(scale=1) - btn_generate_system_prompt = gr.Button( - value="Generate system prompt and sample dataset", variant="primary" - ) - gr.Column(scale=1) - - system_prompt = gr.TextArea( - label="System prompt for dataset generation. You can tune it and regenerate the sample.", - value=default_system_prompts[0], - lines=2 if task == TEXTCAT_TASK else 5, - ) - - with gr.Row(): - sample_dataset = gr.Dataframe( - value=default_datasets[0], - label=( - "Sample dataset. Text truncated to 256 tokens." - if task == TEXTCAT_TASK - else "Sample dataset. Prompts and completions truncated to 256 tokens." - ), - interactive=False, - wrap=True, - ) - - with gr.Row(): - gr.Column(scale=1) - btn_generate_sample_dataset = gr.Button( - value="Generate sample dataset", variant="primary" - ) - gr.Column(scale=1) - - return ( - dataset_description, - examples, - btn_generate_system_prompt, - system_prompt, - sample_dataset, - btn_generate_sample_dataset, - ) - - -def get_argilla_tab() -> Tuple[Any]: - with gr.Tab(label="Argilla"): - if get_argilla_client() is not None: - with gr.Row(variant="panel"): - dataset_name = gr.Textbox( - label="Dataset name", - placeholder="dataset_name", - value="my-distiset", - ) - add_to_existing_dataset = gr.Checkbox( - label="Allow adding records to existing dataset", - info="When selected, you do need to ensure the dataset options are the same as in the existing dataset.", - value=False, - interactive=True, - scale=1, - ) - - with gr.Row(variant="panel"): - btn_generate_full_dataset_argilla = gr.Button( - value="Generate", variant="primary", scale=2 - ) - btn_generate_and_push_to_argilla = gr.Button( - value="Generate and Push to Argilla", - variant="primary", - scale=2, - ) - btn_push_to_argilla = gr.Button( - value="Push to Argilla", variant="primary", scale=2 - ) - else: - gr.Markdown( - "Please add `ARGILLA_API_URL` and `ARGILLA_API_KEY` to use Argilla or export the dataset to the Hugging Face Hub." - ) - return ( - dataset_name, - add_to_existing_dataset, - btn_generate_full_dataset_argilla, - btn_generate_and_push_to_argilla, - btn_push_to_argilla, - ) - - -def get_hf_tab() -> Tuple[Any]: - with gr.Tab("Hugging Face Hub"): - with gr.Row(variant="panel"): - org_name = get_org_dropdown() - repo_name = gr.Textbox( - label="Repo name", - placeholder="dataset_name", - value="my-distiset", - ) - private = gr.Checkbox( - label="Private dataset", - value=True, - interactive=True, - scale=1, - ) - with gr.Row(variant="panel"): - btn_generate_full_dataset = gr.Button( - value="Generate", variant="primary", scale=2 - ) - btn_generate_and_push_to_hub = gr.Button( - value="Generate and Push to Hub", variant="primary", scale=2 - ) - btn_push_to_hub = gr.Button(value="Push to Hub", variant="primary", scale=2) - return ( - org_name, - repo_name, - private, - btn_generate_full_dataset, - btn_generate_and_push_to_hub, - btn_push_to_hub, - ) - - def push_pipeline_code_to_hub( pipeline_code: str, org_name: str, @@ -455,24 +125,6 @@ def validate_push_to_hub(org_name, repo_name): return repo_id -def get_final_dataset_row(default_datasets) -> gr.Dataframe: - with gr.Row(): - final_dataset = gr.Dataframe( - value=default_datasets[0], - label="Generated dataset", - interactive=False, - wrap=True, - min_width=300, - ) - return final_dataset - - -def get_success_message_row() -> gr.Markdown: - with gr.Row(): - success_message = gr.Markdown(visible=False) - return success_message - - def show_success_message(org_name, repo_name) -> gr.Markdown: client = get_argilla_client() if client is None: diff --git a/src/distilabel_dataset_generator/apps/eval.py b/src/distilabel_dataset_generator/apps/eval.py index 6e4a60a..1136fe1 100644 --- a/src/distilabel_dataset_generator/apps/eval.py +++ b/src/distilabel_dataset_generator/apps/eval.py @@ -16,25 +16,23 @@ from gradio_huggingfacehub_search import HuggingfaceHubSearch from huggingface_hub import HfApi -from src.distilabel_dataset_generator.apps.base import ( +from distilabel_dataset_generator.apps.base import ( hide_success_message, show_success_message, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) -from src.distilabel_dataset_generator.pipelines.base import ( - DEFAULT_BATCH_SIZE, -) -from src.distilabel_dataset_generator.pipelines.embeddings import ( +from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE +from distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) -from src.distilabel_dataset_generator.pipelines.eval import ( +from distilabel_dataset_generator.pipelines.eval import ( generate_pipeline_code, get_custom_evaluator, get_ultrafeedback_evaluator, ) -from src.distilabel_dataset_generator.utils import ( +from distilabel_dataset_generator.utils import ( column_to_list, extract_column_names, get_argilla_client, diff --git a/src/distilabel_dataset_generator/apps/sft.py b/src/distilabel_dataset_generator/apps/sft.py index fad57d1..f9655d3 100644 --- a/src/distilabel_dataset_generator/apps/sft.py +++ b/src/distilabel_dataset_generator/apps/sft.py @@ -9,28 +9,25 @@ from distilabel.distiset import Distiset from huggingface_hub import HfApi -from src.distilabel_dataset_generator.apps.base import ( +from distilabel_dataset_generator.apps.base import ( hide_success_message, show_success_message, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) -from src.distilabel_dataset_generator.pipelines.base import ( - DEFAULT_BATCH_SIZE, -) -from src.distilabel_dataset_generator.pipelines.embeddings import ( +from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE +from distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) -from src.distilabel_dataset_generator.pipelines.sft import ( +from distilabel_dataset_generator.pipelines.sft import ( DEFAULT_DATASET_DESCRIPTIONS, generate_pipeline_code, get_magpie_generator, get_prompt_generator, get_response_generator, ) -from src.distilabel_dataset_generator.utils import ( - _LOGGED_OUT_CSS, +from distilabel_dataset_generator.utils import ( get_argilla_client, get_org_dropdown, swap_visibility, @@ -352,170 +349,177 @@ def hide_pipeline_code_visibility(): ###################### -with gr.Blocks(css=_LOGGED_OUT_CSS) as app: +with gr.Blocks() as app: with gr.Column() as main_ui: - gr.Markdown(value="## 1. Describe the dataset you want") - with gr.Row(): - with gr.Column(scale=2): - dataset_description = gr.Textbox( - label="Dataset description", - placeholder="Give a precise description of your desired dataset.", - ) - with gr.Accordion("Temperature", open=False): - temperature = gr.Slider( - minimum=0.1, - maximum=1, - value=0.8, - step=0.1, + if not SFT_AVAILABLE: + gr.Markdown( + value=f"## Supervised Fine-Tuning is not available for the {MODEL} model. Use Hugging Face Llama3 or Qwen2 models." + ) + else: + gr.Markdown(value="## 1. Describe the dataset you want") + with gr.Row(): + with gr.Column(scale=2): + dataset_description = gr.Textbox( + label="Dataset description", + placeholder="Give a precise description of your desired dataset.", + ) + with gr.Accordion("Temperature", open=False): + temperature = gr.Slider( + minimum=0.1, + maximum=1, + value=0.8, + step=0.1, + interactive=True, + show_label=False, + ) + load_btn = gr.Button( + "Create dataset", + variant="primary", + ) + with gr.Column(scale=2): + examples = gr.Examples( + examples=DEFAULT_DATASET_DESCRIPTIONS, + inputs=[dataset_description], + cache_examples=False, + label="Examples", + ) + with gr.Column(scale=1): + pass + + gr.HTML(value="
") + gr.Markdown(value="## 2. Configure your dataset") + with gr.Row(equal_height=False): + with gr.Column(scale=2): + system_prompt = gr.Textbox( + label="System prompt", + placeholder="You are a helpful assistant.", + ) + num_turns = gr.Number( + value=1, + label="Number of turns in the conversation", + minimum=1, + maximum=4, + step=1, interactive=True, - show_label=False, + info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).", ) - load_btn = gr.Button( - "Create dataset", - variant="primary", - ) - with gr.Column(scale=2): - examples = gr.Examples( - examples=DEFAULT_DATASET_DESCRIPTIONS, - inputs=[dataset_description], - cache_examples=False, - label="Examples", - ) - with gr.Column(scale=1): - pass - - gr.HTML(value="
") - gr.Markdown(value="## 2. Configure your dataset") - with gr.Row(equal_height=False): - with gr.Column(scale=2): - system_prompt = gr.Textbox( - label="System prompt", - placeholder="You are a helpful assistant.", - ) - num_turns = gr.Number( - value=1, - label="Number of turns in the conversation", - minimum=1, - maximum=4, - step=1, - interactive=True, - info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).", - ) - btn_apply_to_sample_dataset = gr.Button( - "Refresh dataset", variant="secondary" - ) - with gr.Column(scale=3): - dataframe = gr.Dataframe( - headers=["prompt", "completion"], - wrap=True, - height=500, - interactive=False, - ) - - gr.HTML(value="
") - gr.Markdown(value="## 3. Generate your dataset") - with gr.Row(equal_height=False): - with gr.Column(scale=2): - org_name = get_org_dropdown() - repo_name = gr.Textbox( - label="Repo name", - placeholder="dataset_name", - value=f"my-distiset-{str(uuid.uuid4())[:8]}", - interactive=True, - ) - num_rows = gr.Number( - label="Number of rows", - value=10, - interactive=True, - scale=1, - ) - private = gr.Checkbox( - label="Private dataset", - value=False, - interactive=True, - scale=1, - ) - btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2) - with gr.Column(scale=3): - success_message = gr.Markdown(visible=True) - with gr.Accordion( - "Do you want to go further? Customize and run with Distilabel", - open=False, - visible=False, - ) as pipeline_code_ui: - code = generate_pipeline_code( - system_prompt=system_prompt.value, - num_turns=num_turns.value, - num_rows=num_rows.value, + btn_apply_to_sample_dataset = gr.Button( + "Refresh dataset", variant="secondary" ) - pipeline_code = gr.Code( - value=code, - language="python", - label="Distilabel Pipeline Code", + with gr.Column(scale=3): + dataframe = gr.Dataframe( + headers=["prompt", "completion"], + wrap=True, + height=500, + interactive=False, ) - load_btn.click( - fn=generate_system_prompt, - inputs=[dataset_description, temperature], - outputs=[system_prompt], - show_progress=True, - ).then( - fn=generate_sample_dataset, - inputs=[system_prompt, num_turns], - outputs=[dataframe], - show_progress=True, - ) + gr.HTML(value="
") + gr.Markdown(value="## 3. Generate your dataset") + with gr.Row(equal_height=False): + with gr.Column(scale=2): + org_name = get_org_dropdown() + repo_name = gr.Textbox( + label="Repo name", + placeholder="dataset_name", + value=f"my-distiset-{str(uuid.uuid4())[:8]}", + interactive=True, + ) + num_rows = gr.Number( + label="Number of rows", + value=10, + interactive=True, + scale=1, + ) + private = gr.Checkbox( + label="Private dataset", + value=False, + interactive=True, + scale=1, + ) + btn_push_to_hub = gr.Button( + "Push to Hub", variant="primary", scale=2 + ) + with gr.Column(scale=3): + success_message = gr.Markdown(visible=True) + with gr.Accordion( + "Do you want to go further? Customize and run with Distilabel", + open=False, + visible=False, + ) as pipeline_code_ui: + code = generate_pipeline_code( + system_prompt=system_prompt.value, + num_turns=num_turns.value, + num_rows=num_rows.value, + ) + pipeline_code = gr.Code( + value=code, + language="python", + label="Distilabel Pipeline Code", + ) + + load_btn.click( + fn=generate_system_prompt, + inputs=[dataset_description, temperature], + outputs=[system_prompt], + show_progress=True, + ).then( + fn=generate_sample_dataset, + inputs=[system_prompt, num_turns], + outputs=[dataframe], + show_progress=True, + ) - btn_apply_to_sample_dataset.click( - fn=generate_sample_dataset, - inputs=[system_prompt, num_turns], - outputs=[dataframe], - show_progress=True, - ) + btn_apply_to_sample_dataset.click( + fn=generate_sample_dataset, + inputs=[system_prompt, num_turns], + outputs=[dataframe], + show_progress=True, + ) - btn_push_to_hub.click( - fn=validate_argilla_user_workspace_dataset, - inputs=[repo_name], - outputs=[success_message], - show_progress=True, - ).then( - fn=validate_push_to_hub, - inputs=[org_name, repo_name], - outputs=[success_message], - show_progress=True, - ).success( - fn=hide_success_message, - outputs=[success_message], - show_progress=True, - ).success( - fn=hide_pipeline_code_visibility, - inputs=[], - outputs=[pipeline_code_ui], - ).success( - fn=push_dataset, - inputs=[ - org_name, - repo_name, - system_prompt, - num_turns, - num_rows, - private, - ], - outputs=[success_message], - show_progress=True, - ).success( - fn=show_success_message, - inputs=[org_name, repo_name], - outputs=[success_message], - ).success( - fn=generate_pipeline_code, - inputs=[system_prompt, num_turns, num_rows], - outputs=[pipeline_code], - ).success( - fn=show_pipeline_code_visibility, - inputs=[], - outputs=[pipeline_code_ui], - ) + btn_push_to_hub.click( + fn=validate_argilla_user_workspace_dataset, + inputs=[repo_name], + outputs=[success_message], + show_progress=True, + ).then( + fn=validate_push_to_hub, + inputs=[org_name, repo_name], + outputs=[success_message], + show_progress=True, + ).success( + fn=hide_success_message, + outputs=[success_message], + show_progress=True, + ).success( + fn=hide_pipeline_code_visibility, + inputs=[], + outputs=[pipeline_code_ui], + ).success( + fn=push_dataset, + inputs=[ + org_name, + repo_name, + system_prompt, + num_turns, + num_rows, + private, + ], + outputs=[success_message], + show_progress=True, + ).success( + fn=show_success_message, + inputs=[org_name, repo_name], + outputs=[success_message], + ).success( + fn=generate_pipeline_code, + inputs=[system_prompt, num_turns, num_rows], + outputs=[pipeline_code], + ).success( + fn=show_pipeline_code_visibility, + inputs=[], + outputs=[pipeline_code_ui], + ) - app.load(fn=swap_visibility, outputs=main_ui) - app.load(fn=get_org_dropdown, outputs=[org_name]) + app.load(fn=swap_visibility, outputs=main_ui) + app.load(fn=get_org_dropdown, outputs=[org_name]) diff --git a/src/distilabel_dataset_generator/apps/textcat.py b/src/distilabel_dataset_generator/apps/textcat.py index 2666d0a..43988ef 100644 --- a/src/distilabel_dataset_generator/apps/textcat.py +++ b/src/distilabel_dataset_generator/apps/textcat.py @@ -9,15 +9,13 @@ from distilabel.distiset import Distiset from huggingface_hub import HfApi +from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE from src.distilabel_dataset_generator.apps.base import ( hide_success_message, show_success_message, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) -from src.distilabel_dataset_generator.pipelines.base import ( - DEFAULT_BATCH_SIZE, -) from src.distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, @@ -30,7 +28,6 @@ get_textcat_generator, ) from src.distilabel_dataset_generator.utils import ( - _LOGGED_OUT_CSS, get_argilla_client, get_org_dropdown, get_preprocess_labels, @@ -334,7 +331,7 @@ def hide_pipeline_code_visibility(): ###################### -with gr.Blocks(css=_LOGGED_OUT_CSS) as app: +with gr.Blocks() as app: with gr.Column() as main_ui: gr.Markdown("## 1. Describe the dataset you want") with gr.Row(): diff --git a/src/distilabel_dataset_generator/constants.py b/src/distilabel_dataset_generator/constants.py new file mode 100644 index 0000000..4732a09 --- /dev/null +++ b/src/distilabel_dataset_generator/constants.py @@ -0,0 +1,62 @@ +import os +import warnings + +import argilla as rg + +# Tasks +TEXTCAT_TASK = "text_classification" +SFT_TASK = "supervised_fine_tuning" + +# Hugging Face +HF_TOKEN = os.getenv("HF_TOKEN") +if HF_TOKEN is None: + raise ValueError( + "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints." + ) + +# Inference +DEFAULT_BATCH_SIZE = 5 +MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") +API_KEYS = ( + [os.getenv("HF_TOKEN")] + + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)] + + [os.getenv("API_KEY")] +) +API_KEYS = [token for token in API_KEYS if token] +BASE_URL = os.getenv("BASE_URL", "https://api-inference.huggingface.co/v1/") + +if BASE_URL != "https://api-inference.huggingface.co/v1/" and len(API_KEYS) == 0: + raise ValueError( + "API_KEY is not set. Ensure you have set the API_KEY environment variable that has access to the Hugging Face Inference Endpoints." + ) +if "Qwen2" not in MODEL and "Llama-3" not in MODEL: + SFT_AVAILABLE = False + warnings.warn( + "SFT_AVAILABLE is set to False because the model is not a Qwen or Llama model." + ) + MAGPIE_PRE_QUERY_TEMPLATE = None +else: + SFT_AVAILABLE = True + if "Qwen2" in MODEL: + MAGPIE_PRE_QUERY_TEMPLATE = "qwen2" + else: + MAGPIE_PRE_QUERY_TEMPLATE = "llama3" + +# Embeddings +STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M" + +# Argilla +ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") +ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") +if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: + ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER") + ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER") + +if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: + warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set") + argilla_client = None +else: + argilla_client = rg.Argilla( + api_url=ARGILLA_API_URL, + api_key=ARGILLA_API_KEY, + ) diff --git a/src/distilabel_dataset_generator/pipelines/__init__.py b/src/distilabel_dataset_generator/pipelines/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/distilabel_dataset_generator/pipelines/base.py b/src/distilabel_dataset_generator/pipelines/base.py index ec54f95..22510c2 100644 --- a/src/distilabel_dataset_generator/pipelines/base.py +++ b/src/distilabel_dataset_generator/pipelines/base.py @@ -1,12 +1,10 @@ -from src.distilabel_dataset_generator import HF_TOKENS +from distilabel_dataset_generator.constants import API_KEYS -DEFAULT_BATCH_SIZE = 5 TOKEN_INDEX = 0 -MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" def _get_next_api_key(): global TOKEN_INDEX - api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)] + api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)] TOKEN_INDEX += 1 return api_key diff --git a/src/distilabel_dataset_generator/pipelines/embeddings.py b/src/distilabel_dataset_generator/pipelines/embeddings.py index bcd99ef..3275713 100644 --- a/src/distilabel_dataset_generator/pipelines/embeddings.py +++ b/src/distilabel_dataset_generator/pipelines/embeddings.py @@ -3,8 +3,9 @@ from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding -# Initialize a StaticEmbedding module -static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output") +from distilabel_dataset_generator.constants import STATIC_EMBEDDING_MODEL + +static_embedding = StaticEmbedding.from_model2vec(STATIC_EMBEDDING_MODEL) model = SentenceTransformer(modules=[static_embedding]) diff --git a/src/distilabel_dataset_generator/pipelines/eval.py b/src/distilabel_dataset_generator/pipelines/eval.py index cf1d25b..ee2959a 100644 --- a/src/distilabel_dataset_generator/pipelines/eval.py +++ b/src/distilabel_dataset_generator/pipelines/eval.py @@ -5,18 +5,16 @@ UltraFeedback, ) -from src.distilabel_dataset_generator.pipelines.base import ( - MODEL, - _get_next_api_key, -) -from src.distilabel_dataset_generator.utils import extract_column_names +from distilabel_dataset_generator.constants import BASE_URL, MODEL +from distilabel_dataset_generator.pipelines.base import _get_next_api_key +from distilabel_dataset_generator.utils import extract_column_names def get_ultrafeedback_evaluator(aspect, is_sample): ultrafeedback_evaluator = UltraFeedback( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0, @@ -33,7 +31,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample) custom_evaluator = TextGeneration( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), structured_output={"format": "json", "schema": structured_output}, generation_kwargs={ @@ -62,7 +60,8 @@ def generate_ultrafeedback_pipeline_code( from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" -os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained +BASE_URL = "{BASE_URL}" +os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]") data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries @@ -76,8 +75,8 @@ def generate_ultrafeedback_pipeline_code( ultrafeedback_evaluator = UltraFeedback( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["API_KEY"], generation_kwargs={{ "temperature": 0, "max_new_tokens": 2048, @@ -101,7 +100,8 @@ def generate_ultrafeedback_pipeline_code( from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" -os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained +BASE_URL = "{BASE_URL}" +os.environ["BASE_URL"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}") data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries @@ -119,8 +119,8 @@ def generate_ultrafeedback_pipeline_code( aspect=aspect, llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["BASE_URL"], generation_kwargs={{ "temperature": 0, "max_new_tokens": 2048, @@ -157,6 +157,7 @@ def generate_custom_pipeline_code( from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" +BASE_URL = "{BASE_URL}" CUSTOM_TEMPLATE = "{prompt_template}" os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained @@ -171,7 +172,7 @@ def generate_custom_pipeline_code( custom_evaluator = TextGeneration( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=os.environ["HF_TOKEN"], structured_output={{"format": "json", "schema": {structured_output}}}, generation_kwargs={{ diff --git a/src/distilabel_dataset_generator/pipelines/sft.py b/src/distilabel_dataset_generator/pipelines/sft.py index 240e973..920f40d 100644 --- a/src/distilabel_dataset_generator/pipelines/sft.py +++ b/src/distilabel_dataset_generator/pipelines/sft.py @@ -1,10 +1,12 @@ from distilabel.llms import InferenceEndpointsLLM from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration -from src.distilabel_dataset_generator.pipelines.base import ( +from distilabel_dataset_generator.constants import ( + BASE_URL, + MAGPIE_PRE_QUERY_TEMPLATE, MODEL, - _get_next_api_key, ) +from distilabel_dataset_generator.pipelines.base import _get_next_api_key INFORMATION_SEEKING_PROMPT = ( "You are an AI assistant designed to provide accurate and concise information on a wide" @@ -144,6 +146,7 @@ def get_prompt_generator(temperature): api_key=_get_next_api_key(), model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, generation_kwargs={ "temperature": temperature, "max_new_tokens": 2048, @@ -165,8 +168,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), - magpie_pre_query_template="llama3", + magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, generation_kwargs={ "temperature": 0.9, "do_sample": True, @@ -184,8 +188,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), - magpie_pre_query_template="llama3", + magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, generation_kwargs={ "temperature": 0.9, "do_sample": True, @@ -208,6 +213,7 @@ def get_response_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.8, @@ -223,6 +229,7 @@ def get_response_generator(system_prompt, num_turns, is_sample): llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.8, @@ -247,14 +254,16 @@ def generate_pipeline_code(system_prompt, num_turns, num_rows): from distilabel.llms import InferenceEndpointsLLM MODEL = "{MODEL}" +BASE_URL = "{BASE_URL}" SYSTEM_PROMPT = "{system_prompt}" -os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained +os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained with Pipeline(name="sft") as pipeline: magpie = MagpieGenerator( llm=InferenceEndpointsLLM( model_id=MODEL, tokenizer_id=MODEL, + base_url=BASE_URL, magpie_pre_query_template="llama3", generation_kwargs={{ "temperature": 0.9, @@ -262,7 +271,7 @@ def generate_pipeline_code(system_prompt, num_turns, num_rows): "max_new_tokens": 2048, "stop_sequences": {_STOP_SEQUENCES} }}, - api_key=os.environ["HF_TOKEN"], + api_key=os.environ["BASE_URL"], ), n_turns={num_turns}, num_rows={num_rows}, diff --git a/src/distilabel_dataset_generator/pipelines/textcat.py b/src/distilabel_dataset_generator/pipelines/textcat.py index e17f594..1c88e86 100644 --- a/src/distilabel_dataset_generator/pipelines/textcat.py +++ b/src/distilabel_dataset_generator/pipelines/textcat.py @@ -1,5 +1,4 @@ import random -from pydantic import BaseModel, Field from typing import List from distilabel.llms import InferenceEndpointsLLM @@ -8,12 +7,11 @@ TextClassification, TextGeneration, ) +from pydantic import BaseModel, Field -from src.distilabel_dataset_generator.pipelines.base import ( - MODEL, - _get_next_api_key, -) -from src.distilabel_dataset_generator.utils import get_preprocess_labels +from distilabel_dataset_generator.constants import BASE_URL, MODEL +from distilabel_dataset_generator.pipelines.base import _get_next_api_key +from distilabel_dataset_generator.utils import get_preprocess_labels PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation. @@ -73,7 +71,7 @@ def get_prompt_generator(temperature): llm=InferenceEndpointsLLM( api_key=_get_next_api_key(), model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, structured_output={"format": "json", "schema": TextClassificationTask}, generation_kwargs={ "temperature": temperature, @@ -92,7 +90,7 @@ def get_textcat_generator(difficulty, clarity, is_sample): textcat_generator = GenerateTextClassificationData( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.9, @@ -114,7 +112,7 @@ def get_labeller_generator(system_prompt, labels, num_labels): labeller_generator = TextClassification( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, + base_url=BASE_URL, api_key=_get_next_api_key(), generation_kwargs={ "temperature": 0.7, @@ -149,8 +147,9 @@ def generate_pipeline_code( from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"} MODEL = "{MODEL}" +BASE_URL = "{BASE_URL}" TEXT_CLASSIFICATION_TASK = "{system_prompt}" -os.environ["HF_TOKEN"] = ( +os.environ["API_KEY"] = ( "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained ) @@ -161,8 +160,8 @@ def generate_pipeline_code( textcat_generation = GenerateTextClassificationData( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["API_KEY"], generation_kwargs={{ "temperature": 0.8, "max_new_tokens": 2048, @@ -205,8 +204,8 @@ def generate_pipeline_code( textcat_labeller = TextClassification( llm=InferenceEndpointsLLM( model_id=MODEL, - tokenizer_id=MODEL, - api_key=os.environ["HF_TOKEN"], + base_url=BASE_URL, + api_key=os.environ["API_KEY"], generation_kwargs={{ "temperature": 0.8, "max_new_tokens": 2048, diff --git a/src/distilabel_dataset_generator/utils.py b/src/distilabel_dataset_generator/utils.py index 68b0b77..b894a87 100644 --- a/src/distilabel_dataset_generator/utils.py +++ b/src/distilabel_dataset_generator/utils.py @@ -6,40 +6,13 @@ import numpy as np import pandas as pd from gradio.oauth import ( - OAUTH_CLIENT_ID, - OAUTH_CLIENT_SECRET, - OAUTH_SCOPES, - OPENID_PROVIDER_URL, + OAuthToken, get_space, ) from huggingface_hub import whoami from jinja2 import Environment, meta -from src.distilabel_dataset_generator import argilla_client - -_LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}" - - -_CHECK_IF_SPACE_IS_SET = ( - all( - [ - OAUTH_CLIENT_ID, - OAUTH_CLIENT_SECRET, - OAUTH_SCOPES, - OPENID_PROVIDER_URL, - ] - ) - or get_space() is None -) - -if _CHECK_IF_SPACE_IS_SET: - from gradio.oauth import OAuthToken -else: - OAuthToken = str - - -def get_login_button(): - return gr.LoginButton(value="Sign in!", size="sm", scale=2).activate() +from distilabel_dataset_generator.constants import argilla_client def get_duplicate_button(): @@ -85,13 +58,6 @@ def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None): ) -def get_token(oauth_token: Union[OAuthToken, None]): - if oauth_token: - return oauth_token.token - else: - return "" - - def swap_visibility(oauth_token: Union[OAuthToken, None]): if oauth_token: return gr.update(elem_classes=["main_ui_logged_in"]) @@ -99,28 +65,6 @@ def swap_visibility(oauth_token: Union[OAuthToken, None]): return gr.update(elem_classes=["main_ui_logged_out"]) -def get_base_app(): - with gr.Blocks( - title="🧬 Synthetic Data Generator", - head="🧬 Synthetic Data Generator", - css=_LOGGED_OUT_CSS, - ) as app: - with gr.Row(): - gr.Markdown( - "Want to run this locally or with other LLMs? Take a look at the FAQ tab. distilabel Synthetic Data Generator is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation." - ) - with gr.Row(): - gr.Column() - get_login_button() - gr.Column() - - gr.Markdown("## Iterate on a sample dataset") - with gr.Column() as main_ui: - pass - - return app - - def get_argilla_client() -> Union[rg.Argilla, None]: return argilla_client