From f5ab4cbe6d82d1b70fe056b852b2c72f78b06f5e Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 23 Dec 2024 17:15:45 +0100 Subject: [PATCH 01/15] update deployment with API providers --- README.md | 13 ++- examples/enforce_mapgie_template copy.py | 9 -- ...a_local.py => hf-serverless_deployment.py} | 4 +- examples/ollama_deployment.py | 17 +++ .../{openai_local.py => openai_deployment.py} | 3 +- examples/tgi_or_hf_dedicated.py | 14 +++ pdm.lock | 67 ++++++++--- pyproject.toml | 2 +- src/synthetic_dataset_generator/apps/chat.py | 2 +- .../apps/textcat.py | 49 +++++--- src/synthetic_dataset_generator/constants.py | 75 ++++++++---- .../pipelines/base.py | 81 ++++++++++++- .../pipelines/chat.py | 107 +++++++----------- .../pipelines/textcat.py | 62 ++-------- 14 files changed, 310 insertions(+), 195 deletions(-) delete mode 100644 examples/enforce_mapgie_template copy.py rename examples/{ollama_local.py => hf-serverless_deployment.py} (51%) create mode 100644 examples/ollama_deployment.py rename examples/{openai_local.py => openai_deployment.py} (63%) create mode 100644 examples/tgi_or_hf_dedicated.py diff --git a/README.md b/README.md index 557d3e2..a24c9ed 100644 --- a/README.md +++ b/README.md @@ -76,21 +76,24 @@ launch() - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints. You can find some configuration examples in the [examples](examples/) folder. -Optionally, you can set the following environment variables to customize the generation process. +You can set the following environment variables to customize the generation process. - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`. - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`. - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`. -Optionally, you can use different models and APIs. For providers outside of Hugging Face, we provide an integration through [LiteLLM](https://docs.litellm.ai/docs/providers). +Optionally, you can use different API providers and models. -- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`, `http://127.0.0.1:11434/v1/`. -- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `openai/gpt-4o`, `ollama/llama3.1`. +- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`. - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable. +- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`. +- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/v1/`. +- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`. SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2. -- `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. Llama3 and Qwen2 are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"` respectively. For other models, you can pass a custom pre-query template string. +- `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`. +- `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string. Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables: diff --git a/examples/enforce_mapgie_template copy.py b/examples/enforce_mapgie_template copy.py deleted file mode 100644 index ab01549..0000000 --- a/examples/enforce_mapgie_template copy.py +++ /dev/null @@ -1,9 +0,0 @@ -# pip install synthetic-dataset-generator -import os - -from synthetic_dataset_generator import launch - -os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "my_custom_template" -os.environ["MODEL"] = "google/gemma-2-9b-it" - -launch() diff --git a/examples/ollama_local.py b/examples/hf-serverless_deployment.py similarity index 51% rename from examples/ollama_local.py rename to examples/hf-serverless_deployment.py index e7ba9cd..5c12ee2 100644 --- a/examples/ollama_local.py +++ b/examples/hf-serverless_deployment.py @@ -4,7 +4,7 @@ from synthetic_dataset_generator import launch assert os.getenv("HF_TOKEN") # push the data to huggingface -os.environ["BASE_URL"] = "http://127.0.0.1:11434/v1/" -os.environ["MODEL"] = "llama3.1" +os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model launch() diff --git a/examples/ollama_deployment.py b/examples/ollama_deployment.py new file mode 100644 index 0000000..f0aad22 --- /dev/null +++ b/examples/ollama_deployment.py @@ -0,0 +1,17 @@ +# pip install synthetic-dataset-generator +# ollama serve +# ollama run llama3.1:8b-instruct-q8_0 +import os + +from synthetic_dataset_generator import launch + +assert os.getenv("HF_TOKEN") # push the data to huggingface +os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" +os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" +os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" +os.environ["MAX_NUM_ROWS"] = "10000" +os.environ["DEFAULT_BATCH_SIZE"] = "5" +os.environ["MAX_NUM_TOKENS"] = "2048" + +launch() diff --git a/examples/openai_local.py b/examples/openai_deployment.py similarity index 63% rename from examples/openai_local.py rename to examples/openai_deployment.py index b531f05..5d0231f 100644 --- a/examples/openai_local.py +++ b/examples/openai_deployment.py @@ -4,8 +4,9 @@ from synthetic_dataset_generator import launch assert os.getenv("HF_TOKEN") # push the data to huggingface -os.environ["BASE_URL"] = "https://api.openai.com/v1/" +os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") os.environ["MODEL"] = "gpt-4o" +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI launch() diff --git a/examples/tgi_or_hf_dedicated.py b/examples/tgi_or_hf_dedicated.py new file mode 100644 index 0000000..9f466af --- /dev/null +++ b/examples/tgi_or_hf_dedicated.py @@ -0,0 +1,14 @@ +# pip install synthetic-dataset-generator +import os + +from synthetic_dataset_generator import launch + +assert os.getenv("HF_TOKEN") # push the data to huggingface +os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" +os.environ["TOKENIZER_ID"] = ( + "meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint +) +os.environ["MODEL"] = None # model is linked to endpoint + +launch() diff --git a/pdm.lock b/pdm.lock index 485413b..ad969f7 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:a6d7f86e9a168e7eb78801faafa8a9ea13270310ee9c21a0e23aeab277d973a0" +content_hash = "sha256:e95140895657d62ad438ff1815ddf1798abbb342ddd2649ae462620b8b3f5350" [[metadata.targets]] requires_python = ">=3.10,<3.13" @@ -491,8 +491,11 @@ files = [ [[package]] name = "distilabel" -version = "1.4.1" +version = "1.5.0" requires_python = ">=3.9" +git = "https://github.com/argilla-io/distilabel.git" +ref = "feat/add-magpie-support-llama-cpp-ollama" +revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6" summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs." groups = ["default"] dependencies = [ @@ -512,30 +515,30 @@ dependencies = [ "typer>=0.9.0", "universal-pathlib>=0.2.2", ] -files = [ - {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"}, - {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"}, -] [[package]] name = "distilabel" -version = "1.4.1" -extras = ["argilla", "hf-inference-endpoints", "instructor", "outlines"] +version = "1.5.0" +extras = ["argilla", "hf-inference-endpoints", "hf-transformers", "instructor", "llama-cpp", "ollama", "openai", "outlines"] requires_python = ">=3.9" +git = "https://github.com/argilla-io/distilabel.git" +ref = "feat/add-magpie-support-llama-cpp-ollama" +revision = "4e291e7bf1c27b734a683a3af1fefe58965d77d6" summary = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs." groups = ["default"] dependencies = [ "argilla>=2.0.0", - "distilabel==1.4.1", + "distilabel @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama", "huggingface-hub>=0.22.0", "instructor>=1.2.3", "ipython", + "llama-cpp-python>=0.2.0", "numba>=0.54.0", + "ollama>=0.1.7", + "openai>=1.0.0", "outlines>=0.0.40", -] -files = [ - {file = "distilabel-1.4.1-py3-none-any.whl", hash = "sha256:4643da7f3abae86a330d86d1498443ea56978e462e21ae3d106a4c6013386965"}, - {file = "distilabel-1.4.1.tar.gz", hash = "sha256:0c373be234e8f2982ec7f940d9a95585b15306b6ab5315f5a6a45214d8f34006"}, + "torch>=2.0.0", + "transformers>=4.34.1", ] [[package]] @@ -824,7 +827,7 @@ files = [ [[package]] name = "httpx" -version = "0.28.1" +version = "0.27.2" requires_python = ">=3.8" summary = "The next generation HTTP client." groups = ["default"] @@ -833,10 +836,11 @@ dependencies = [ "certifi", "httpcore==1.*", "idna", + "sniffio", ] files = [ - {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, - {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, ] [[package]] @@ -1068,6 +1072,22 @@ files = [ {file = "lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80"}, ] +[[package]] +name = "llama-cpp-python" +version = "0.3.5" +requires_python = ">=3.8" +summary = "Python bindings for the llama.cpp library" +groups = ["default"] +dependencies = [ + "diskcache>=5.6.1", + "jinja2>=2.11.3", + "numpy>=1.20.0", + "typing-extensions>=4.5.0", +] +files = [ + {file = "llama_cpp_python-0.3.5.tar.gz", hash = "sha256:f5ce47499d53d3973e28ca5bdaf2dfe820163fa3fb67e3050f98e2e9b58d2cf6"}, +] + [[package]] name = "llvmlite" version = "0.43.0" @@ -1538,6 +1558,21 @@ files = [ {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"}, ] +[[package]] +name = "ollama" +version = "0.4.4" +requires_python = "<4.0,>=3.8" +summary = "The official Python client for Ollama." +groups = ["default"] +dependencies = [ + "httpx<0.28.0,>=0.27.0", + "pydantic<3.0.0,>=2.9.0", +] +files = [ + {file = "ollama-0.4.4-py3-none-any.whl", hash = "sha256:0f466e845e2205a1cbf5a2fef4640027b90beaa3b06c574426d8b6b17fd6e139"}, + {file = "ollama-0.4.4.tar.gz", hash = "sha256:e1db064273c739babc2dde9ea84029c4a43415354741b6c50939ddd3dd0f7ffb"}, +] + [[package]] name = "openai" version = "1.57.4" diff --git a/pyproject.toml b/pyproject.toml index 016cc48..fdec71c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ readme = "README.md" license = {text = "Apache 2"} dependencies = [ - "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1,<2.0.0", + "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines] @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama", "gradio[oauth]>=5.4.0,<6.0.0", "transformers>=4.44.2,<5.0.0", "sentence-transformers>=3.2.0,<4.0.0", diff --git a/src/synthetic_dataset_generator/apps/chat.py b/src/synthetic_dataset_generator/apps/chat.py index 10b4bd1..bebdbd8 100644 --- a/src/synthetic_dataset_generator/apps/chat.py +++ b/src/synthetic_dataset_generator/apps/chat.py @@ -121,7 +121,7 @@ def generate_dataset( { "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: " } - for i in range(int(num_rows / 50)) + for i in range(int(num_rows / 100)) ] batch = list(prompt_rewriter.process(inputs=inputs)) prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt] diff --git a/src/synthetic_dataset_generator/apps/textcat.py b/src/synthetic_dataset_generator/apps/textcat.py index 9e2010c..c2edf6d 100644 --- a/src/synthetic_dataset_generator/apps/textcat.py +++ b/src/synthetic_dataset_generator/apps/textcat.py @@ -178,26 +178,41 @@ def generate_dataset( dataframe = pd.DataFrame(distiset_results) if multi_label: - dataframe["labels"] = dataframe["labels"].apply( - lambda x: list( - set( - [ - label.lower().strip() - if (label is not None and label.lower().strip() in labels) - else random.choice(labels) - for label in x - ] - ) - ) - ) + + def _validate_labels(x): + if isinstance(x, str): # single label + return [x.lower().strip()] + elif isinstance(x, list): # multiple labels + return [ + label.lower().strip() + for label in x + if label.lower().strip() in labels + ] + else: + return [random.choice(labels)] + + dataframe["labels"] = dataframe["labels"].apply(_validate_labels) dataframe = dataframe[dataframe["labels"].notna()] else: + + def _validate_labels(x): + if isinstance(x, str) and x.lower().strip() in labels: + return x.lower().strip() + elif isinstance(x, list): + options = [ + label.lower().strip() + for label in x + if isinstance(label, str) and label.lower().strip() in labels + ] + if options: + return random.choice(options) + else: + return random.choice(labels) + else: + return random.choice(labels) + dataframe = dataframe.rename(columns={"labels": "label"}) - dataframe["label"] = dataframe["label"].apply( - lambda x: x.lower().strip() - if x and x.lower().strip() in labels - else random.choice(labels) - ) + dataframe["label"] = dataframe["label"].apply(_validate_labels) dataframe = dataframe[dataframe["text"].notna()] progress(1.0, desc="Dataset created") diff --git a/src/synthetic_dataset_generator/constants.py b/src/synthetic_dataset_generator/constants.py index 3b6ea67..3583c7c 100644 --- a/src/synthetic_dataset_generator/constants.py +++ b/src/synthetic_dataset_generator/constants.py @@ -7,39 +7,63 @@ TEXTCAT_TASK = "text_classification" SFT_TASK = "supervised_fine_tuning" -# Hugging Face +# Inference +MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048)) +MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000)) +DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5)) + +# Models +MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") +TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None) +OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL") +HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL") +if HUGGINGFACE_BASE_URL and MODEL: + raise ValueError( + "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints." + ) +if OPENAI_BASE_URL or OLLAMA_BASE_URL: + if not MODEL: + raise ValueError("`MODEL` is not set. Please provide a model id for inference.") + + + +# Check if multiple base URLs are provided +base_urls = [ + url for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL] if url +] +if len(base_urls) > 1: + raise ValueError( + f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time." + ) +BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL + + +# API Keys HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError( "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints." ) -# Inference -MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048)) -MAX_NUM_ROWS: str | int = int(os.getenv("MAX_NUM_ROWS", 1000)) -DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5)) -MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") -BASE_URL = os.getenv("BASE_URL", default=None) - _API_KEY = os.getenv("API_KEY") -if _API_KEY: - API_KEYS = [_API_KEY] -else: - API_KEYS = [os.getenv("HF_TOKEN")] + [ - os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10) - ] +API_KEYS = ( + [_API_KEY] + if _API_KEY + else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)] +) API_KEYS = [token for token in API_KEYS if token] # Determine if SFT is available SFT_AVAILABLE = False llama_options = ["llama3", "llama-3", "llama 3"] qwen_options = ["qwen2", "qwen-2", "qwen 2"] -if os.getenv("MAGPIE_PRE_QUERY_TEMPLATE"): + +if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower(): SFT_AVAILABLE = True - passed_pre_query_template = os.getenv("MAGPIE_PRE_QUERY_TEMPLATE") - if passed_pre_query_template.lower() in llama_options: + if passed_pre_query_template in llama_options: MAGPIE_PRE_QUERY_TEMPLATE = "llama3" - elif passed_pre_query_template.lower() in qwen_options: + elif passed_pre_query_template in qwen_options: MAGPIE_PRE_QUERY_TEMPLATE = "qwen2" else: MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template @@ -54,12 +78,12 @@ SFT_AVAILABLE = True MAGPIE_PRE_QUERY_TEMPLATE = "qwen2" -if BASE_URL: +if OPENAI_BASE_URL: SFT_AVAILABLE = False if not SFT_AVAILABLE: warnings.warn( - message="`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints to generate chat data." + "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`." ) MAGPIE_PRE_QUERY_TEMPLATE = None @@ -67,11 +91,12 @@ STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M" # Argilla -ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") -ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") -if ARGILLA_API_URL is None or ARGILLA_API_KEY is None: - ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER") - ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER") +ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv( + "ARGILLA_API_URL_SDG_REVIEWER" +) +ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv( + "ARGILLA_API_KEY_SDG_REVIEWER" +) if not ARGILLA_API_URL or not ARGILLA_API_KEY: warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty") diff --git a/src/synthetic_dataset_generator/pipelines/base.py b/src/synthetic_dataset_generator/pipelines/base.py index c520a2e..5102bd9 100644 --- a/src/synthetic_dataset_generator/pipelines/base.py +++ b/src/synthetic_dataset_generator/pipelines/base.py @@ -1,4 +1,15 @@ -from synthetic_dataset_generator.constants import API_KEYS +import gradio as gr +from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM + +from synthetic_dataset_generator.constants import ( + API_KEYS, + HUGGINGFACE_BASE_URL, + MAGPIE_PRE_QUERY_TEMPLATE, + MODEL, + OLLAMA_BASE_URL, + OPENAI_BASE_URL, + TOKENIZER_ID, +) TOKEN_INDEX = 0 @@ -8,3 +19,71 @@ def _get_next_api_key(): api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)] TOKEN_INDEX += 1 return api_key + + +def _get_llm(use_magpie_template=False, **kwargs): + if OPENAI_BASE_URL: + llm = OpenAILLM( + model=MODEL, + base_url=OPENAI_BASE_URL, + api_key=_get_next_api_key(), + **kwargs, + ) + if "generation_kwargs" in kwargs: + if "stop_sequences" in kwargs["generation_kwargs"]: + kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ + "stop_sequences" + ] + del kwargs["generation_kwargs"]["stop_sequences"] + if "do_sample" in kwargs["generation_kwargs"]: + del kwargs["generation_kwargs"]["do_sample"] + elif OLLAMA_BASE_URL: + if "generation_kwargs" in kwargs: + if "max_new_tokens" in kwargs["generation_kwargs"]: + kwargs["generation_kwargs"]["num_predict"] = kwargs[ + "generation_kwargs" + ]["max_new_tokens"] + del kwargs["generation_kwargs"]["max_new_tokens"] + if "stop_sequences" in kwargs["generation_kwargs"]: + kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ + "stop_sequences" + ] + del kwargs["generation_kwargs"]["stop_sequences"] + if "do_sample" in kwargs["generation_kwargs"]: + del kwargs["generation_kwargs"]["do_sample"] + options = kwargs["generation_kwargs"] + del kwargs["generation_kwargs"] + kwargs["generation_kwargs"] = {} + kwargs["generation_kwargs"]["options"] = options + llm = OllamaLLM( + model=MODEL, + host=OLLAMA_BASE_URL, + tokenizer_id=TOKENIZER_ID or MODEL, + **kwargs, + ) + elif HUGGINGFACE_BASE_URL: + kwargs["generation_kwargs"]["do_sample"] = True + llm = InferenceEndpointsLLM( + api_key=_get_next_api_key(), + base_url=HUGGINGFACE_BASE_URL, + tokenizer_id=TOKENIZER_ID or MODEL, + **kwargs, + ) + else: + llm = InferenceEndpointsLLM( + api_key=_get_next_api_key(), + tokenizer_id=TOKENIZER_ID or MODEL, + model_id=MODEL, + magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, + **kwargs, + ) + + return llm + + +try: + llm = _get_llm() + llm.load() + llm.generate([[{"content": "Hello, world!", "role": "user"}]]) +except Exception as e: + gr.Error(f"Error loading {llm.__class__.__name__}: {e}") diff --git a/src/synthetic_dataset_generator/pipelines/chat.py b/src/synthetic_dataset_generator/pipelines/chat.py index 5774cac..7b3a11a 100644 --- a/src/synthetic_dataset_generator/pipelines/chat.py +++ b/src/synthetic_dataset_generator/pipelines/chat.py @@ -1,4 +1,3 @@ -from distilabel.llms import InferenceEndpointsLLM from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration from synthetic_dataset_generator.constants import ( @@ -7,7 +6,7 @@ MAX_NUM_TOKENS, MODEL, ) -from synthetic_dataset_generator.pipelines.base import _get_next_api_key +from synthetic_dataset_generator.pipelines.base import _get_llm INFORMATION_SEEKING_PROMPT = ( "You are an AI assistant designed to provide accurate and concise information on a wide" @@ -149,18 +148,13 @@ def _get_output_mappings(num_turns): def get_prompt_generator(): + generation_kwargs = { + "temperature": 0.8, + "max_new_tokens": MAX_NUM_TOKENS, + "do_sample": True, + } prompt_generator = TextGeneration( - llm=InferenceEndpointsLLM( - api_key=_get_next_api_key(), - model_id=MODEL, - tokenizer_id=MODEL, - base_url=BASE_URL, - generation_kwargs={ - "temperature": 0.8, - "max_new_tokens": MAX_NUM_TOKENS, - "do_sample": True, - }, - ), + llm=_get_llm(generation_kwargs=generation_kwargs), system_prompt=PROMPT_CREATION_PROMPT, use_system_prompt=True, ) @@ -172,38 +166,34 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample): input_mappings = _get_output_mappings(num_turns) output_mappings = input_mappings.copy() if num_turns == 1: + generation_kwargs = { + "temperature": temperature, + "do_sample": True, + "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25), + "stop_sequences": _STOP_SEQUENCES, + } magpie_generator = Magpie( - llm=InferenceEndpointsLLM( - model_id=MODEL, - tokenizer_id=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), + llm=_get_llm( + generation_kwargs=generation_kwargs, magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, - generation_kwargs={ - "temperature": temperature, - "do_sample": True, - "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25), - "stop_sequences": _STOP_SEQUENCES, - }, + use_magpie_template=True, ), n_turns=num_turns, output_mappings=output_mappings, only_instruction=True, ) else: + generation_kwargs = { + "temperature": temperature, + "do_sample": True, + "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5), + "stop_sequences": _STOP_SEQUENCES, + } magpie_generator = Magpie( - llm=InferenceEndpointsLLM( - model_id=MODEL, - tokenizer_id=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), + llm=_get_llm( + generation_kwargs=generation_kwargs, magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, - generation_kwargs={ - "temperature": temperature, - "do_sample": True, - "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5), - "stop_sequences": _STOP_SEQUENCES, - }, + use_magpie_template=True, ), end_with_user=True, n_turns=num_turns, @@ -214,50 +204,33 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample): def get_prompt_rewriter(): - prompt_rewriter = TextGeneration( - llm=InferenceEndpointsLLM( - model_id=MODEL, - tokenizer_id=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - generation_kwargs={ - "temperature": 1, - }, - ), - ) + generation_kwargs = { + "temperature": 1, + } + prompt_rewriter = TextGeneration(llm=_get_llm(generation_kwargs=generation_kwargs)) prompt_rewriter.load() return prompt_rewriter def get_response_generator(system_prompt, num_turns, temperature, is_sample): if num_turns == 1: + generation_kwargs = { + "temperature": temperature, + "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5), + } response_generator = TextGeneration( - llm=InferenceEndpointsLLM( - model_id=MODEL, - tokenizer_id=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - generation_kwargs={ - "temperature": temperature, - "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5), - }, - ), + llm=_get_llm(generation_kwargs=generation_kwargs), system_prompt=system_prompt, output_mappings={"generation": "completion"}, input_mappings={"instruction": "prompt"}, ) else: + generation_kwargs = { + "temperature": temperature, + "max_new_tokens": MAX_NUM_TOKENS, + } response_generator = ChatGeneration( - llm=InferenceEndpointsLLM( - model_id=MODEL, - tokenizer_id=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - generation_kwargs={ - "temperature": temperature, - "max_new_tokens": MAX_NUM_TOKENS, - }, - ), + llm=_get_llm(generation_kwargs=generation_kwargs), output_mappings={"generation": "completion"}, input_mappings={"conversation": "messages"}, ) @@ -293,7 +266,7 @@ def generate_pipeline_code(system_prompt, num_turns, num_rows, temperature): "max_new_tokens": {MAX_NUM_TOKENS}, "stop_sequences": {_STOP_SEQUENCES} }}, - api_key=os.environ["BASE_URL"], + api_key=os.environ["API_KEY"], ), n_turns={num_turns}, num_rows={num_rows}, diff --git a/src/synthetic_dataset_generator/pipelines/textcat.py b/src/synthetic_dataset_generator/pipelines/textcat.py index 4f5f11d..ab1547d 100644 --- a/src/synthetic_dataset_generator/pipelines/textcat.py +++ b/src/synthetic_dataset_generator/pipelines/textcat.py @@ -1,7 +1,6 @@ import random from typing import List -from distilabel.llms import InferenceEndpointsLLM, OpenAILLM from distilabel.steps.tasks import ( GenerateTextClassificationData, TextClassification, @@ -9,8 +8,12 @@ ) from pydantic import BaseModel, Field -from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL -from synthetic_dataset_generator.pipelines.base import _get_next_api_key +from synthetic_dataset_generator.constants import ( + BASE_URL, + MAX_NUM_TOKENS, + MODEL, +) +from synthetic_dataset_generator.pipelines.base import _get_llm from synthetic_dataset_generator.utils import get_preprocess_labels PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation. @@ -69,23 +72,10 @@ def get_prompt_generator(): "temperature": 0.8, "max_new_tokens": MAX_NUM_TOKENS, } - if BASE_URL: - llm = OpenAILLM( - model=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - structured_output=structured_output, - generation_kwargs=generation_kwargs, - ) - else: - generation_kwargs["do_sample"] = True - llm = InferenceEndpointsLLM( - api_key=_get_next_api_key(), - model_id=MODEL, - base_url=BASE_URL, - structured_output=structured_output, - generation_kwargs=generation_kwargs, - ) + llm = _get_llm( + structured_output=structured_output, + generation_kwargs=generation_kwargs, + ) prompt_generator = TextGeneration( llm=llm, @@ -103,21 +93,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample): "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS, "top_p": 0.95, } - if BASE_URL: - llm = OpenAILLM( - model=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - generation_kwargs=generation_kwargs, - ) - else: - generation_kwargs["do_sample"] = True - llm = InferenceEndpointsLLM( - model_id=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - generation_kwargs=generation_kwargs, - ) + llm = _get_llm(generation_kwargs=generation_kwargs) textcat_generator = GenerateTextClassificationData( llm=llm, @@ -134,21 +110,7 @@ def get_labeller_generator(system_prompt, labels, multi_label): "temperature": 0.01, "max_new_tokens": MAX_NUM_TOKENS, } - - if BASE_URL: - llm = OpenAILLM( - model=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - generation_kwargs=generation_kwargs, - ) - else: - llm = InferenceEndpointsLLM( - model_id=MODEL, - base_url=BASE_URL, - api_key=_get_next_api_key(), - generation_kwargs=generation_kwargs, - ) + llm = _get_llm(generation_kwargs=generation_kwargs) labeller_generator = TextClassification( llm=llm, From 32d8669769f8d5dd0418d99e8ae9bbe97eada85c Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 23 Dec 2024 17:57:39 +0100 Subject: [PATCH 02/15] update examples --- examples/argilla_deployment.py | 6 ++++-- examples/ollama_deployment.py | 11 ++++------- examples/openai_deployment.py | 6 +++--- examples/tgi_or_hf_dedicated.py | 4 ++-- src/synthetic_dataset_generator/constants.py | 6 ++---- 5 files changed, 15 insertions(+), 18 deletions(-) diff --git a/examples/argilla_deployment.py b/examples/argilla_deployment.py index 76ce637..1b127cd 100644 --- a/examples/argilla_deployment.py +++ b/examples/argilla_deployment.py @@ -4,7 +4,9 @@ from synthetic_dataset_generator import launch # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL -os.environ["ARGILLA_API_URL"] = "https://[your-owner-name]-[your_space_name].hf.space" -os.environ["ARGILLA_API_KEY"] = "my_api_key" +os.environ["ARGILLA_API_URL"] = ( + "https://[your-owner-name]-[your_space_name].hf.space" # argilla base url +) +os.environ["ARGILLA_API_KEY"] = "my_api_key" # argilla api key launch() diff --git a/examples/ollama_deployment.py b/examples/ollama_deployment.py index f0aad22..66e0e49 100644 --- a/examples/ollama_deployment.py +++ b/examples/ollama_deployment.py @@ -6,12 +6,9 @@ from synthetic_dataset_generator import launch assert os.getenv("HF_TOKEN") # push the data to huggingface -os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" -os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" -os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" -os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" -os.environ["MAX_NUM_ROWS"] = "10000" -os.environ["DEFAULT_BATCH_SIZE"] = "5" -os.environ["MAX_NUM_TOKENS"] = "2048" +os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url +os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" # model id +os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" # tokenizer id +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template launch() diff --git a/examples/openai_deployment.py b/examples/openai_deployment.py index 5d0231f..59a9bc1 100644 --- a/examples/openai_deployment.py +++ b/examples/openai_deployment.py @@ -4,9 +4,9 @@ from synthetic_dataset_generator import launch assert os.getenv("HF_TOKEN") # push the data to huggingface -os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" -os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") -os.environ["MODEL"] = "gpt-4o" +os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url +os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") # openai api key +os.environ["MODEL"] = "gpt-4o" # model id os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = None # chat data not supported with OpenAI launch() diff --git a/examples/tgi_or_hf_dedicated.py b/examples/tgi_or_hf_dedicated.py index 9f466af..a2def93 100644 --- a/examples/tgi_or_hf_dedicated.py +++ b/examples/tgi_or_hf_dedicated.py @@ -4,8 +4,8 @@ from synthetic_dataset_generator import launch assert os.getenv("HF_TOKEN") # push the data to huggingface -os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" -os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" +os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template os.environ["TOKENIZER_ID"] = ( "meta-llama/Llama-3.1-8B-Instruct" # tokenizer for model hosted on endpoint ) diff --git a/src/synthetic_dataset_generator/constants.py b/src/synthetic_dataset_generator/constants.py index 3583c7c..f177137 100644 --- a/src/synthetic_dataset_generator/constants.py +++ b/src/synthetic_dataset_generator/constants.py @@ -22,12 +22,10 @@ raise ValueError( "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints." ) -if OPENAI_BASE_URL or OLLAMA_BASE_URL: - if not MODEL: +if not MODEL: + if OPENAI_BASE_URL or OLLAMA_BASE_URL: raise ValueError("`MODEL` is not set. Please provide a model id for inference.") - - # Check if multiple base URLs are provided base_urls = [ url for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL] if url From 2841b26d0ae2aab12d1abd47e0de2b293386258c Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 24 Dec 2024 09:25:28 +0100 Subject: [PATCH 03/15] add randomisation of system prompts for generation --- README.md | 2 +- examples/ollama_deployment.py | 5 ++- src/synthetic_dataset_generator/apps/base.py | 9 ++++- src/synthetic_dataset_generator/apps/chat.py | 19 ++++----- src/synthetic_dataset_generator/apps/eval.py | 4 +- .../apps/textcat.py | 16 +++++--- src/synthetic_dataset_generator/constants.py | 2 +- .../pipelines/base.py | 40 +++++++++++++++++++ .../pipelines/chat.py | 9 ----- .../pipelines/textcat.py | 2 - src/synthetic_dataset_generator/utils.py | 5 +++ 11 files changed, 79 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index a24c9ed..029cd86 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ Optionally, you can use different API providers and models. - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`. - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable. - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`. -- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/v1/`. +- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`. - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`. SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2. diff --git a/examples/ollama_deployment.py b/examples/ollama_deployment.py index 66e0e49..79a3186 100644 --- a/examples/ollama_deployment.py +++ b/examples/ollama_deployment.py @@ -9,6 +9,9 @@ os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" # model id os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" # tokenizer id -os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" +os.environ["MAX_NUM_ROWS"] = "10000" +os.environ["DEFAULT_BATCH_SIZE"] = "5" +os.environ["MAX_NUM_TOKENS"] = "1024" launch() diff --git a/src/synthetic_dataset_generator/apps/base.py b/src/synthetic_dataset_generator/apps/base.py index 39951b7..316f5ca 100644 --- a/src/synthetic_dataset_generator/apps/base.py +++ b/src/synthetic_dataset_generator/apps/base.py @@ -77,10 +77,15 @@ def validate_push_to_hub(org_name, repo_name): return repo_id -def combine_datasets(repo_id: str, dataset: Dataset) -> Dataset: +def combine_datasets( + repo_id: str, dataset: Dataset, oauth_token: Union[OAuthToken, None] +) -> Dataset: try: new_dataset = load_dataset( - repo_id, split="train", download_mode="force_redownload" + repo_id, + split="train", + download_mode="force_redownload", + token=oauth_token.token, ) return concatenate_datasets([dataset, new_dataset]) except Exception: diff --git a/src/synthetic_dataset_generator/apps/chat.py b/src/synthetic_dataset_generator/apps/chat.py index bebdbd8..274207d 100644 --- a/src/synthetic_dataset_generator/apps/chat.py +++ b/src/synthetic_dataset_generator/apps/chat.py @@ -25,12 +25,12 @@ MODEL, SFT_AVAILABLE, ) +from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts from synthetic_dataset_generator.pipelines.chat import ( DEFAULT_DATASET_DESCRIPTIONS, generate_pipeline_code, get_magpie_generator, get_prompt_generator, - get_prompt_rewriter, get_response_generator, ) from synthetic_dataset_generator.pipelines.embeddings import ( @@ -40,6 +40,7 @@ from synthetic_dataset_generator.utils import ( get_argilla_client, get_org_dropdown, + get_random_repo_name, swap_visibility, ) @@ -106,7 +107,6 @@ def generate_dataset( ) -> pd.DataFrame: num_rows = test_max_num_rows(num_rows) progress(0.0, desc="(1/2) Generating instructions") - prompt_rewriter = get_prompt_rewriter() magpie_generator = get_magpie_generator( system_prompt, num_turns, temperature, is_sample ) @@ -117,14 +117,7 @@ def generate_dataset( batch_size = DEFAULT_BATCH_SIZE # create prompt rewrites - inputs = [ - { - "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: " - } - for i in range(int(num_rows / 100)) - ] - batch = list(prompt_rewriter.process(inputs=inputs)) - prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt] + prompt_rewrites = get_rewriten_prompts(system_prompt, num_rows) # create instructions n_processed = 0 @@ -142,6 +135,7 @@ def generate_dataset( batch = list(magpie_generator.process(inputs=inputs)) magpie_results.extend(batch[0]) n_processed += batch_size + random.seed(a=random.randint(0, 2**32 - 1)) progress(0.5, desc="(1/2) Generating instructions") # generate responses @@ -158,6 +152,7 @@ def generate_dataset( responses = list(response_generator.process(inputs=batch)) response_results.extend(responses[0]) n_processed += batch_size + random.seed(a=random.randint(0, 2**32 - 1)) for result in response_results: result["prompt"] = result["instruction"] result["completion"] = result["generation"] @@ -178,6 +173,7 @@ def generate_dataset( responses = list(response_generator.process(inputs=batch)) response_results.extend(responses[0]) n_processed += batch_size + random.seed(a=random.randint(0, 2**32 - 1)) for result in response_results: result["messages"].append( {"role": "assistant", "content": result["generation"]} @@ -236,7 +232,7 @@ def push_dataset_to_hub( dataframe = convert_dataframe_messages(dataframe) progress(0.7, desc="Creating dataset") dataset = Dataset.from_pandas(dataframe) - dataset = combine_datasets(repo_id, dataset) + dataset = combine_datasets(repo_id, dataset, oauth_token) progress(0.9, desc="Pushing dataset") distiset = Distiset({"default": dataset}) distiset.push_to_hub( @@ -600,4 +596,5 @@ def hide_pipeline_code_visibility(): outputs=[dataset_description, system_prompt, num_turns, dataframe], ) app.load(fn=get_org_dropdown, outputs=[org_name]) + app.load(fn=get_random_repo_name, outputs=[repo_name]) app.load(fn=swap_visibility, outputs=main_ui) diff --git a/src/synthetic_dataset_generator/apps/eval.py b/src/synthetic_dataset_generator/apps/eval.py index d33b3f1..81e0bcd 100644 --- a/src/synthetic_dataset_generator/apps/eval.py +++ b/src/synthetic_dataset_generator/apps/eval.py @@ -41,6 +41,7 @@ extract_column_names, get_argilla_client, get_org_dropdown, + get_random_repo_name, pad_or_truncate_list, process_columns, swap_visibility, @@ -359,7 +360,7 @@ def push_dataset_to_hub( ): repo_id = validate_push_to_hub(org_name, repo_name) dataset = Dataset.from_pandas(dataframe) - dataset = combine_datasets(repo_id, dataset) + dataset = combine_datasets(repo_id, dataset, oauth_token) distiset = Distiset({"default": dataset}) distiset.push_to_hub( repo_id=repo_id, @@ -907,3 +908,4 @@ def hide_pipeline_code_visibility(): app.load(fn=swap_visibility, outputs=main_ui) app.load(fn=get_org_dropdown, outputs=[org_name]) + app.load(fn=get_random_repo_name, outputs=[repo_name]) diff --git a/src/synthetic_dataset_generator/apps/textcat.py b/src/synthetic_dataset_generator/apps/textcat.py index c2edf6d..b2dcab7 100644 --- a/src/synthetic_dataset_generator/apps/textcat.py +++ b/src/synthetic_dataset_generator/apps/textcat.py @@ -20,6 +20,7 @@ validate_push_to_hub, ) from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE +from synthetic_dataset_generator.pipelines.base import get_rewriten_prompts from synthetic_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, @@ -35,6 +36,7 @@ get_argilla_client, get_org_dropdown, get_preprocess_labels, + get_random_repo_name, swap_visibility, ) @@ -106,7 +108,7 @@ def generate_dataset( ) updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}." if multi_label: - updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels." + updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is always better than applying too many labels." labeller_generator = get_labeller_generator( system_prompt=updated_system_prompt, labels=labels, @@ -118,6 +120,7 @@ def generate_dataset( # create text classification data n_processed = 0 textcat_results = [] + rewritten_system_prompts = get_rewriten_prompts(system_prompt, num_rows) while n_processed < num_rows: progress( 2 * 0.5 * n_processed / num_rows, @@ -128,25 +131,24 @@ def generate_dataset( batch_size = min(batch_size, remaining_rows) inputs = [] for _ in range(batch_size): + k = 1 if multi_label: num_labels = len(labels) k = int( random.betavariate(alpha=(num_labels - 1), beta=num_labels) * num_labels ) - else: - k = 1 - sampled_labels = random.sample(labels, min(k, len(labels))) random.shuffle(sampled_labels) inputs.append( { - "task": f"{system_prompt}. The text represents the following categories: {', '.join(sampled_labels)}" + "task": f"{random.choice(rewritten_system_prompts)}. The text represents the following categories: {', '.join(sampled_labels)}" } ) batch = list(textcat_generator.process(inputs=inputs)) textcat_results.extend(batch[0]) n_processed += batch_size + random.seed(a=random.randint(0, 2**32 - 1)) for result in textcat_results: result["text"] = result["input_text"] @@ -164,6 +166,7 @@ def generate_dataset( labels_batch = list(labeller_generator.process(inputs=batch)) labeller_results.extend(labels_batch[0]) n_processed += batch_size + random.seed(a=random.randint(0, 2**32 - 1)) progress( 1, total=total_steps, @@ -250,7 +253,7 @@ def push_dataset_to_hub( dataframe.reset_index(drop=True), features=features, ) - dataset = combine_datasets(repo_id, dataset) + dataset = combine_datasets(repo_id, dataset, oauth_token) distiset = Distiset({"default": dataset}) progress(0.9, desc="Pushing dataset") distiset.push_to_hub( @@ -662,3 +665,4 @@ def hide_pipeline_code_visibility(): app.load(fn=swap_visibility, outputs=main_ui) app.load(fn=get_org_dropdown, outputs=[org_name]) + app.load(fn=get_random_repo_name, outputs=[repo_name]) diff --git a/src/synthetic_dataset_generator/constants.py b/src/synthetic_dataset_generator/constants.py index f177137..134a706 100644 --- a/src/synthetic_dataset_generator/constants.py +++ b/src/synthetic_dataset_generator/constants.py @@ -81,7 +81,7 @@ if not SFT_AVAILABLE: warnings.warn( - "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`." + "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm." ) MAGPIE_PRE_QUERY_TEMPLATE = None diff --git a/src/synthetic_dataset_generator/pipelines/base.py b/src/synthetic_dataset_generator/pipelines/base.py index 5102bd9..79f7dfe 100644 --- a/src/synthetic_dataset_generator/pipelines/base.py +++ b/src/synthetic_dataset_generator/pipelines/base.py @@ -1,8 +1,13 @@ +import math +import random + import gradio as gr from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM +from distilabel.steps.tasks import TextGeneration from synthetic_dataset_generator.constants import ( API_KEYS, + DEFAULT_BATCH_SIZE, HUGGINGFACE_BASE_URL, MAGPIE_PRE_QUERY_TEMPLATE, MODEL, @@ -21,6 +26,41 @@ def _get_next_api_key(): return api_key +def _get_prompt_rewriter(): + generation_kwargs = { + "temperature": 1, + } + system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new." + prompt_rewriter = TextGeneration( + llm=_get_llm(generation_kwargs=generation_kwargs), + system_prompt=system_prompt, + use_system_prompt=True, + ) + prompt_rewriter.load() + return prompt_rewriter + + +def get_rewriten_prompts(prompt: str, num_rows: int): + prompt_rewriter = _get_prompt_rewriter() + # create prompt rewrites + inputs = [ + {"instruction": f"Original prompt: {prompt} \nRewritten prompt: "} + for i in range(math.floor(num_rows / 100)) + ] + n_processed = 0 + prompt_rewrites = [prompt] + while n_processed < num_rows: + batch = list( + prompt_rewriter.process( + inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE] + ) + ) + prompt_rewrites += [entry["generation"] for entry in batch[0]] + n_processed += DEFAULT_BATCH_SIZE + random.seed(a=random.randint(0, 2**32 - 1)) + return prompt_rewrites + + def _get_llm(use_magpie_template=False, **kwargs): if OPENAI_BASE_URL: llm = OpenAILLM( diff --git a/src/synthetic_dataset_generator/pipelines/chat.py b/src/synthetic_dataset_generator/pipelines/chat.py index 7b3a11a..f70cb57 100644 --- a/src/synthetic_dataset_generator/pipelines/chat.py +++ b/src/synthetic_dataset_generator/pipelines/chat.py @@ -203,15 +203,6 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample): return magpie_generator -def get_prompt_rewriter(): - generation_kwargs = { - "temperature": 1, - } - prompt_rewriter = TextGeneration(llm=_get_llm(generation_kwargs=generation_kwargs)) - prompt_rewriter.load() - return prompt_rewriter - - def get_response_generator(system_prompt, num_turns, temperature, is_sample): if num_turns == 1: generation_kwargs = { diff --git a/src/synthetic_dataset_generator/pipelines/textcat.py b/src/synthetic_dataset_generator/pipelines/textcat.py index ab1547d..59702b0 100644 --- a/src/synthetic_dataset_generator/pipelines/textcat.py +++ b/src/synthetic_dataset_generator/pipelines/textcat.py @@ -94,7 +94,6 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample): "top_p": 0.95, } llm = _get_llm(generation_kwargs=generation_kwargs) - textcat_generator = GenerateTextClassificationData( llm=llm, difficulty=None if difficulty == "mixed" else difficulty, @@ -111,7 +110,6 @@ def get_labeller_generator(system_prompt, labels, multi_label): "max_new_tokens": MAX_NUM_TOKENS, } llm = _get_llm(generation_kwargs=generation_kwargs) - labeller_generator = TextClassification( llm=llm, context=system_prompt, diff --git a/src/synthetic_dataset_generator/utils.py b/src/synthetic_dataset_generator/utils.py index 967db1b..3884c21 100644 --- a/src/synthetic_dataset_generator/utils.py +++ b/src/synthetic_dataset_generator/utils.py @@ -1,4 +1,5 @@ import json +import uuid import warnings from typing import List, Optional, Union @@ -55,6 +56,10 @@ def list_orgs(oauth_token: Union[OAuthToken, None] = None): return organizations +def get_random_repo_name(): + return f"my-distiset-{str(uuid.uuid4())[:8]}" + + def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None): if oauth_token is not None: orgs = list_orgs(oauth_token) From e1cb58c14db16f7d55b7c3c068a02abdfb1b4a42 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 30 Dec 2024 09:20:30 +0100 Subject: [PATCH 04/15] update examples to not assert check on tokens --- examples/argilla_deployment.py | 1 + examples/hf-serverless_deployment.py | 2 +- examples/ollama_deployment.py | 10 +++++----- examples/openai_deployment.py | 2 +- examples/tgi_or_hf_dedicated.py | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/argilla_deployment.py b/examples/argilla_deployment.py index 1b127cd..bb19a06 100644 --- a/examples/argilla_deployment.py +++ b/examples/argilla_deployment.py @@ -4,6 +4,7 @@ from synthetic_dataset_generator import launch # Follow https://docs.argilla.io/latest/getting_started/quickstart/ to get your Argilla API key and URL +os.environ["HF_TOKEN"] = "hf_..." os.environ["ARGILLA_API_URL"] = ( "https://[your-owner-name]-[your_space_name].hf.space" # argilla base url ) diff --git a/examples/hf-serverless_deployment.py b/examples/hf-serverless_deployment.py index 5c12ee2..561602c 100644 --- a/examples/hf-serverless_deployment.py +++ b/examples/hf-serverless_deployment.py @@ -3,7 +3,7 @@ from synthetic_dataset_generator import launch -assert os.getenv("HF_TOKEN") # push the data to huggingface +os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model diff --git a/examples/ollama_deployment.py b/examples/ollama_deployment.py index 79a3186..0bc3243 100644 --- a/examples/ollama_deployment.py +++ b/examples/ollama_deployment.py @@ -5,13 +5,13 @@ from synthetic_dataset_generator import launch -assert os.getenv("HF_TOKEN") # push the data to huggingface +# os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url -os.environ["MODEL"] = "llama3.1:8b-instruct-q8_0" # model id -os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.1-8B-Instruct" # tokenizer id -os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" +os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S" # model id +os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct" # tokenizer id +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2" os.environ["MAX_NUM_ROWS"] = "10000" -os.environ["DEFAULT_BATCH_SIZE"] = "5" +os.environ["DEFAULT_BATCH_SIZE"] = "2" os.environ["MAX_NUM_TOKENS"] = "1024" launch() diff --git a/examples/openai_deployment.py b/examples/openai_deployment.py index 59a9bc1..6c8617f 100644 --- a/examples/openai_deployment.py +++ b/examples/openai_deployment.py @@ -3,7 +3,7 @@ from synthetic_dataset_generator import launch -assert os.getenv("HF_TOKEN") # push the data to huggingface +os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface os.environ["OPENAI_BASE_URL"] = "https://api.openai.com/v1/" # openai base url os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") # openai api key os.environ["MODEL"] = "gpt-4o" # model id diff --git a/examples/tgi_or_hf_dedicated.py b/examples/tgi_or_hf_dedicated.py index a2def93..3c59bef 100644 --- a/examples/tgi_or_hf_dedicated.py +++ b/examples/tgi_or_hf_dedicated.py @@ -3,7 +3,7 @@ from synthetic_dataset_generator import launch -assert os.getenv("HF_TOKEN") # push the data to huggingface +os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface os.environ["HUGGINGFACE_BASE_URL"] = "http://127.0.0.1:3000/" # dedicated endpoint/TGI os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template os.environ["TOKENIZER_ID"] = ( From 2d84a88407e4cc53911c1c04b5a94eae6097a263 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 30 Dec 2024 09:26:01 +0100 Subject: [PATCH 05/15] add task categories --- src/synthetic_dataset_generator/_distiset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/synthetic_dataset_generator/_distiset.py b/src/synthetic_dataset_generator/_distiset.py index a0ec0dc..000b10a 100644 --- a/src/synthetic_dataset_generator/_distiset.py +++ b/src/synthetic_dataset_generator/_distiset.py @@ -81,6 +81,15 @@ def _get_card( dataset[0] if not isinstance(dataset, dict) else dataset["train"][0] ) + keys = list(sample_records.keys()) + if len(keys) != 2 or not ( + ("label" in keys and "text" in keys) + or ("labels" in keys and "text" in keys) + ): + task_categories = ["text-classification"] + elif "prompt" in keys or "messages" in keys: + task_categories = ["text-generation", "text2text-generation"] + readme_metadata = {} if repo_id and token: readme_metadata = self._extract_readme_metadata(repo_id, token) @@ -90,6 +99,7 @@ def _get_card( "size_categories": size_categories_parser( max(len(dataset) for dataset in self.values()) ), + "task_categories": task_categories, "tags": [ "synthetic", "distilabel", From 8dfc799ae5b2d34e4b68ee1c7a843499e122a2a6 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sun, 5 Jan 2025 08:38:37 +0100 Subject: [PATCH 06/15] add vllm deployment info --- README.md | 2 ++ examples/vllm_deployment.py | 16 ++++++++++++++++ src/synthetic_dataset_generator/constants.py | 11 ++++++++--- .../pipelines/base.py | 14 +++++++++++++- 4 files changed, 39 insertions(+), 4 deletions(-) create mode 100644 examples/vllm_deployment.py diff --git a/README.md b/README.md index 029cd86..b87c18b 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,8 @@ Optionally, you can use different API providers and models. - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`. - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`. - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`. +- `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`. + SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2. diff --git a/examples/vllm_deployment.py b/examples/vllm_deployment.py new file mode 100644 index 0000000..36f6d63 --- /dev/null +++ b/examples/vllm_deployment.py @@ -0,0 +1,16 @@ +# pip install synthetic-dataset-generator +# vllm serve Qwen/Qwen2.5-1.5B-Instruct +import os + +from synthetic_dataset_generator import launch + +# os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface +os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url +os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct" # model id +os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct" # tokenizer id +os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "qwen2" +os.environ["MAX_NUM_ROWS"] = "10000" +os.environ["DEFAULT_BATCH_SIZE"] = "2" +os.environ["MAX_NUM_TOKENS"] = "1024" + +launch() diff --git a/src/synthetic_dataset_generator/constants.py b/src/synthetic_dataset_generator/constants.py index 134a706..dee8a31 100644 --- a/src/synthetic_dataset_generator/constants.py +++ b/src/synthetic_dataset_generator/constants.py @@ -18,23 +18,28 @@ OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL") HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL") +VLLM_BASE_URL = os.getenv("VLLM_BASE_URL") + +# check if model is set correctly if HUGGINGFACE_BASE_URL and MODEL: raise ValueError( "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints." ) if not MODEL: - if OPENAI_BASE_URL or OLLAMA_BASE_URL: + if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL: raise ValueError("`MODEL` is not set. Please provide a model id for inference.") # Check if multiple base URLs are provided base_urls = [ - url for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL] if url + url + for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL] + if url ] if len(base_urls) > 1: raise ValueError( f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time." ) -BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL +BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL # API Keys diff --git a/src/synthetic_dataset_generator/pipelines/base.py b/src/synthetic_dataset_generator/pipelines/base.py index 79f7dfe..34009e1 100644 --- a/src/synthetic_dataset_generator/pipelines/base.py +++ b/src/synthetic_dataset_generator/pipelines/base.py @@ -2,7 +2,7 @@ import random import gradio as gr -from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM +from distilabel.llms import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM from distilabel.steps.tasks import TextGeneration from synthetic_dataset_generator.constants import ( @@ -14,6 +14,7 @@ OLLAMA_BASE_URL, OPENAI_BASE_URL, TOKENIZER_ID, + VLLM_BASE_URL, ) TOKEN_INDEX = 0 @@ -109,6 +110,17 @@ def _get_llm(use_magpie_template=False, **kwargs): tokenizer_id=TOKENIZER_ID or MODEL, **kwargs, ) + elif VLLM_BASE_URL: + if "generation_kwargs" in kwargs: + if "do_sample" in kwargs["generation_kwargs"]: + del kwargs["generation_kwargs"]["do_sample"] + llm = ClientvLLM( + base_url=VLLM_BASE_URL, + model=MODEL, + tokenizer=TOKENIZER_ID or MODEL, + api_key=_get_next_api_key(), + **kwargs, + ) else: llm = InferenceEndpointsLLM( api_key=_get_next_api_key(), From 2ad72e360d54ad9fc6cb709620b5377f96ef2421 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Wed, 8 Jan 2025 10:27:49 +0100 Subject: [PATCH 07/15] Apply suggestions from code review Co-authored-by: Sara Han <127759186+sdiazlor@users.noreply.github.com> --- examples/ollama_deployment.py | 2 +- examples/vllm_deployment.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ollama_deployment.py b/examples/ollama_deployment.py index 0bc3243..a1a4d61 100644 --- a/examples/ollama_deployment.py +++ b/examples/ollama_deployment.py @@ -1,6 +1,6 @@ # pip install synthetic-dataset-generator # ollama serve -# ollama run llama3.1:8b-instruct-q8_0 +# ollama run qwen2.5:32b-instruct-q5_K_S import os from synthetic_dataset_generator import launch diff --git a/examples/vllm_deployment.py b/examples/vllm_deployment.py index 36f6d63..9024469 100644 --- a/examples/vllm_deployment.py +++ b/examples/vllm_deployment.py @@ -4,7 +4,7 @@ from synthetic_dataset_generator import launch -# os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface +os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface os.environ["VLLM_BASE_URL"] = "http://127.0.0.1:8000/" # vllm base url os.environ["MODEL"] = "Qwen/Qwen2.5-1.5B-Instruct" # model id os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-1.5B-Instruct" # tokenizer id From 0aaf1cd091ab50e10ad311604fb0c4db0b04ef77 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 8 Jan 2025 11:12:48 +0100 Subject: [PATCH 08/15] resolve merge conflicts --- examples/argilla_deployment.py | 7 +- .../fine-tune-modernbert-classifier.ipynb | 538 ++++++++++++++++++ .../fine-tune-smollm2-on-synthetic-data.ipynb | 310 ++++++++++ 3 files changed, 854 insertions(+), 1 deletion(-) create mode 100644 examples/fine-tune-modernbert-classifier.ipynb create mode 100644 examples/fine-tune-smollm2-on-synthetic-data.ipynb diff --git a/examples/argilla_deployment.py b/examples/argilla_deployment.py index bb19a06..fee3f0c 100644 --- a/examples/argilla_deployment.py +++ b/examples/argilla_deployment.py @@ -1,4 +1,9 @@ -# pip install synthetic-dataset-generator +# /// script +# requires-python = ">=3.11,<3.12" +# dependencies = [ +# "synthetic-dataset-generator", +# ] +# /// import os from synthetic_dataset_generator import launch diff --git a/examples/fine-tune-modernbert-classifier.ipynb b/examples/fine-tune-modernbert-classifier.ipynb new file mode 100644 index 0000000..47b4ef3 --- /dev/null +++ b/examples/fine-tune-modernbert-classifier.ipynb @@ -0,0 +1,538 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-tune ModernBERT for text classification using synthetic data\n", + "\n", + "LLMs are great general purpose models, but they are not always the best choice for a specific task. Therefore, smaller and more specialized models are important for sustainable, efficient, and cheaper AI.\n", + "A lack of domain sepcific datasets is a common problem for smaller and more specialized models. This is because it is difficult to find a dataset that is both representative and diverse enough for a specific task. We solve this problem by generating a synthetic dataset from an LLM using the `synthetic-data-generator`, which is available as a [Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) or on [GitHub](https://github.com/argilla-io/synthetic-data-generator).\n", + "\n", + "In this example, we will fine-tune a ModernBERT model on a synthetic dataset generated from the synthetic-data-generator. This demonstrates the effectiveness of synthetic data and the novel ModernBERT model, which is a new and improved version of BERT models, with an 8192 token context length, significantly better downstream performance, and much faster processing speeds.\n", + "\n", + "## Install the dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install Pytorch & other libraries\n", + "%pip install \"torch==2.5.0\" \"torchvision==0.20.0\" \n", + "%pip install \"setuptools<71.0.0\" scikit-learn \n", + " \n", + "# Install Hugging Face libraries\n", + "%pip install --upgrade \\\n", + " \"datasets==3.1.0\" \\\n", + " \"accelerate==1.2.1\" \\\n", + " \"hf-transfer==0.1.8\"\n", + " \n", + "# ModernBERT is not yet available in an official release, so we need to install it from github\n", + "%pip install \"git+https://github.com/huggingface/transformers.git@6e0515e99c39444caae39472ee1b2fd76ece32f1\" --upgrade" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The problem\n", + "\n", + "The [nvidia/domain-classifier](https://huggingface.co/nvidia/domain-classifier), is a model that can classify the domain of a text which can help with curating data. This model is cool but is based on the Deberta V3 Base, which is an outdated architecture that requires custom code to run, has a context length of 512 tokens, and is not as fast as the ModernBERT model. The labels for the model are:\n", + "\n", + "```\n", + "'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'\n", + "```\n", + "\n", + "The data on which the model was trained is not available, so we cannot use it for our purposes. We can however generate a synthetic data to solve this problem." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "source": [ + "## Let's generate some data\n", + "\n", + "Let's go to the [hosted Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) to generate the data. This is done in three steps 1) we come up with a dataset description, 2) iterate on the task configuration, and 3) generate and push the data to Hugging Face. A more detailed flow can be found in [this blogpost](https://huggingface.co/blog/synthetic-data-generator). \n", + "\n", + "\n", + "\n", + "For this example, we will generate 1000 examples with a temperature of 1. After some iteration, we come up with the following system prompt:\n", + "\n", + "```\n", + "Long texts (at least 2000 words) from various media sources like Wikipedia, Reddit, Common Crawl, websites, commercials, online forums, books, newspapers and folders that cover multiple topics. Classify the text based on its main subject matter into one of the following categories\n", + "```\n", + "\n", + "We press the \"Push to Hub\" button and wait for the data to be generated. This takes a few minutes and we end up with a dataset with 1000 examples. The labels are nicely distributed across the categories, varied in length, and the texts look diverse and interesting.\n", + "\n", + "\n", + "\n", + "The data is pushed to Argilla to so we recommend inspecting and validating the labels before finetuning the model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Finetuning the ModernBERT model\n", + "\n", + "We mostly rely on the blog from [Phillip Schmid](https://www.philschmid.de/fine-tune-modern-bert-in-2025). I will basic consumer hardware, my Apple M1 Max with 32GB of shared memory. We will use the `datasets` library to load the data and the `transformers` library to finetune the model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/davidberenstein/Documents/programming/argilla/synthetic-data-generator/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "{'text': 'Recently, there has been an increase in property values within the suburban areas of several cities due to improvements in infrastructure and lifestyle amenities such as parks, retail stores, and educational institutions nearby. Additionally, new housing developments are emerging, catering to different family needs with varying sizes and price ranges. These changes have influenced investment decisions for many looking to buy or sell properties.',\n", + " 'label': 14}" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_dataset\n", + "from datasets.arrow_dataset import Dataset\n", + "from datasets.dataset_dict import DatasetDict, IterableDatasetDict\n", + "from datasets.iterable_dataset import IterableDataset\n", + " \n", + "# Dataset id from huggingface.co/dataset\n", + "dataset_id = \"argilla/synthetic-domain-text-classification\"\n", + " \n", + "# Load raw dataset\n", + "train_dataset = load_dataset(dataset_id, split='train')\n", + "\n", + "split_dataset = train_dataset.train_test_split(test_size=0.1)\n", + "split_dataset['train'][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we need to tokenize the data. We will use the `AutoTokenizer` class from the `transformers` library to load the tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 900/900 [00:00<00:00, 4787.61 examples/s]\n", + "Map: 100%|██████████| 100/100 [00:00<00:00, 4163.70 examples/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "dict_keys(['labels', 'input_ids', 'attention_mask'])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import AutoTokenizer\n", + " \n", + "# Model id to load the tokenizer\n", + "model_id = \"answerdotai/ModernBERT-base\"\n", + "\n", + "# Load Tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", + " \n", + "# Tokenize helper function\n", + "def tokenize(batch):\n", + " return tokenizer(batch['text'], padding='max_length', truncation=True, return_tensors=\"pt\")\n", + " \n", + "# Tokenize dataset\n", + "if \"label\" in split_dataset[\"train\"].features.keys():\n", + " split_dataset = split_dataset.rename_column(\"label\", \"labels\") # to match Trainer\n", + "tokenized_dataset = split_dataset.map(tokenize, batched=True, remove_columns=[\"text\"])\n", + " \n", + "tokenized_dataset[\"train\"].features.keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we need to prepare the model. We will use the `AutoModelForSequenceClassification` class from the `transformers` library to load the model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "from transformers import AutoModelForSequenceClassification\n", + " \n", + "# Model id to load the tokenizer\n", + "model_id = \"answerdotai/ModernBERT-base\"\n", + " \n", + "# Prepare model labels - useful for inference\n", + "labels = tokenized_dataset[\"train\"].features[\"labels\"].names\n", + "num_labels = len(labels)\n", + "label2id, id2label = dict(), dict()\n", + "for i, label in enumerate(labels):\n", + " label2id[label] = str(i)\n", + " id2label[str(i)] = label\n", + " \n", + "# Download the model from huggingface.co/models\n", + "model = AutoModelForSequenceClassification.from_pretrained(\n", + " model_id, num_labels=num_labels, label2id=label2id, id2label=id2label,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use a simple F1 score as the evaluation metric." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn.metrics import f1_score\n", + " \n", + "# Metric helper method\n", + "def compute_metrics(eval_pred):\n", + " predictions, labels = eval_pred\n", + " predictions = np.argmax(predictions, axis=1)\n", + " score = f1_score(\n", + " labels, predictions, labels=labels, pos_label=1, average=\"weighted\"\n", + " )\n", + " return {\"f1\": float(score) if score == 1 else score}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we need to define the training arguments. We will use the `TrainingArguments` class from the `transformers` library to define the training arguments." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/davidberenstein/Documents/programming/argilla/synthetic-data-generator/.venv/lib/python3.11/site-packages/transformers/training_args.py:2241: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from huggingface_hub import HfFolder\n", + "from transformers import Trainer, TrainingArguments\n", + " \n", + "# Define training args\n", + "training_args = TrainingArguments(\n", + " output_dir= \"ModernBERT-domain-classifier\",\n", + " per_device_train_batch_size=32,\n", + " per_device_eval_batch_size=16,\n", + " learning_rate=5e-5,\n", + "\t\tnum_train_epochs=5,\n", + " bf16=True, # bfloat16 training \n", + " optim=\"adamw_torch_fused\", # improved optimizer \n", + " # logging & evaluation strategies\n", + " logging_strategy=\"steps\",\n", + " logging_steps=100,\n", + " eval_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " save_total_limit=2,\n", + " load_best_model_at_end=True,\n", + " use_mps_device=True,\n", + " metric_for_best_model=\"f1\",\n", + " # push to hub parameters\n", + " push_to_hub=True,\n", + " hub_strategy=\"every_save\",\n", + " hub_token=HfFolder.get_token(),\n", + ")\n", + " \n", + "# Create a Trainer instance\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=tokenized_dataset[\"train\"],\n", + " eval_dataset=tokenized_dataset[\"test\"],\n", + " compute_metrics=compute_metrics,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \n", + " 20%|██ | 29/145 [11:32<33:16, 17.21s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'eval_loss': 0.729780912399292, 'eval_f1': 0.7743598318036522, 'eval_runtime': 3.5337, 'eval_samples_per_second': 28.299, 'eval_steps_per_second': 1.981, 'epoch': 1.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \n", + " 40%|████ | 58/145 [22:57<25:56, 17.89s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'eval_loss': 0.4369044005870819, 'eval_f1': 0.8310764765820946, 'eval_runtime': 3.3266, 'eval_samples_per_second': 30.061, 'eval_steps_per_second': 2.104, 'epoch': 2.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \n", + " 60%|██████ | 87/145 [35:16<17:06, 17.70s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'eval_loss': 0.6091340184211731, 'eval_f1': 0.8399274488570763, 'eval_runtime': 3.2772, 'eval_samples_per_second': 30.514, 'eval_steps_per_second': 2.136, 'epoch': 3.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 69%|██████▉ | 100/145 [41:03<18:02, 24.06s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'loss': 0.7663, 'grad_norm': 7.232136249542236, 'learning_rate': 1.5517241379310346e-05, 'epoch': 3.45}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \n", + " 80%|████████ | 116/145 [47:23<08:50, 18.30s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'eval_loss': 0.43516409397125244, 'eval_f1': 0.8797674004703547, 'eval_runtime': 3.2975, 'eval_samples_per_second': 30.326, 'eval_steps_per_second': 2.123, 'epoch': 4.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \n", + "100%|██████████| 145/145 [1:00:40<00:00, 19.18s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'eval_loss': 0.39272159337997437, 'eval_f1': 0.8914389523348718, 'eval_runtime': 3.5564, 'eval_samples_per_second': 28.118, 'eval_steps_per_second': 1.968, 'epoch': 5.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 145/145 [1:00:42<00:00, 25.12s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'train_runtime': 3642.7783, 'train_samples_per_second': 1.235, 'train_steps_per_second': 0.04, 'train_loss': 0.535627057634551, 'epoch': 5.0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "events.out.tfevents.1735555878.Davids-MacBook-Pro.local.23438.0: 100%|██████████| 9.32k/9.32k [00:00<00:00, 55.0kB/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "CommitInfo(commit_url='https://huggingface.co/davidberenstein1957/domain-classifier/commit/915f4b03c230cc8f376f13729728f14347400041', commit_message='End of training', commit_description='', oid='915f4b03c230cc8f376f13729728f14347400041', pr_url=None, repo_url=RepoUrl('https://huggingface.co/davidberenstein1957/domain-classifier', endpoint='https://huggingface.co', repo_type='model', repo_id='davidberenstein1957/domain-classifier'), pr_revision=None, pr_num=None)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.train()\n", + "# Save processor and create model card\n", + "tokenizer.save_pretrained(\"ModernBERT-domain-classifier\")\n", + "trainer.create_model_card()\n", + "trainer.push_to_hub()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We get an F1 score of 0.89 on the test set, which is pretty good for the small dataset and time spent." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run inference\n", + "\n", + "We can now load the model and run inference." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Device set to use mps:0\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'label': 'health', 'score': 0.6779336333274841}]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import pipeline\n", + " \n", + "# load model from huggingface.co/models using our repository id\n", + "classifier = pipeline(\n", + " task=\"text-classification\", \n", + " model=\"argilla/ModernBERT-domain-classifier\", \n", + " device=0,\n", + ")\n", + " \n", + "sample = \"Smoking is bad for your health.\"\n", + " \n", + "classifier(sample)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "We have shown that we can generate a synthetic dataset from an LLM and finetune a ModernBERT model on it. This the effectiveness of synthetic data and the novel ModernBERT model, which is new and improved version of BERT models, with 8192 token context length, significantly better downstream performance, and much faster processing speeds. \n", + "\n", + "Pretty cool for 20 minutes of generating data, and an hour of fine-tuning on consumer hardware." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/fine-tune-smollm2-on-synthetic-data.ipynb b/examples/fine-tune-smollm2-on-synthetic-data.ipynb new file mode 100644 index 0000000..0114458 --- /dev/null +++ b/examples/fine-tune-smollm2-on-synthetic-data.ipynb @@ -0,0 +1,310 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-tune a SmolLM on domain-specific synthetic data from a LLM\n", + "\n", + "Yes, smoll models can beat GPT4-like models on domain-specific tasks but don't expect miracles. When comparing smoll vs large, consider all costs and gains like difference performance and the value of using private and local models and data that you own.\n", + "\n", + "The [Hugging Face SmolLM models](https://github.com/huggingface/smollm) are blazingly fast and remarkably powerful. With its 135M, 360M and 1.7B parameter models, it is a great choice for a small and fast model. The great thing about SmolLM is that it is a general-purpose model that can be fine-tuned on domain-specific data.\n", + "\n", + "A lack of domain-specific datasets is a common problem for smaller and more specialized models. This is because it is difficult to find a dataset that is both representative and diverse enough for a specific task. We solve this problem by generating a synthetic dataset from an LLM using the `synthetic-data-generator`, which is available as a [Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) or on [GitHub](https://github.com/argilla-io/synthetic-data-generator).\n", + "\n", + "In this example, we will fine-tune a SmolLM2 model on a synthetic dataset generated from `meta-llama/Meta-Llama-3.1-8B-Instruct` with the `synthetic-data-generator`.\n", + "\n", + "## Install the dependencies\n", + "\n", + "We will install some basic dependencies for the fine-tuning with `trl` but we will use the Synthetic Data Generator UI to generate the synthetic dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install transformers datasets trl torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The problem\n", + "\n", + "Reasoning data has proven to be a fundamental change in the performance of generative models. Reasoning is amazing but it also means the model generates more \"chatty\" during the token generation process, causing the model to become slower and more expensive. For this reason, we want to create a model that can reason without being too chatty. Therefore, we will generate a concise reasoning dataset and fine-tune a SmolLM2 model on it.\n", + "\n", + "## Let's generate some data\n", + "\n", + "Let's go to the [hosted Hugging Face Space](https://huggingface.co/spaces/argilla/synthetic-data-generator) to generate the data. This is done in three steps 1) we come up with a dataset description, 2) iterate on the task configuration, and 3) generate and push the data to Hugging Face. A more detailed flow can be found in [this blog post](https://huggingface.co/blog/synthetic-data-generator). \n", + "\n", + "\n", + "\n", + "For this example, we will generate 5000 chat data examples for a single turn in the conversation. All examples have been generated with a temperature of 1. After some iteration, we come up with the following system prompt:\n", + "\n", + "```\n", + "You are an AI assistant who provides brief and to-the-point responses with logical step-by-step reasoning. Your purpose is to offer straightforward explanations and answers so that you can get to the heart of the issue. Respond with extremely concise, direct justifications and evidence-based conclusions. User questions are direct and concise.\n", + "```\n", + "\n", + "We press the \"Push to Hub\" button and wait for the data to be generated. This takes a few hours and we end up with a dataset with 5000 examples, which is the maximum number of examples we can generate in a single run. You can scale this by deploying a private instance of the Synthetic Data Generator. \n", + "\n", + "\n", + "\n", + "The data is pushed to Argilla too so we recommend inspecting and validating the the data before finetuning the actual model. We applied some basic filters and transformations to the data to make it more suitable for fine-tuning.\n", + "\n", + "## Fine-tune the model\n", + "\n", + "We will use TRL to fine-tune the model. It is part of the Hugging Face ecosystem and works seamlessly on top of datasets generated by the synthetic data generator without needing to do any data transformations.\n", + "\n", + "### Load the model\n", + "\n", + "We will first load the model and tokenizer and set up the chat format." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Import necessary libraries\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from datasets import load_dataset\n", + "from trl import SFTConfig, SFTTrainer, setup_chat_format\n", + "import torch\n", + "import os\n", + "\n", + "device = (\n", + " \"cuda\"\n", + " if torch.cuda.is_available()\n", + " else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", + ")\n", + "\n", + "# Load the model and tokenizer\n", + "model_name = \"HuggingFaceTB/SmolLM2-360M\"\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " pretrained_model_name_or_path=model_name\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)\n", + "\n", + "# Set up the chat format\n", + "model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test the base model\n", + "\n", + "We will first test the base model to see how it performs on the task. During this step we will also generate a prompt for the model to respond to, to see how it performs on the task." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Device set to use mps:0\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'generated_text': 'What is the primary function of mitochondria within a cell?\\n\\nMitochondria are the powerhouses of the cell. They are responsible for the production of ATP (adenosine triphosphate) and the energy required for cellular processes.\\n\\nWhat is the function of the mitochondria in the cell?\\n\\nThe mitochondria are the powerhouses of the cell. They are responsible for the production of ATP (adenosine triphosphate) and the energy required for cellular processes.\\n\\nWhat is the function of the mitochondria in the cell?\\n\\nThe'}]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import pipeline\n", + "\n", + "prompt = \"What is the primary function of mitochondria within a cell?\"\n", + "\n", + "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, device=device)\n", + "pipe(prompt, max_new_tokens=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the dataset\n", + "\n", + "For fine-tuning, we need to load the dataset and tokenize it. We will use the `synthetic-concise-reasoning-sft-filtered` dataset that we generated in the previous step." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 4133/4133 [00:00<00:00, 18478.53 examples/s]\n" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset(\"argilla/synthetic-concise-reasoning-sft-filtered\")\n", + "def tokenize_function(examples):\n", + " examples[\"text\"] = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": examples[\"prompt\"].strip()}, {\"role\": \"assistant\", \"content\": examples[\"completion\"].strip()}], tokenize=False)\n", + " return examples\n", + "ds = ds.map(tokenize_function)\n", + "ds = ds.shuffle()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fine-tune the model\n", + "\n", + "We will now fine-tune the model. We will use the `SFTTrainer` from the `trl` library to fine-tune the model. We will use a batch size of 4 and a learning rate of 5e-5. We will also use the `use_mps_device` flag to use the MPS device if available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\"\n", + "\n", + "# Configure the SFTTrainer\n", + "sft_config = SFTConfig(\n", + " output_dir=\"./sft_output\",\n", + " num_train_epochs=1,\n", + " per_device_train_batch_size=4, # Set according to your GPU memory capacity\n", + " learning_rate=5e-5, # Common starting point for fine-tuning\n", + " logging_steps=100, # Frequency of logging training metrics\n", + " use_mps_device= True if device == \"mps\" else False,\n", + " hub_model_id=\"argilla/SmolLM2-360M-synthetic-concise-reasoning\", # Set a unique name for your model\n", + " push_to_hub=True,\n", + ")\n", + "\n", + "# Initialize the SFTTrainer\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=sft_config,\n", + " train_dataset=ds[\"train\"],\n", + " tokenizer=tokenizer,\n", + ")\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "# {'loss': 1.4498, 'grad_norm': 2.3919131755828857, 'learning_rate': 4e-05, 'epoch': 0.1}\n", + "# {'loss': 1.362, 'grad_norm': 1.6650595664978027, 'learning_rate': 3e-05, 'epoch': 0.19}\n", + "# {'loss': 1.3778, 'grad_norm': 1.4778285026550293, 'learning_rate': 2e-05, 'epoch': 0.29}\n", + "# {'loss': 1.3735, 'grad_norm': 2.1424977779388428, 'learning_rate': 1e-05, 'epoch': 0.39}\n", + "# {'loss': 1.3512, 'grad_norm': 2.3498542308807373, 'learning_rate': 0.0, 'epoch': 0.48}\n", + "# {'train_runtime': 1911.514, 'train_samples_per_second': 1.046, 'train_steps_per_second': 0.262, 'train_loss': 1.3828572998046875, 'epoch': 0.48}\n", + "```\n", + "\n", + "For the example, we did not use a specific validation set but we can see the loss is decreasing, so we assume the model is generalsing well to the training data. To get a better understanding of the model's performance, let's test it again with the same prompt.\n", + "\n", + "### Run inference\n", + "\n", + "We can now run inference with [the fine-tuned model](https://huggingface.co/argilla/SmolLM2-360M-synthetic-concise-reasoning/blob/main/README.md)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Device set to use mps\n" + ] + }, + { + "data": { + "text/plain": [ + "'The primary function of mitochondria is to generate energy for the cell. They are organelles found in eukaryotic cells that convert nutrients into ATP (adenosine triphosphate), which is the primary source of energy for cellular processes.\\nMitochondria are responsible for:\\n\\nEnergy production: Mitochondria produce ATP through a process called oxidative phosphorylation, which involves the transfer of electrons from food molecules to oxygen.\\nEnergy storage: Mitochondria store energy in the form of adenosine triphosphate (ATP), which is used by the cell for various cellular processes.\\nCellular respiration: Mitochondria also participate in cellular respiration, a'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = \"What is the primary function of mitochondria within a cell?\"\n", + "\n", + "generator = pipeline(\n", + " \"text-generation\",\n", + " model=\"argilla/SmolLM2-360M-synthetic-concise-reasoning\",\n", + " device=\"mps\",\n", + ")\n", + "generator(\n", + " [{\"role\": \"user\", \"content\": prompt}], max_new_tokens=128, return_full_text=False\n", + ")[0][\"generated_text\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "We have fine-tuned a SmolLM2 model on a synthetic dataset generated from a large language model. We have seen that the model performs well on the task and that the synthetic data is a great way to generate diverse and representative data for supervised fine-tuning. \n", + "\n", + "In practice, you would likely want to spend more time on the data quality and fine-tuning the model but the flow shows the Synthetic Data Generator is a great tool to generate synthetic data for any task.\n", + "\n", + "Overall, I think it is pretty cool for a couple of hours of generation and fine-tuning on consumer hardware.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 68d064d4418d4e7bf76ca6f5ed9b567b645e59ee Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 8 Jan 2025 11:15:44 +0100 Subject: [PATCH 09/15] add uv run option to examples --- examples/hf-serverless_deployment.py | 7 ++++++- examples/ollama_deployment.py | 7 ++++++- examples/openai_deployment.py | 8 +++++++- examples/tgi_or_hf_dedicated.py | 7 ++++++- examples/vllm_deployment.py | 7 ++++++- 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/hf-serverless_deployment.py b/examples/hf-serverless_deployment.py index 561602c..0ccdd65 100644 --- a/examples/hf-serverless_deployment.py +++ b/examples/hf-serverless_deployment.py @@ -1,4 +1,9 @@ -# pip install synthetic-dataset-generator +# /// script +# requires-python = ">=3.11,<3.12" +# dependencies = [ +# "synthetic-dataset-generator", +# ] +# /// import os from synthetic_dataset_generator import launch diff --git a/examples/ollama_deployment.py b/examples/ollama_deployment.py index a1a4d61..9c3f7e3 100644 --- a/examples/ollama_deployment.py +++ b/examples/ollama_deployment.py @@ -1,4 +1,9 @@ -# pip install synthetic-dataset-generator +# /// script +# requires-python = ">=3.11,<3.12" +# dependencies = [ +# "synthetic-dataset-generator", +# ] +# /// # ollama serve # ollama run qwen2.5:32b-instruct-q5_K_S import os diff --git a/examples/openai_deployment.py b/examples/openai_deployment.py index 6c8617f..a3a1bf8 100644 --- a/examples/openai_deployment.py +++ b/examples/openai_deployment.py @@ -1,4 +1,10 @@ -# pip install synthetic-dataset-generator +# /// script +# requires-python = ">=3.11,<3.12" +# dependencies = [ +# "synthetic-dataset-generator", +# ] +# /// + import os from synthetic_dataset_generator import launch diff --git a/examples/tgi_or_hf_dedicated.py b/examples/tgi_or_hf_dedicated.py index 3c59bef..0c726db 100644 --- a/examples/tgi_or_hf_dedicated.py +++ b/examples/tgi_or_hf_dedicated.py @@ -1,4 +1,9 @@ -# pip install synthetic-dataset-generator +# /// script +# requires-python = ">=3.11,<3.12" +# dependencies = [ +# "synthetic-dataset-generator", +# ] +# /// import os from synthetic_dataset_generator import launch diff --git a/examples/vllm_deployment.py b/examples/vllm_deployment.py index 9024469..bdaee63 100644 --- a/examples/vllm_deployment.py +++ b/examples/vllm_deployment.py @@ -1,4 +1,9 @@ -# pip install synthetic-dataset-generator +# /// script +# requires-python = ">=3.11,<3.12" +# dependencies = [ +# "synthetic-dataset-generator", +# ] +# /// # vllm serve Qwen/Qwen2.5-1.5B-Instruct import os From 70abf20fdf4101353fd57049d85088ebbac8fe37 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 8 Jan 2025 11:15:57 +0100 Subject: [PATCH 10/15] chore add vllm to dependencies --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fdec71c..52eb2c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ readme = "README.md" license = {text = "Apache 2"} dependencies = [ - "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines] @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama", + "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm] @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama", "gradio[oauth]>=5.4.0,<6.0.0", "transformers>=4.44.2,<5.0.0", "sentence-transformers>=3.2.0,<4.0.0", From 6f3d06e1b268cdfdeb2d087586edf4b786969013 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 8 Jan 2025 11:17:47 +0100 Subject: [PATCH 11/15] fix returning duplicate labels --- src/synthetic_dataset_generator/apps/textcat.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/synthetic_dataset_generator/apps/textcat.py b/src/synthetic_dataset_generator/apps/textcat.py index b2dcab7..bc3bce1 100644 --- a/src/synthetic_dataset_generator/apps/textcat.py +++ b/src/synthetic_dataset_generator/apps/textcat.py @@ -186,13 +186,15 @@ def _validate_labels(x): if isinstance(x, str): # single label return [x.lower().strip()] elif isinstance(x, list): # multiple labels - return [ - label.lower().strip() - for label in x - if label.lower().strip() in labels - ] + return list( + set( + label.lower().strip() + for label in x + if label.lower().strip() in labels + ) + ) else: - return [random.choice(labels)] + return list(set([random.choice(labels)])) dataframe["labels"] = dataframe["labels"].apply(_validate_labels) dataframe = dataframe[dataframe["labels"].notna()] From 9b64ead884c648ff2d7dac69203d654ed7e4e5d2 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 8 Jan 2025 11:27:26 +0100 Subject: [PATCH 12/15] update file naming --- ...gilla_deployment.py => argilla-deployment.py} | 0 examples/enforce_mapgie_template copy.py | 14 -------------- examples/fine-tune-modernbert-classifier.ipynb | 2 +- ...ated.py => hf-dedicated-or-tgi-deployment.py} | 0 ...deployment.py => hf-serverless-deployment.py} | 0 ...ollama_deployment.py => ollama-deployment.py} | 2 +- examples/ollama_local.py | 15 --------------- ...openai_deployment.py => openai-deployment.py} | 0 examples/openai_local.py | 16 ---------------- .../{vllm_deployment.py => vllm-deployment.py} | 0 10 files changed, 2 insertions(+), 47 deletions(-) rename examples/{argilla_deployment.py => argilla-deployment.py} (100%) delete mode 100644 examples/enforce_mapgie_template copy.py rename examples/{tgi_or_hf_dedicated.py => hf-dedicated-or-tgi-deployment.py} (100%) rename examples/{hf-serverless_deployment.py => hf-serverless-deployment.py} (100%) rename examples/{ollama_deployment.py => ollama-deployment.py} (90%) delete mode 100644 examples/ollama_local.py rename examples/{openai_deployment.py => openai-deployment.py} (100%) delete mode 100644 examples/openai_local.py rename examples/{vllm_deployment.py => vllm-deployment.py} (100%) diff --git a/examples/argilla_deployment.py b/examples/argilla-deployment.py similarity index 100% rename from examples/argilla_deployment.py rename to examples/argilla-deployment.py diff --git a/examples/enforce_mapgie_template copy.py b/examples/enforce_mapgie_template copy.py deleted file mode 100644 index 6907d91..0000000 --- a/examples/enforce_mapgie_template copy.py +++ /dev/null @@ -1,14 +0,0 @@ -# /// script -# requires-python = ">=3.11,<3.12" -# dependencies = [ -# "synthetic-dataset-generator", -# ] -# /// -import os - -from synthetic_dataset_generator import launch - -os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "my_custom_template" -os.environ["MODEL"] = "google/gemma-2-9b-it" - -launch() diff --git a/examples/fine-tune-modernbert-classifier.ipynb b/examples/fine-tune-modernbert-classifier.ipynb index 47b4ef3..8bd1cba 100644 --- a/examples/fine-tune-modernbert-classifier.ipynb +++ b/examples/fine-tune-modernbert-classifier.ipynb @@ -530,7 +530,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/examples/tgi_or_hf_dedicated.py b/examples/hf-dedicated-or-tgi-deployment.py similarity index 100% rename from examples/tgi_or_hf_dedicated.py rename to examples/hf-dedicated-or-tgi-deployment.py diff --git a/examples/hf-serverless_deployment.py b/examples/hf-serverless-deployment.py similarity index 100% rename from examples/hf-serverless_deployment.py rename to examples/hf-serverless-deployment.py diff --git a/examples/ollama_deployment.py b/examples/ollama-deployment.py similarity index 90% rename from examples/ollama_deployment.py rename to examples/ollama-deployment.py index 9c3f7e3..bd32be1 100644 --- a/examples/ollama_deployment.py +++ b/examples/ollama-deployment.py @@ -10,7 +10,7 @@ from synthetic_dataset_generator import launch -# os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface +os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface os.environ["OLLAMA_BASE_URL"] = "http://127.0.0.1:11434/" # ollama base url os.environ["MODEL"] = "qwen2.5:32b-instruct-q5_K_S" # model id os.environ["TOKENIZER_ID"] = "Qwen/Qwen2.5-32B-Instruct" # tokenizer id diff --git a/examples/ollama_local.py b/examples/ollama_local.py deleted file mode 100644 index bf8535c..0000000 --- a/examples/ollama_local.py +++ /dev/null @@ -1,15 +0,0 @@ -# /// script -# requires-python = ">=3.11,<3.12" -# dependencies = [ -# "synthetic-dataset-generator", -# ] -# /// -import os - -from synthetic_dataset_generator import launch - -assert os.getenv("HF_TOKEN") # push the data to huggingface -os.environ["BASE_URL"] = "http://127.0.0.1:11434/v1/" -os.environ["MODEL"] = "llama3.1" - -launch() diff --git a/examples/openai_deployment.py b/examples/openai-deployment.py similarity index 100% rename from examples/openai_deployment.py rename to examples/openai-deployment.py diff --git a/examples/openai_local.py b/examples/openai_local.py deleted file mode 100644 index 9ab58ef..0000000 --- a/examples/openai_local.py +++ /dev/null @@ -1,16 +0,0 @@ -# /// script -# requires-python = ">=3.11,<3.12" -# dependencies = [ -# "synthetic-dataset-generator", -# ] -# /// -import os - -from synthetic_dataset_generator import launch - -assert os.getenv("HF_TOKEN") # push the data to huggingface -os.environ["BASE_URL"] = "https://api.openai.com/v1/" -os.environ["API_KEY"] = os.getenv("OPENAI_API_KEY") -os.environ["MODEL"] = "gpt-4o" - -launch() diff --git a/examples/vllm_deployment.py b/examples/vllm-deployment.py similarity index 100% rename from examples/vllm_deployment.py rename to examples/vllm-deployment.py From 7b8f6471e4ec604af5e58b40640e494bc9ba9292 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 8 Jan 2025 11:35:40 +0100 Subject: [PATCH 13/15] update text w.r.t. Magpie deployment. --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index b87c18b..7b9ac8f 100644 --- a/README.md +++ b/README.md @@ -91,8 +91,7 @@ Optionally, you can use different API providers and models. - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`. - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`. - -SFT and Chat Data generation is only supported with Hugging Face Inference Endpoints , and you can set the following environment variables use it with models other than Llama3 and Qwen2. +SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables. - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`. - `MAGPIE_PRE_QUERY_TEMPLATE`: Enforce setting the pre-query template for Magpie, which is only supported with Hugging Face Inference Endpoints. `llama3` and `qwen2` are supported out of the box and will use `"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"` and `"<|im_start|>user\n"`, respectively. For other models, you can pass a custom pre-query template string. From 53056239d2a529ccfe6cbdf1a0738ad2c57b71d2 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 10 Jan 2025 18:31:43 +0100 Subject: [PATCH 14/15] Update distilabel dependency to use the 'develop' branch for improved support and compatibility --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 52eb2c8..f7c1792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ readme = "README.md" license = {text = "Apache 2"} dependencies = [ - "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm] @ git+https://github.com/argilla-io/distilabel.git@feat/add-magpie-support-llama-cpp-ollama", + "distilabel[argilla,hf-inference-endpoints,hf-transformers,instructor,llama-cpp,ollama,openai,outlines,vllm] @ git+https://github.com/argilla-io/distilabel.git@develop", "gradio[oauth]>=5.4.0,<6.0.0", "transformers>=4.44.2,<5.0.0", "sentence-transformers>=3.2.0,<4.0.0", From ce401b1090a3afeab2cbe01d6920df0ffe03e0f8 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 10 Jan 2025 18:32:34 +0100 Subject: [PATCH 15/15] Fix typo in README.md regarding video link for Synthetic Data Generator --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7b9ac8f..7d53aa1 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ hf_oauth_scopes: ## Introduction -Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also wathh the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action. +Synthetic Data Generator is a tool that allows you to create high-quality datasets for training and fine-tuning language models. It leverages the power of distilabel and LLMs to generate synthetic data tailored to your specific needs. [The announcement blog](https://huggingface.co/blog/synthetic-data-generator) goes over a practical example of how to use it but you can also watch the [video](https://www.youtube.com/watch?v=nXjVtnGeEss) to see it in action. Supported Tasks: