diff --git a/weave/trace_server/llm_completion.py b/weave/trace_server/llm_completion.py index bec4361c856..e8ac86ffb33 100644 --- a/weave/trace_server/llm_completion.py +++ b/weave/trace_server/llm_completion.py @@ -32,20 +32,54 @@ def lite_llm_completion( # This allows us to drop params that are not supported by the LLM provider litellm.drop_params = True + + if supports_n_times(inputs.model) or inputs.n == 1: + try: + res = litellm.completion( + **inputs.model_dump(exclude_none=True), + api_key=api_key, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + ) + return tsi.CompletionsCreateRes(response=res.model_dump()) + except Exception as e: + error_message = str(e) + error_message = error_message.replace("litellm.", "") + return tsi.CompletionsCreateRes(response={"error": error_message}) + + # o1 models with n > 1 + results = [] try: - res = litellm.completion( - **inputs.model_dump(exclude_none=True), - api_key=api_key, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_region_name=aws_region_name, - ) - return tsi.CompletionsCreateRes(response=res.model_dump()) + # get n results + for i in range(inputs.n or 1): + results.append( + litellm.completion( + **inputs.model_dump(exclude_none=True), + api_key=api_key, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + ) + ) except Exception as e: error_message = str(e) error_message = error_message.replace("litellm.", "") return tsi.CompletionsCreateRes(response={"error": error_message}) + final_result = results[0] + for idx, result in enumerate(results): + if idx != 0: + # append choices + final_result.choices.append(result.choices[0]) + + # sum usage + final_result.usage = sum_dict_leaves( + [result.usage.model_dump() for result in results] + ) + + return tsi.CompletionsCreateRes(response=final_result.model_dump()) + def get_bedrock_credentials( model_name: str, @@ -88,3 +122,25 @@ def get_bedrock_credentials( ) return aws_access_key_id, aws_secret_access_key, aws_region_name + + +NO_N_TIMES_MODEL_NAMES = ("o1-mini", "o1-preview", "o1") + + +# if the model name contains any of these strings, we don't support n > 1 +def supports_n_times(model_name: str) -> bool: + return not any(x in model_name for x in NO_N_TIMES_MODEL_NAMES) + + +# copied from weave/trace/weave_client.py +def sum_dict_leaves(dicts: list[dict]) -> dict: + # dicts is a list of dictionaries, that may or may not + # have nested dictionaries. Sum all the leaves that match + result: dict = {} + for d in dicts: + for k, v in d.items(): + if isinstance(v, dict): + result[k] = sum_dict_leaves([result.get(k, {}), v]) + elif v is not None: + result[k] = result.get(k, 0) + v + return result