Skip to content

Commit

Permalink
Merge pull request #3 from argilla-io/feat/avoid-usage-withou-argilla
Browse files Browse the repository at this point in the history
feat enable usage without argilla
  • Loading branch information
davidberenstein1957 authored Dec 3, 2024
2 parents 0a0f99c + 2cf2cd7 commit da59bd9
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 22 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ hf_oauth_scopes:
<img alt="CI" src="https://img.shields.io/pypi/v/synthetic-dataset-generator.svg?style=flat-round&logo=pypi&logoColor=white">
</a>
<a href="https://pepy.tech/project/synthetic-dataset-generator">
<img alt="CI" src="https://static.pepy.tech/personalized-badge/argilla?period=month&units=international_system&left_color=grey&right_color=blue&left_text=pypi%20downloads/month">
<img alt="CI" src="https://static.pepy.tech/personalized-badge/synthetic-dataset-generator?period=month&units=international_system&left_color=grey&right_color=blue&left_text=pypi%20downloads/month">
</a>
<a href="https://huggingface.co/spaces/argilla/synthetic-data-generator?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg"/>
Expand Down Expand Up @@ -80,16 +80,25 @@ pip install synthetic-dataset-generator

### Environment Variables

- `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained).
- `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run *Free* Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained).

Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:

- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.

## Quick Start
## Quickstart

```bash
python app.py
```

### Argilla integration

Argilla is a open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).

![Argilla integration](https://huggingface.co/spaces/argilla/synthetic-data-generator/resolve/main/assets/argilla.png)

## Custom synthetic data generation?

Each pipeline is based on distilabel, so you can easily change the LLM or the pipeline steps.
Expand Down
Binary file added assets/argilla.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions src/distilabel_dataset_generator/apps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,27 @@ def get_success_message_row() -> gr.Markdown:

def show_success_message(org_name, repo_name) -> gr.Markdown:
client = get_argilla_client()
if client is None:
return gr.Markdown(
value="""
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
<p style="margin-top: 0.5em;">
The generated dataset is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks. Your dataset is now available at:
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
https://huggingface.co/datasets/{org_name}/{repo_name}
</a>
</p>
<p style="margin-top: 1em; font-size: 0.9em; color: #333;">
By configuring an `ARGILLA_API_URL` and `ARGILLA_API_KEY` you can curate the dataset in Argilla.
Unfamiliar with Argilla? Here are some docs to help you get started:
<br>• <a href="https://docs.argilla.io/latest/getting_started/quickstart/" target="_blank">How to get started with Argilla</a>
<br>• <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
<br>• <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
</p>
</div>
"""
)
argilla_api_url = client.api_url
return gr.Markdown(
value=f"""
Expand Down
25 changes: 14 additions & 11 deletions src/distilabel_dataset_generator/apps/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
extract_column_names,
get_argilla_client,
get_org_dropdown,
pad_or_truncate_list,
process_columns,
swap_visibility,
pad_or_truncate_list,
)


Expand Down Expand Up @@ -334,8 +334,10 @@ def push_dataset(
push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
try:
progress(0.1, desc="Setting up user and workspace")
client = get_argilla_client()
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
client = get_argilla_client()
if client is None:
return ""
if eval_type == "ultrafeedback":
num_generations = len((dataframe["generations"][0]))
fields = [
Expand Down Expand Up @@ -580,6 +582,7 @@ def push_dataset(
def show_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=True)}


def hide_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=False)}

Expand Down Expand Up @@ -708,15 +711,15 @@ def hide_pipeline_code_visibility():
visible=False,
) as pipeline_code_ui:
code = generate_pipeline_code(
repo_id=search_in.value,
aspects=aspects_instruction_response.value,
instruction_column=instruction_instruction_response,
response_columns=response_instruction_response,
prompt_template=prompt_template.value,
structured_output=structured_output.value,
num_rows=num_rows.value,
eval_type=eval_type.value,
)
repo_id=search_in.value,
aspects=aspects_instruction_response.value,
instruction_column=instruction_instruction_response,
response_columns=response_instruction_response,
prompt_template=prompt_template.value,
structured_output=structured_output.value,
num_rows=num_rows.value,
eval_type=eval_type.value,
)
pipeline_code = gr.Code(
value=code,
language="python",
Expand Down
4 changes: 3 additions & 1 deletion src/distilabel_dataset_generator/apps/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,10 @@ def push_dataset(
push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
try:
progress(0.1, desc="Setting up user and workspace")
client = get_argilla_client()
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
client = get_argilla_client()
if client is None:
return ""
if "messages" in dataframe.columns:
settings = rg.Settings(
fields=[
Expand Down
16 changes: 9 additions & 7 deletions src/distilabel_dataset_generator/apps/textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres
labels = data["labels"]
return system_prompt, labels

def generate_sample_dataset(system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()):

def generate_sample_dataset(
system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()
):
dataframe = generate_dataset(
system_prompt=system_prompt,
difficulty=difficulty,
Expand Down Expand Up @@ -138,11 +141,7 @@ def generate_dataset(
# create final dataset
distiset_results = []
for result in labeller_results:
record = {
key: result[key]
for key in ["labels", "text"]
if key in result
}
record = {key: result[key] for key in ["labels", "text"] if key in result}
distiset_results.append(record)

dataframe = pd.DataFrame(distiset_results)
Expand Down Expand Up @@ -212,13 +211,16 @@ def push_dataset(
push_dataset_to_hub(
dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
)

dataframe = dataframe[
(dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
]
try:
progress(0.1, desc="Setting up user and workspace")
client = get_argilla_client()
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
client = get_argilla_client()
if client is None:
return ""
labels = get_preprocess_labels(labels)
settings = rg.Settings(
fields=[
Expand Down

0 comments on commit da59bd9

Please sign in to comment.