Skip to content

Commit

Permalink
fix model validation when using ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Dec 17, 2024
1 parent d129960 commit 85b97c4
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 73 deletions.
141 changes: 70 additions & 71 deletions src/synthetic_dataset_generator/apps/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,77 +527,76 @@ def hide_pipeline_code_visibility():
label="Distilabel Pipeline Code",
)

load_btn.click(
fn=generate_system_prompt,
inputs=[dataset_description],
outputs=[system_prompt],
show_progress=True,
).then(
fn=generate_sample_dataset,
inputs=[system_prompt, num_turns],
outputs=[dataframe],
show_progress=True,
)

btn_apply_to_sample_dataset.click(
fn=generate_sample_dataset,
inputs=[system_prompt, num_turns],
outputs=[dataframe],
show_progress=True,
)
load_btn.click(
fn=generate_system_prompt,
inputs=[dataset_description],
outputs=[system_prompt],
show_progress=True,
).then(
fn=generate_sample_dataset,
inputs=[system_prompt, num_turns],
outputs=[dataframe],
show_progress=True,
)

btn_push_to_hub.click(
fn=validate_argilla_user_workspace_dataset,
inputs=[repo_name],
outputs=[success_message],
show_progress=True,
).then(
fn=validate_push_to_hub,
inputs=[org_name, repo_name],
outputs=[success_message],
show_progress=True,
).success(
fn=hide_success_message,
outputs=[success_message],
show_progress=True,
).success(
fn=hide_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
show_progress=True,
).success(
fn=push_dataset,
inputs=[
org_name,
repo_name,
system_prompt,
num_turns,
num_rows,
private,
temperature,
pipeline_code,
],
outputs=[success_message],
show_progress=True,
).success(
fn=show_success_message,
inputs=[org_name, repo_name],
outputs=[success_message],
).success(
fn=generate_pipeline_code,
inputs=[system_prompt, num_turns, num_rows, temperature],
outputs=[pipeline_code],
).success(
fn=show_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
)
gr.on(
triggers=[clear_btn_part.click, clear_btn_full.click],
fn=lambda _: ("", "", 1, _get_dataframe()),
inputs=[dataframe],
outputs=[dataset_description, system_prompt, num_turns, dataframe],
)
btn_apply_to_sample_dataset.click(
fn=generate_sample_dataset,
inputs=[system_prompt, num_turns],
outputs=[dataframe],
show_progress=True,
)

btn_push_to_hub.click(
fn=validate_argilla_user_workspace_dataset,
inputs=[repo_name],
outputs=[success_message],
show_progress=True,
).then(
fn=validate_push_to_hub,
inputs=[org_name, repo_name],
outputs=[success_message],
show_progress=True,
).success(
fn=hide_success_message,
outputs=[success_message],
show_progress=True,
).success(
fn=hide_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
show_progress=True,
).success(
fn=push_dataset,
inputs=[
org_name,
repo_name,
system_prompt,
num_turns,
num_rows,
private,
temperature,
pipeline_code,
],
outputs=[success_message],
show_progress=True,
).success(
fn=show_success_message,
inputs=[org_name, repo_name],
outputs=[success_message],
).success(
fn=generate_pipeline_code,
inputs=[system_prompt, num_turns, num_rows, temperature],
outputs=[pipeline_code],
).success(
fn=show_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
)
gr.on(
triggers=[clear_btn_part.click, clear_btn_full.click],
fn=lambda _: ("", "", 1, _get_dataframe()),
inputs=[dataframe],
outputs=[dataset_description, system_prompt, num_turns, dataframe],
)
app.load(fn=get_org_dropdown, outputs=[org_name])
app.load(fn=swap_visibility, outputs=main_ui)
app.load(fn=get_org_dropdown, outputs=[org_name])
8 changes: 6 additions & 2 deletions src/synthetic_dataset_generator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,14 @@
raise ValueError(
f"MAGPIE_PRE_QUERY_TEMPLATE must be either {llama_options} or {qwen_options}."
)
elif MODEL.lower() in llama_options:
elif MODEL.lower() in llama_options or any(
option in MODEL.lower() for option in llama_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif MODEL.lower() in qwen_options:
elif MODEL.lower() in qwen_options or any(
option in MODEL.lower() for option in qwen_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
else:
Expand Down

0 comments on commit 85b97c4

Please sign in to comment.