-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Budget Manager + Support for Anthropic, Cohere, Palm (100+ LLMs using LiteLLM) #99
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
from dataclasses import asdict | ||
from typing import List, Dict, Callable | ||
import openai | ||
import litellm | ||
import uuid | ||
import re | ||
import tiktoken | ||
import time | ||
|
@@ -26,6 +28,9 @@ | |
|
||
ENCODER = tiktoken.get_encoding("cl100k_base") | ||
|
||
# initialize a budget manager to control costs for gpt-4/other llms | ||
budget_manager = litellm.BudgetManager(project_name="stampy_chat") | ||
|
||
DEBUG_PRINT = True | ||
|
||
def set_debug_print(val: bool): | ||
|
@@ -142,7 +147,7 @@ def construct_prompt(query: str, mode: str, history: List[Dict[str, str]], conte | |
import time | ||
import json | ||
|
||
def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print): | ||
def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print, session_id: str = ""): | ||
try: | ||
# 1. Find the most relevant blocks from the Alignment Research Dataset | ||
yield {"state": "loading", "phase": "semantic"} | ||
|
@@ -181,14 +186,19 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, | |
t1 = time.time() | ||
response = '' | ||
|
||
for chunk in openai.ChatCompletion.create( | ||
# check if budget exceeded for session | ||
if budget_manager.get_current_cost(user=session_id) <= budget_manager.get_total_budget(session_id): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this for the number of allowed tokens or chat calls? Is it a hard total, or does it get reset every now and then? The code is run on gunicorn workers - how will that influence it, as I'm guessing litellm won't communicate across processes? |
||
raise Exception(f"Exceeded the maximum budget for this session") | ||
|
||
for chunk in litellm.completion( | ||
model=COMPLETIONS_MODEL, | ||
messages=prompt, | ||
max_tokens=max_tokens_completion, | ||
stream=True, | ||
temperature=0, # may or may not be a good idea | ||
): | ||
res = chunk["choices"][0]["delta"] | ||
budget_manager.update_cost(completion_obj=response, user=session_id) | ||
if res is not None and res.get("content") is not None: | ||
response += res["content"] | ||
yield {"state": "streaming", "content": res["content"]} | ||
|
@@ -225,13 +235,16 @@ def talk_to_robot_internal(index, query: str, mode: str, history: List[Dict[str, | |
|
||
# convert talk_to_robot_internal from dict generator into json generator | ||
def talk_to_robot(index, query: str, mode: str, history: List[Dict[str, str]], k: int = STANDARD_K, log: Callable = print): | ||
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k, log)) | ||
session_id = str(uuid.uuid4()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand this, then budget manager has an internal dict to count how much a given session has used? But if you're creating a new id with each call to this function, then each session will have max 1 call? I'm planning on adding session ids, as it will be needed for logging anyway, so could you do this by extracting the session id from the request params in the |
||
budget_manager.create_budget(total_budget=10, user=session_id) # init $10 budget | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't hardcode it - create a setting in |
||
yield from (json.dumps(block) for block in talk_to_robot_internal(index, query, mode, history, k, log, session_id=session_id)) | ||
|
||
# wayyy simplified api | ||
def talk_to_robot_simple(index, query: str, log: Callable = print): | ||
res = {'response': ''} | ||
|
||
for block in talk_to_robot_internal(index, query, "default", [], log = log): | ||
session_id = str(uuid.uuid4()) | ||
budget_manager.create_budget(total_budget=10, user=session_id) # init $10 budget | ||
for block in talk_to_robot_internal(index, query, "default", [], log = log, session_id=session_id): | ||
if block['state'] == 'loading' and block['phase'] == 'semantic' and 'citations' in block: | ||
citations = {} | ||
for i, c in enumerate(block['citations']): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does the budget per user get configured? Could you add a new item to
env.py
so that it can be configured? Also, what would the units be? (I had a very quick glance at the litellm docs, but otherwise don't know anything about it)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'll give it a proper look tomorrow)