Skip to content

Commit

Permalink
Early exit for OpenAI's .with_streaming_response. (#1157)
Browse files Browse the repository at this point in the history
* Add initial fixes and tests

* Update v1 mock server responses

* Remove commented out code

* Add escape for chat completions

* Add escape for embedding async

* Change extra_headers escape conditionals

* Fix comment

* Add skip conditions for v1.7

* Move transaction skip earlier to not record attributes

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
lrafeei and mergify[bot] committed Jun 10, 2024
1 parent 73f3197 commit 45b575f
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 2 deletions.
16 changes: 15 additions & 1 deletion newrelic/hooks/mlmodel_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def wrap_chat_completion_sync(wrapped, instance, args, kwargs):
if not transaction:
return wrapped(*args, **kwargs)

# If `.with_streaming_response.` wrapper used, switch to streaming
# For now, we will exit and instrument this later
if (kwargs.get("extra_headers") or {}).get("X-Stainless-Raw-Response") == "stream":
return wrapped(*args, **kwargs)

settings = transaction.settings if transaction.settings is not None else global_settings()
if not settings.ai_monitoring.enabled:
return wrapped(*args, **kwargs)
Expand Down Expand Up @@ -213,7 +218,11 @@ def create_chat_completion_message_event(

async def wrap_embedding_async(wrapped, instance, args, kwargs):
transaction = current_transaction()
if not transaction or kwargs.get("stream", False):
if (
not transaction
or kwargs.get("stream", False)
or (kwargs.get("extra_headers") or {}).get("X-Stainless-Raw-Response") == "stream"
):
return await wrapped(*args, **kwargs)

settings = transaction.settings if transaction.settings is not None else global_settings()
Expand Down Expand Up @@ -393,6 +402,11 @@ async def wrap_chat_completion_async(wrapped, instance, args, kwargs):
if not transaction:
return await wrapped(*args, **kwargs)

# If `.with_streaming_response.` wrapper used, switch to streaming
# For now, we will exit and instrument this later
if (kwargs.get("extra_headers") or {}).get("X-Stainless-Raw-Response") == "stream":
return await wrapped(*args, **kwargs)

settings = transaction.settings if transaction.settings is not None else global_settings()
if not settings.ai_monitoring.enabled:
return await wrapped(*args, **kwargs)
Expand Down
281 changes: 281 additions & 0 deletions tests/mlmodel_openai/test_chat_completion_stream_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import openai
import pytest
from conftest import get_openai_version # pylint: disable=E0611
from testing_support.fixtures import (
override_llm_token_callback_settings,
reset_core_stats_engine,
Expand All @@ -37,6 +39,15 @@
from newrelic.api.background_task import background_task
from newrelic.api.transaction import add_custom_attribute

# TODO: Once instrumentation support is added for `.with_streaming_response.`
# the validator checks can be uncommented/active.

OPENAI_VERSION = get_openai_version()
SKIP_IF_NO_OPENAI_WITH_STREAMING_RESPONSE = pytest.mark.skipif(
OPENAI_VERSION < (1, 8), reason="OpenAI does not support .with_streaming_response. until v1.8"
)


_test_openai_chat_completion_messages = (
{"role": "system", "content": "You are a scientist."},
{"role": "user", "content": "What is 212 degrees Fahrenheit converted to Celsius?"},
Expand Down Expand Up @@ -161,6 +172,132 @@ def test_openai_chat_completion_sync_with_llm_metadata(set_trace_info, sync_open
assert resp


@SKIP_IF_NO_OPENAI_WITH_STREAMING_RESPONSE
@reset_core_stats_engine()
@pytest.mark.parametrize(
"stream_set, stream_val",
[
(False, None),
(True, True),
(True, False),
],
)
@validate_transaction_metrics(
name="test_chat_completion_stream_v1:test_openai_chat_completion_sync_with_llm_metadata_with_streaming_response_lines",
# custom_metrics=[
# ("Supportability/Python/ML/OpenAI/%s" % openai.__version__, 1),
# ],
background_task=True,
)
# @validate_attributes("agent", ["llm"])
@background_task()
def test_openai_chat_completion_sync_with_llm_metadata_with_streaming_response_lines(
set_trace_info, sync_openai_client, stream_set, stream_val
):
set_trace_info()
add_custom_attribute("llm.conversation_id", "my-awesome-id")
add_custom_attribute("llm.foo", "bar")
add_custom_attribute("non_llm_attr", "python-agent")

create_dict = {
"model": "gpt-3.5-turbo",
"messages": _test_openai_chat_completion_messages,
"temperature": 0.7,
"max_tokens": 100,
}
if stream_set:
create_dict["stream"] = stream_val

with sync_openai_client.chat.completions.with_streaming_response.create(**create_dict) as generator:

for _ in generator.iter_lines():
pass


@SKIP_IF_NO_OPENAI_WITH_STREAMING_RESPONSE
@reset_core_stats_engine()
@pytest.mark.parametrize(
"stream_set, stream_val",
[
(False, None),
(True, True),
(True, False),
],
)
@validate_transaction_metrics(
name="test_chat_completion_stream_v1:test_openai_chat_completion_sync_with_llm_metadata_with_streaming_response_bytes",
# custom_metrics=[
# ("Supportability/Python/ML/OpenAI/%s" % openai.__version__, 1),
# ],
background_task=True,
)
# @validate_attributes("agent", ["llm"])
@background_task()
def test_openai_chat_completion_sync_with_llm_metadata_with_streaming_response_bytes(
set_trace_info, sync_openai_client, stream_set, stream_val
):
set_trace_info()
add_custom_attribute("llm.conversation_id", "my-awesome-id")
add_custom_attribute("llm.foo", "bar")
add_custom_attribute("non_llm_attr", "python-agent")

create_dict = {
"model": "gpt-3.5-turbo",
"messages": _test_openai_chat_completion_messages,
"temperature": 0.7,
"max_tokens": 100,
}
if stream_set:
create_dict["stream"] = stream_val

with sync_openai_client.chat.completions.with_streaming_response.create(**create_dict) as generator:

for _ in generator.iter_bytes():
pass


@SKIP_IF_NO_OPENAI_WITH_STREAMING_RESPONSE
@reset_core_stats_engine()
@pytest.mark.parametrize(
"stream_set, stream_val",
[
(False, None),
(True, True),
(True, False),
],
)
@validate_transaction_metrics(
name="test_chat_completion_stream_v1:test_openai_chat_completion_sync_with_llm_metadata_with_streaming_response_text",
# custom_metrics=[
# ("Supportability/Python/ML/OpenAI/%s" % openai.__version__, 1),
# ],
background_task=True,
)
# @validate_attributes("agent", ["llm"])
@background_task()
def test_openai_chat_completion_sync_with_llm_metadata_with_streaming_response_text(
set_trace_info, sync_openai_client, stream_set, stream_val
):
set_trace_info()
add_custom_attribute("llm.conversation_id", "my-awesome-id")
add_custom_attribute("llm.foo", "bar")
add_custom_attribute("non_llm_attr", "python-agent")

create_dict = {
"model": "gpt-3.5-turbo",
"messages": _test_openai_chat_completion_messages,
"temperature": 0.7,
"max_tokens": 100,
}
if stream_set:
create_dict["stream"] = stream_val

with sync_openai_client.chat.completions.with_streaming_response.create(**create_dict) as generator:

for _ in generator.iter_text():
pass


@reset_core_stats_engine()
@disabled_ai_monitoring_record_content_settings
@validate_custom_events(events_sans_content(chat_completion_recorded_events))
Expand Down Expand Up @@ -367,6 +504,150 @@ async def consumer():
loop.run_until_complete(consumer())


@SKIP_IF_NO_OPENAI_WITH_STREAMING_RESPONSE
@reset_core_stats_engine()
@pytest.mark.parametrize(
"stream_set, stream_val",
[
(False, None),
(True, True),
(True, False),
],
)
# @validate_custom_events(chat_completion_recorded_events)
# @validate_custom_event_count(count=4)
@validate_transaction_metrics(
"test_chat_completion_stream_v1:test_openai_chat_completion_async_with_llm_metadata_with_streaming_response_lines",
# scoped_metrics=[("Llm/completion/OpenAI/create", 1)],
# rollup_metrics=[("Llm/completion/OpenAI/create", 1)],
# custom_metrics=[
# ("Supportability/Python/ML/OpenAI/%s" % openai.__version__, 1),
# ],
background_task=True,
)
# @validate_attributes("agent", ["llm"])
@background_task()
def test_openai_chat_completion_async_with_llm_metadata_with_streaming_response_lines(
loop, set_trace_info, async_openai_client, stream_set, stream_val
):
set_trace_info()
add_custom_attribute("llm.conversation_id", "my-awesome-id")
add_custom_attribute("llm.foo", "bar")
add_custom_attribute("non_llm_attr", "python-agent")
create_dict = {
"model": "gpt-3.5-turbo",
"messages": _test_openai_chat_completion_messages,
"temperature": 0.7,
"max_tokens": 100,
}
if stream_set:
create_dict["stream"] = stream_val

async def consumer():
async with async_openai_client.chat.completions.with_streaming_response.create(**create_dict) as generator:

async for _ in generator.iter_lines():
pass

loop.run_until_complete(consumer())


@SKIP_IF_NO_OPENAI_WITH_STREAMING_RESPONSE
@reset_core_stats_engine()
@pytest.mark.parametrize(
"stream_set, stream_val",
[
(False, None),
(True, True),
(True, False),
],
)
# @validate_custom_events(chat_completion_recorded_events)
# @validate_custom_event_count(count=4)
@validate_transaction_metrics(
"test_chat_completion_stream_v1:test_openai_chat_completion_async_with_llm_metadata_with_streaming_response_bytes",
# scoped_metrics=[("Llm/completion/OpenAI/create", 1)],
# rollup_metrics=[("Llm/completion/OpenAI/create", 1)],
# custom_metrics=[
# ("Supportability/Python/ML/OpenAI/%s" % openai.__version__, 1),
# ],
background_task=True,
)
# @validate_attributes("agent", ["llm"])
@background_task()
def test_openai_chat_completion_async_with_llm_metadata_with_streaming_response_bytes(
loop, set_trace_info, async_openai_client, stream_set, stream_val
):
set_trace_info()
add_custom_attribute("llm.conversation_id", "my-awesome-id")
add_custom_attribute("llm.foo", "bar")
add_custom_attribute("non_llm_attr", "python-agent")
create_dict = {
"model": "gpt-3.5-turbo",
"messages": _test_openai_chat_completion_messages,
"temperature": 0.7,
"max_tokens": 100,
}
if stream_set:
create_dict["stream"] = stream_val

async def consumer():
async with async_openai_client.chat.completions.with_streaming_response.create(**create_dict) as generator:

async for _ in generator.iter_bytes():
pass

loop.run_until_complete(consumer())


@SKIP_IF_NO_OPENAI_WITH_STREAMING_RESPONSE
@reset_core_stats_engine()
@pytest.mark.parametrize(
"stream_set, stream_val",
[
(False, None),
(True, True),
(True, False),
],
)
# @validate_custom_events(chat_completion_recorded_events)
# @validate_custom_event_count(count=4)
@validate_transaction_metrics(
"test_chat_completion_stream_v1:test_openai_chat_completion_async_with_llm_metadata_with_streaming_response_text",
# scoped_metrics=[("Llm/completion/OpenAI/create", 1)],
# rollup_metrics=[("Llm/completion/OpenAI/create", 1)],
# custom_metrics=[
# ("Supportability/Python/ML/OpenAI/%s" % openai.__version__, 1),
# ],
background_task=True,
)
# @validate_attributes("agent", ["llm"])
@background_task()
def test_openai_chat_completion_async_with_llm_metadata_with_streaming_response_text(
loop, set_trace_info, async_openai_client, stream_set, stream_val
):
set_trace_info()
add_custom_attribute("llm.conversation_id", "my-awesome-id")
add_custom_attribute("llm.foo", "bar")
add_custom_attribute("non_llm_attr", "python-agent")
create_dict = {
"model": "gpt-3.5-turbo",
"messages": _test_openai_chat_completion_messages,
"temperature": 0.7,
"max_tokens": 100,
}
if stream_set:
create_dict["stream"] = stream_val

async def consumer():
async with async_openai_client.chat.completions.with_streaming_response.create(**create_dict) as generator:

async for _ in generator.iter_text():
pass

loop.run_until_complete(consumer())


@reset_core_stats_engine()
@disabled_ai_monitoring_record_content_settings
@validate_custom_events(events_sans_content(chat_completion_recorded_events))
Expand Down
2 changes: 1 addition & 1 deletion tests/mlmodel_openai/test_embeddings_stream_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_openai_embedding_sync(set_trace_info, sync_openai_stream_client):
@background_task()
def test_openai_embedding_async(loop, set_trace_info, async_openai_stream_client):
"""
Does not instrumenting streamed embeddings.
Does not instrument streamed embeddings.
"""
set_trace_info()

Expand Down

0 comments on commit 45b575f

Please sign in to comment.