Skip to content

Commit

Permalink
Merge pull request #38 from epinzur/fix_bug3
Browse files Browse the repository at this point in the history
fix query bug
  • Loading branch information
epinzur authored Jun 27, 2024
2 parents 6c9c092 + 2a8229b commit eca79d3
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions ragulate/pipelines/query_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,23 +151,23 @@ def update_progress(self, query_change: int = 0):
self._finished_feedbacks = done

def get_provider(self) -> LLMProvider:
provider_name = self.provider_name.lower()
llm_provider = self.llm_provider.lower()
model_name = self.model_name

if provider_name == "openai":
if llm_provider == "openai":
return OpenAI(model_engine=model_name)
elif provider_name == "azureopenai":
elif llm_provider == "azureopenai":
return AzureOpenAI(deployment_name=model_name)
elif provider_name == "bedrock":
elif llm_provider == "bedrock":
return Bedrock(model_id=model_name)
elif provider_name == "litellm":
elif llm_provider == "litellm":
return LiteLLM(model_engine=model_name)
elif provider_name == "Langchain":
elif llm_provider == "Langchain":
return Langchain(model_engine=model_name)
elif provider_name == "huggingface":
elif llm_provider == "huggingface":
return Huggingface(name=model_name)
else:
raise ValueError(f"Unsupported provider: {provider_name}")
raise ValueError(f"Unsupported provider: {llm_provider}")

def query(self):
query_method = self.get_method()
Expand Down

0 comments on commit eca79d3

Please sign in to comment.