diff --git a/ragulate/cli_commands/query.py b/ragulate/cli_commands/query.py index 6853199..88f7759 100644 --- a/ragulate/cli_commands/query.py +++ b/ragulate/cli_commands/query.py @@ -17,14 +17,14 @@ def setup_query(subparsers): ) query_parser.add_argument( "-s", - "--script_path", + "--script", type=str, help="The path to the python script that contains the query method", required=True, ) query_parser.add_argument( "-m", - "--method-name", + "--method", type=str, help="The name of the method in the script to run query", required=True, @@ -90,12 +90,34 @@ def setup_query(subparsers): ), action="store_true", ) + query_parser.add_argument( + "--provider", + type=str, + help=("The name of the LLM Provider to use for Evaluation."), + choices=[ + "OpenAI", + "AzureOpenAI", + "Bedrock", + "LiteLLM", + "Langchain", + "Huggingface", + ], + default="OpenAI", + ) + query_parser.add_argument( + "--model", + type=str, + help=( + "The name or id of the LLM model or deployment to use for Evaluation.", + "Generally used in combination with the --provider param.", + ), + ) query_parser.set_defaults(func=lambda args: call_query(**vars(args))) def call_query( name: str, - script_path: str, - method_name: str, + script: str, + method: str, var_name: List[str], var_value: List[str], dataset: List[str], @@ -103,6 +125,8 @@ def call_query( sample: float, seed: int, restart: bool, + provider: str, + model: str, **kwargs, ): if sample <= 0.0 or sample > 1.0: @@ -124,12 +148,14 @@ def call_query( query_pipeline = QueryPipeline( recipe_name=name, - script_path=script_path, - method_name=method_name, + script_path=script, + method_name=method, ingredients=ingredients, datasets=datasets, sample_percent=sample, random_seed=seed, restart_pipeline=restart, + llm_provider=provider, + model_name=model, ) query_pipeline.query() diff --git a/ragulate/pipelines/query_pipeline.py b/ragulate/pipelines/query_pipeline.py index 432b6d7..9124839 100644 --- a/ragulate/pipelines/query_pipeline.py +++ b/ragulate/pipelines/query_pipeline.py @@ -5,7 +5,15 @@ from tqdm import tqdm from trulens_eval import Tru, TruChain -from trulens_eval.feedback.provider import OpenAI +from trulens_eval.feedback.provider import ( + AzureOpenAI, + Bedrock, + Huggingface, + Langchain, + LiteLLM, + OpenAI, +) +from trulens_eval.feedback.provider.base import LLMProvider from trulens_eval.schema.feedback import FeedbackMode, FeedbackResultStatus from ragulate.datasets import BaseDataset @@ -47,7 +55,9 @@ def __init__( datasets: List[BaseDataset], sample_percent: float = 1.0, random_seed: Optional[int] = None, - restart_pipeline: bool = False, + restart_pipeline: Optional[bool] = False, + llm_provider: Optional[str] = "OpenAI", + model_name: Optional[str] = None, **kwargs, ): super().__init__( @@ -61,6 +71,8 @@ def __init__( self.sample_percent = sample_percent self.random_seed = random_seed self.restart_pipeline = restart_pipeline + self.llm_provider = llm_provider + self.model_name = model_name # Set up the signal handler for SIGINT (Ctrl-C) signal.signal(signal.SIGINT, self.signal_handler) @@ -136,11 +148,30 @@ def update_progress(self, query_change: int = 0): self._finished_feedbacks = done + def get_provider(self) -> LLMProvider: + provider_name = self.provider_name.lower() + model_name = self.model_name + + if provider_name == "openai": + return OpenAI(model_engine=model_name) + elif provider_name == "azureopenai": + return AzureOpenAI(deployment_name=model_name) + elif provider_name == "bedrock": + return Bedrock(model_id=model_name) + elif provider_name == "litellm": + return LiteLLM(model_engine=model_name) + elif provider_name == "Langchain": + return Langchain(model_engine=model_name) + elif provider_name == "huggingface": + return Huggingface(name=model_name) + else: + raise ValueError(f"Unsupported provider: {provider_name}") + def query(self): query_method = self.get_method() pipeline = query_method(**self.ingredients) - llm_provider = OpenAI(model_engine="gpt-3.5-turbo") + llm_provider = self.get_provider() feedbacks = Feedbacks(llm_provider=llm_provider, pipeline=pipeline)