Skip to content

Commit

Permalink
refactor wrapper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
luc committed Mar 23, 2024
1 parent cbf54d5 commit e0a5958
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 46 deletions.
24 changes: 9 additions & 15 deletions genai_impact/tracers/anthropic_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ class Message(_Message):
impacts: Impacts


def anthropic_chat_wrapper(
wrapped: Callable, instance: _Anthropic, args: Any, kwargs: Any # noqa: ARG001
) -> Message:
response = wrapped(*args, **kwargs)
def compute_impacts_and_return_response(response: Any) -> Message:
model = models.find_model(provider="anthropic", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
Expand All @@ -36,21 +33,18 @@ def anthropic_chat_wrapper(
return Message(**response.model_dump(), impacts=impacts)


def anthropic_chat_wrapper(
wrapped: Callable, instance: _Anthropic, args: Any, kwargs: Any # noqa: ARG001
) -> Message:
response = wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


async def anthropic_async_chat_wrapper(
wrapped: Callable, instance: _AsyncAnthropic, args: Any, kwargs: Any # noqa: ARG001
) -> Message:
response = await wrapped(*args, **kwargs)
model = models.find_model(provider="anthropic", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
print(f"Could not find model `{response.model}` for anthropic provider.")
return response
output_tokens = response.usage.output_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return Message(**response.model_dump(), impacts=impacts)
return compute_impacts_and_return_response(response)


class AnthropicInstrumentor:
Expand Down
29 changes: 13 additions & 16 deletions genai_impact/tracers/mistralai_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ class ChatCompletionResponse(_ChatCompletionResponse):
impacts: Impacts


def mistralai_chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
def compute_impacts_and_return_response(response: Any) -> ChatCompletionResponse:
model = models.find_model(provider="mistralai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
Expand All @@ -38,21 +35,21 @@ def mistralai_chat_wrapper(
return ChatCompletionResponse(**response.model_dump(), impacts=impacts)


def mistralai_chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


async def mistralai_async_chat_wrapper(
wrapped: Callable, instance: _MistralAsyncClient, args: Any, kwargs: Any # noqa: ARG001
wrapped: Callable,
instance: _MistralAsyncClient,
args: Any,
kwargs: Any, # noqa: ARG001
) -> ChatCompletionResponse:
response = await wrapped(*args, **kwargs)
model = models.find_model(provider="mistralai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
print(f"Could not find model `{response.model}` for mistralai provider.")
return response
output_tokens = response.usage.completion_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletionResponse(**response.model_dump(), impacts=impacts)
return compute_impacts_and_return_response(response)


class MistralAIInstrumentor:
Expand Down
23 changes: 8 additions & 15 deletions genai_impact/tracers/openai_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ class ChatCompletion(_ChatCompletion):
impacts: Impacts


def openai_chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
def compute_impacts_and_return_response(response: Any) -> ChatCompletion:
model = models.find_model(provider="openai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
Expand All @@ -29,6 +26,11 @@ def openai_chat_wrapper(
)
return ChatCompletion(**response.model_dump(), impacts=impacts)

def openai_chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)

async def openai_async_chat_wrapper(
wrapped: Callable,
Expand All @@ -37,17 +39,8 @@ async def openai_async_chat_wrapper(
kwargs: Any, # noqa: ARG001
) -> ChatCompletion:
response = await wrapped(*args, **kwargs)
model = models.find_model(provider="openai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
print(f"Could not find model `{response.model}` for openai provider.")
return response
output_tokens = response.usage.completion_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletion(**response.model_dump(), impacts=impacts)
return compute_impacts_and_return_response(response)



class OpenAIInstrumentor:
Expand Down

0 comments on commit e0a5958

Please sign in to comment.