From 2a8229b8af02ae48f6f87be0d65fc22bcfdb7555 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Thu, 27 Jun 2024 15:13:25 +0200 Subject: [PATCH] fix query bug --- ragulate/pipelines/query_pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ragulate/pipelines/query_pipeline.py b/ragulate/pipelines/query_pipeline.py index 4476fbf..f290f7a 100644 --- a/ragulate/pipelines/query_pipeline.py +++ b/ragulate/pipelines/query_pipeline.py @@ -151,23 +151,23 @@ def update_progress(self, query_change: int = 0): self._finished_feedbacks = done def get_provider(self) -> LLMProvider: - provider_name = self.provider_name.lower() + llm_provider = self.llm_provider.lower() model_name = self.model_name - if provider_name == "openai": + if llm_provider == "openai": return OpenAI(model_engine=model_name) - elif provider_name == "azureopenai": + elif llm_provider == "azureopenai": return AzureOpenAI(deployment_name=model_name) - elif provider_name == "bedrock": + elif llm_provider == "bedrock": return Bedrock(model_id=model_name) - elif provider_name == "litellm": + elif llm_provider == "litellm": return LiteLLM(model_engine=model_name) - elif provider_name == "Langchain": + elif llm_provider == "Langchain": return Langchain(model_engine=model_name) - elif provider_name == "huggingface": + elif llm_provider == "huggingface": return Huggingface(name=model_name) else: - raise ValueError(f"Unsupported provider: {provider_name}") + raise ValueError(f"Unsupported provider: {llm_provider}") def query(self): query_method = self.get_method()