Skip to content

Commit

Permalink
Merge pull request #21 from dataforgoodfr/feat/async-tracers
Browse files Browse the repository at this point in the history
Feat/async tracers
  • Loading branch information
samuelrince authored Mar 24, 2024
2 parents 7b42d85 + b632c19 commit 6e54b80
Show file tree
Hide file tree
Showing 14 changed files with 417 additions and 123 deletions.
3 changes: 2 additions & 1 deletion genai_impact/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
class GenAIImpactError(Exception):
pass


class TracerInitializationError(GenAIImpactError):
"Tracer is initialized twice"
"""Tracer is initialized twice"""
pass
2 changes: 2 additions & 0 deletions genai_impact/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

class Tracer:
initialized = False

@staticmethod
def init() -> None:
if Tracer.initialized:
raise TracerInitializationError()
init_instruments()
Tracer.initialized = True


def init_instruments() -> None:
init_openai_instrumentor()
init_anthropic_instrumentor()
Expand Down
33 changes: 24 additions & 9 deletions genai_impact/tracers/anthropic_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@

try:
from anthropic import Anthropic as _Anthropic
from anthropic import AsyncAnthropic as _AsyncAnthropic
from anthropic.types import Message as _Message
except ImportError:
_Anthropic = object()
_AsyncAnthropic = object()
_Message = object()


class Message(_Message):
impacts: Impacts


def anthropic_chat_wrapper(
wrapped: Callable, instance: _Anthropic, args: Any, kwargs: Any # noqa: ARG001
) -> Message:
response = wrapped(*args, **kwargs)
def compute_impacts_and_return_response(response: Any) -> Message:
model = models.find_model(provider="anthropic", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
Expand All @@ -29,12 +28,25 @@ def anthropic_chat_wrapper(
output_tokens = response.usage.output_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size,
output_token_count=output_tokens
model_parameter_count=model_size, output_token_count=output_tokens
)
return Message(**response.model_dump(), impacts=impacts)


def anthropic_chat_wrapper(
wrapped: Callable, instance: _Anthropic, args: Any, kwargs: Any # noqa: ARG001
) -> Message:
response = wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


async def anthropic_async_chat_wrapper(
wrapped: Callable, instance: _AsyncAnthropic, args: Any, kwargs: Any # noqa: ARG001
) -> Message:
response = await wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


class AnthropicInstrumentor:
def __init__(self) -> None:
self.wrapped_methods = [
Expand All @@ -43,12 +55,15 @@ def __init__(self) -> None:
"name": "Messages.create",
"wrapper": anthropic_chat_wrapper,
},
{
"module": "anthropic.resources",
"name": "AsyncMessages.create",
"wrapper": anthropic_async_chat_wrapper,
},
]

def instrument(self) -> None:
for wrapper in self.wrapped_methods:
wrap_function_wrapper(
wrapper["module"],
wrapper["name"],
wrapper["wrapper"]
wrapper["module"], wrapper["name"], wrapper["wrapper"]
)
36 changes: 27 additions & 9 deletions genai_impact/tracers/mistralai_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,22 @@
from genai_impact.model_repository import models

try:
from mistralai.async_client import MistralAsyncClient as _MistralAsyncClient
from mistralai.client import MistralClient as _MistralClient
from mistralai.models.chat_completion import (
ChatCompletionResponse as _ChatCompletionResponse,
)
except ImportError:
_MistralClient = object()
_MistralAsyncClient = object()
_ChatCompletionResponse = object()


class ChatCompletionResponse(_ChatCompletionResponse):
impacts: Impacts


def mistralai_chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
def compute_impacts_and_return_response(response: Any) -> ChatCompletionResponse:
model = models.find_model(provider="mistralai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
Expand All @@ -31,12 +30,28 @@ def mistralai_chat_wrapper(
output_tokens = response.usage.completion_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size,
output_token_count=output_tokens
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletionResponse(**response.model_dump(), impacts=impacts)


def mistralai_chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


async def mistralai_async_chat_wrapper(
wrapped: Callable,
instance: _MistralAsyncClient, # noqa: ARG001
args: Any,
kwargs: Any,
) -> ChatCompletionResponse:
response = await wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


class MistralAIInstrumentor:
def __init__(self) -> None:
self.wrapped_methods = [
Expand All @@ -45,12 +60,15 @@ def __init__(self) -> None:
"name": "MistralClient.chat",
"wrapper": mistralai_chat_wrapper,
},
{
"module": "mistralai.async_client",
"name": "MistralAsyncClient.chat",
"wrapper": mistralai_async_chat_wrapper,
},
]

def instrument(self) -> None:
for wrapper in self.wrapped_methods:
wrap_function_wrapper(
wrapper["module"],
wrapper["name"],
wrapper["wrapper"]
wrapper["module"], wrapper["name"], wrapper["wrapper"]
)
36 changes: 26 additions & 10 deletions genai_impact/tracers/openai_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable

from openai.resources.chat import Completions
from openai.resources.chat import AsyncCompletions, Completions
from openai.types.chat import ChatCompletion as _ChatCompletion
from wrapt import wrap_function_wrapper

Expand All @@ -12,10 +12,7 @@ class ChatCompletion(_ChatCompletion):
impacts: Impacts


def openai_chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
def compute_impacts_and_return_response(response: Any) -> ChatCompletion:
model = models.find_model(provider="openai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
Expand All @@ -24,12 +21,28 @@ def openai_chat_wrapper(
output_tokens = response.usage.completion_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size,
output_token_count=output_tokens
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletion(**response.model_dump(), impacts=impacts)


def openai_chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


async def openai_async_chat_wrapper(
wrapped: Callable,
instance: AsyncCompletions, # noqa: ARG001
args: Any,
kwargs: Any,
) -> ChatCompletion:
response = await wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)


class OpenAIInstrumentor:
def __init__(self) -> None:
self.wrapped_methods = [
Expand All @@ -38,12 +51,15 @@ def __init__(self) -> None:
"name": "Completions.create",
"wrapper": openai_chat_wrapper,
},
{
"module": "openai.resources.chat.completions",
"name": "AsyncCompletions.create",
"wrapper": openai_async_chat_wrapper,
},
]

def instrument(self) -> None:
for wrapper in self.wrapped_methods:
wrap_function_wrapper(
wrapper["module"],
wrapper["name"],
wrapper["wrapper"]
wrapper["module"], wrapper["name"], wrapper["wrapper"]
)
Loading

0 comments on commit 6e54b80

Please sign in to comment.