Skip to content

Commit

Permalink
use liteLLM in the generator
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Apr 12, 2024
1 parent a6596a3 commit 4716e19
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
]
dependencies = [
"jinja2",
"openai",
"litellm",
]

[project.urls]
Expand Down
26 changes: 10 additions & 16 deletions src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import openai
from typing import cast

from litellm import completion, ModelResponse
from jinja2 import nodes
from jinja2.ext import Extension

CHAT_MODELS = [
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
]
DEFAULT_MODEL = "gpt-3.5-turbo"


Expand Down Expand Up @@ -45,13 +42,10 @@ def _generate(self, text, model_name=DEFAULT_MODEL):
To tweak the prompt used to generate content, change the variable `messages` .
"""
content = openai.ChatCompletion.create(
model=model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": text},
],
temperature=0.5,
)["choices"][0]["message"]["content"]

return content
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": text},
]

response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages))
return response["choices"][0]["message"]["content"]
1 change: 0 additions & 1 deletion tests/test_run_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ def test_run_prompt_process():
p = Prompt.from_template("run_prompt_process.jinja")
env.extensions["banks.extensions.generate.GenerateExtension"]._generate = mock.MagicMock(return_value="foo bar baz")

# print(p.text({"topic": "climate change"}))
assert p.text({"topic": "climate change"}) == "FOO BAR BAZ"

0 comments on commit 4716e19

Please sign in to comment.