Skip to content

Commit

Permalink
feat(weave): run it n times for o1 models (#3286)
Browse files Browse the repository at this point in the history
* run it n times for o1 modles

* pr comments
  • Loading branch information
jwlee64 authored Dec 18, 2024
1 parent 5fd95bd commit 36c17ab
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,54 @@ def lite_llm_completion(

# This allows us to drop params that are not supported by the LLM provider
litellm.drop_params = True

if supports_n_times(inputs.model) or inputs.n == 1:
try:
res = litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
return tsi.CompletionsCreateRes(response=res.model_dump())
except Exception as e:
error_message = str(e)
error_message = error_message.replace("litellm.", "")
return tsi.CompletionsCreateRes(response={"error": error_message})

# o1 models with n > 1
results = []
try:
res = litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
return tsi.CompletionsCreateRes(response=res.model_dump())
# get n results
for i in range(inputs.n or 1):
results.append(
litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
)
except Exception as e:
error_message = str(e)
error_message = error_message.replace("litellm.", "")
return tsi.CompletionsCreateRes(response={"error": error_message})

final_result = results[0]
for idx, result in enumerate(results):
if idx != 0:
# append choices
final_result.choices.append(result.choices[0])

# sum usage
final_result.usage = sum_dict_leaves(
[result.usage.model_dump() for result in results]
)

return tsi.CompletionsCreateRes(response=final_result.model_dump())


def get_bedrock_credentials(
model_name: str,
Expand Down Expand Up @@ -88,3 +122,25 @@ def get_bedrock_credentials(
)

return aws_access_key_id, aws_secret_access_key, aws_region_name


NO_N_TIMES_MODEL_NAMES = ("o1-mini", "o1-preview", "o1")


# if the model name contains any of these strings, we don't support n > 1
def supports_n_times(model_name: str) -> bool:
return not any(x in model_name for x in NO_N_TIMES_MODEL_NAMES)


# copied from weave/trace/weave_client.py
def sum_dict_leaves(dicts: list[dict]) -> dict:
# dicts is a list of dictionaries, that may or may not
# have nested dictionaries. Sum all the leaves that match
result: dict = {}
for d in dicts:
for k, v in d.items():
if isinstance(v, dict):
result[k] = sum_dict_leaves([result.get(k, {}), v])
elif v is not None:
result[k] = result.get(k, 0) + v
return result

0 comments on commit 36c17ab

Please sign in to comment.