From 21047c829babf9953d29fde2320bf2e9c587acd8 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Thu, 27 Jun 2024 14:04:35 +0200 Subject: [PATCH] shorter param names --- ragulate/cli_commands/query.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/ragulate/cli_commands/query.py b/ragulate/cli_commands/query.py index 1c262bd..88f7759 100644 --- a/ragulate/cli_commands/query.py +++ b/ragulate/cli_commands/query.py @@ -17,14 +17,14 @@ def setup_query(subparsers): ) query_parser.add_argument( "-s", - "--script_path", + "--script", type=str, help="The path to the python script that contains the query method", required=True, ) query_parser.add_argument( "-m", - "--method-name", + "--method", type=str, help="The name of the method in the script to run query", required=True, @@ -91,7 +91,7 @@ def setup_query(subparsers): action="store_true", ) query_parser.add_argument( - "--llm_provider", + "--provider", type=str, help=("The name of the LLM Provider to use for Evaluation."), choices=[ @@ -102,18 +102,22 @@ def setup_query(subparsers): "Langchain", "Huggingface", ], + default="OpenAI", ) query_parser.add_argument( - "--model_name", + "--model", type=str, - help=("The name or id of the LLM model or deployment to use for Evaluation."), + help=( + "The name or id of the LLM model or deployment to use for Evaluation.", + "Generally used in combination with the --provider param.", + ), ) query_parser.set_defaults(func=lambda args: call_query(**vars(args))) def call_query( name: str, - script_path: str, - method_name: str, + script: str, + method: str, var_name: List[str], var_value: List[str], dataset: List[str], @@ -121,8 +125,8 @@ def call_query( sample: float, seed: int, restart: bool, - llm_provider: str, - model_name: str, + provider: str, + model: str, **kwargs, ): if sample <= 0.0 or sample > 1.0: @@ -144,14 +148,14 @@ def call_query( query_pipeline = QueryPipeline( recipe_name=name, - script_path=script_path, - method_name=method_name, + script_path=script, + method_name=method, ingredients=ingredients, datasets=datasets, sample_percent=sample, random_seed=seed, restart_pipeline=restart, - llm_provider=llm_provider, - model_name=model_name, + llm_provider=provider, + model_name=model, ) query_pipeline.query()