Skip to content

Commit

Permalink
remove hacky n completions (#3312)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwlee64 authored Jan 3, 2025
1 parent 0b8eaa3 commit 1c60163
Showing 1 changed file with 8 additions and 63 deletions.
71 changes: 8 additions & 63 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,53 +33,20 @@ def lite_llm_completion(
# This allows us to drop params that are not supported by the LLM provider
litellm.drop_params = True

if supports_n_times(inputs.model) or inputs.n == 1:
try:
res = litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
return tsi.CompletionsCreateRes(response=res.model_dump())
except Exception as e:
error_message = str(e)
error_message = error_message.replace("litellm.", "")
return tsi.CompletionsCreateRes(response={"error": error_message})

# o1 models with n > 1
results = []
try:
# get n results
for i in range(inputs.n or 1):
results.append(
litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
)
res = litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
return tsi.CompletionsCreateRes(response=res.model_dump())
except Exception as e:
error_message = str(e)
error_message = error_message.replace("litellm.", "")
return tsi.CompletionsCreateRes(response={"error": error_message})

final_result = results[0]
for idx, result in enumerate(results):
if idx != 0:
# append choices
final_result.choices.append(result.choices[0])

# sum usage
final_result.usage = sum_dict_leaves(
[result.usage.model_dump() for result in results]
)

return tsi.CompletionsCreateRes(response=final_result.model_dump())


def get_bedrock_credentials(
model_name: str,
Expand Down Expand Up @@ -122,25 +89,3 @@ def get_bedrock_credentials(
)

return aws_access_key_id, aws_secret_access_key, aws_region_name


NO_N_TIMES_MODEL_NAMES = ("o1-mini", "o1-preview", "o1")


# if the model name contains any of these strings, we don't support n > 1
def supports_n_times(model_name: str) -> bool:
return not any(x in model_name for x in NO_N_TIMES_MODEL_NAMES)


# copied from weave/trace/weave_client.py
def sum_dict_leaves(dicts: list[dict]) -> dict:
# dicts is a list of dictionaries, that may or may not
# have nested dictionaries. Sum all the leaves that match
result: dict = {}
for d in dicts:
for k, v in d.items():
if isinstance(v, dict):
result[k] = sum_dict_leaves([result.get(k, {}), v])
elif v is not None:
result[k] = result.get(k, 0) + v
return result

0 comments on commit 1c60163

Please sign in to comment.