From 85b97c47dad6346c8e3a11ad52c1fde7894feee4 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 17 Dec 2024 07:59:34 +0100 Subject: [PATCH] fix model validation when using ollama --- src/synthetic_dataset_generator/apps/chat.py | 141 +++++++++---------- src/synthetic_dataset_generator/constants.py | 8 +- 2 files changed, 76 insertions(+), 73 deletions(-) diff --git a/src/synthetic_dataset_generator/apps/chat.py b/src/synthetic_dataset_generator/apps/chat.py index b2d0c8a..280e607 100644 --- a/src/synthetic_dataset_generator/apps/chat.py +++ b/src/synthetic_dataset_generator/apps/chat.py @@ -527,77 +527,76 @@ def hide_pipeline_code_visibility(): label="Distilabel Pipeline Code", ) - load_btn.click( - fn=generate_system_prompt, - inputs=[dataset_description], - outputs=[system_prompt], - show_progress=True, - ).then( - fn=generate_sample_dataset, - inputs=[system_prompt, num_turns], - outputs=[dataframe], - show_progress=True, - ) - - btn_apply_to_sample_dataset.click( - fn=generate_sample_dataset, - inputs=[system_prompt, num_turns], - outputs=[dataframe], - show_progress=True, - ) + load_btn.click( + fn=generate_system_prompt, + inputs=[dataset_description], + outputs=[system_prompt], + show_progress=True, + ).then( + fn=generate_sample_dataset, + inputs=[system_prompt, num_turns], + outputs=[dataframe], + show_progress=True, + ) - btn_push_to_hub.click( - fn=validate_argilla_user_workspace_dataset, - inputs=[repo_name], - outputs=[success_message], - show_progress=True, - ).then( - fn=validate_push_to_hub, - inputs=[org_name, repo_name], - outputs=[success_message], - show_progress=True, - ).success( - fn=hide_success_message, - outputs=[success_message], - show_progress=True, - ).success( - fn=hide_pipeline_code_visibility, - inputs=[], - outputs=[pipeline_code_ui], - show_progress=True, - ).success( - fn=push_dataset, - inputs=[ - org_name, - repo_name, - system_prompt, - num_turns, - num_rows, - private, - temperature, - pipeline_code, - ], - outputs=[success_message], - show_progress=True, - ).success( - fn=show_success_message, - inputs=[org_name, repo_name], - outputs=[success_message], - ).success( - fn=generate_pipeline_code, - inputs=[system_prompt, num_turns, num_rows, temperature], - outputs=[pipeline_code], - ).success( - fn=show_pipeline_code_visibility, - inputs=[], - outputs=[pipeline_code_ui], - ) - gr.on( - triggers=[clear_btn_part.click, clear_btn_full.click], - fn=lambda _: ("", "", 1, _get_dataframe()), - inputs=[dataframe], - outputs=[dataset_description, system_prompt, num_turns, dataframe], - ) + btn_apply_to_sample_dataset.click( + fn=generate_sample_dataset, + inputs=[system_prompt, num_turns], + outputs=[dataframe], + show_progress=True, + ) + btn_push_to_hub.click( + fn=validate_argilla_user_workspace_dataset, + inputs=[repo_name], + outputs=[success_message], + show_progress=True, + ).then( + fn=validate_push_to_hub, + inputs=[org_name, repo_name], + outputs=[success_message], + show_progress=True, + ).success( + fn=hide_success_message, + outputs=[success_message], + show_progress=True, + ).success( + fn=hide_pipeline_code_visibility, + inputs=[], + outputs=[pipeline_code_ui], + show_progress=True, + ).success( + fn=push_dataset, + inputs=[ + org_name, + repo_name, + system_prompt, + num_turns, + num_rows, + private, + temperature, + pipeline_code, + ], + outputs=[success_message], + show_progress=True, + ).success( + fn=show_success_message, + inputs=[org_name, repo_name], + outputs=[success_message], + ).success( + fn=generate_pipeline_code, + inputs=[system_prompt, num_turns, num_rows, temperature], + outputs=[pipeline_code], + ).success( + fn=show_pipeline_code_visibility, + inputs=[], + outputs=[pipeline_code_ui], + ) + gr.on( + triggers=[clear_btn_part.click, clear_btn_full.click], + fn=lambda _: ("", "", 1, _get_dataframe()), + inputs=[dataframe], + outputs=[dataset_description, system_prompt, num_turns, dataframe], + ) + app.load(fn=get_org_dropdown, outputs=[org_name]) app.load(fn=swap_visibility, outputs=main_ui) - app.load(fn=get_org_dropdown, outputs=[org_name]) diff --git a/src/synthetic_dataset_generator/constants.py b/src/synthetic_dataset_generator/constants.py index 22a9342..0081d9e 100644 --- a/src/synthetic_dataset_generator/constants.py +++ b/src/synthetic_dataset_generator/constants.py @@ -46,10 +46,14 @@ raise ValueError( f"MAGPIE_PRE_QUERY_TEMPLATE must be either {llama_options} or {qwen_options}." ) -elif MODEL.lower() in llama_options: +elif MODEL.lower() in llama_options or any( + option in MODEL.lower() for option in llama_options +): SFT_AVAILABLE = True MAGPIE_PRE_QUERY_TEMPLATE = "llama3" -elif MODEL.lower() in qwen_options: +elif MODEL.lower() in qwen_options or any( + option in MODEL.lower() for option in qwen_options +): SFT_AVAILABLE = True MAGPIE_PRE_QUERY_TEMPLATE = "qwen2" else: