Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace custom SSE reader with source for System.Net.ServerSentEvents #33

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.ServerSentEvents;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -31,7 +33,7 @@ public override IAsyncEnumerator<StreamingUpdate> GetAsyncEnumerator(Cancellatio

private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator<StreamingUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<Task<ClientResult>> _getResultAsync;
private readonly AsyncStreamingUpdateCollection _enumerable;
Expand All @@ -44,7 +46,7 @@ private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator<Streaming
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IAsyncEnumerator<ServerSentEvent>? _events;
private IAsyncEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingUpdate>? _updates;

private StreamingUpdate? _current;
Expand Down Expand Up @@ -84,7 +86,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingUpdate>.MoveNextAsync()

if (await _events.MoveNextAsync().ConfigureAwait(false))
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
Expand All @@ -104,7 +106,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingUpdate>.MoveNextAsync()
return false;
}

private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync()
private async Task<IAsyncEnumerator<SseItem<byte[]>>> CreateEventEnumeratorAsync()
{
ClientResult result = await _getResultAsync().ConfigureAwait(false);
PipelineResponse response = result.GetRawResponse();
Expand All @@ -115,7 +117,7 @@ private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

AsyncServerSentEventEnumerable enumerable = new(response.ContentStream);
IAsyncEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).EnumerateAsync();
return enumerable.GetAsyncEnumerator(_cancellationToken);
}

Expand Down
3 changes: 2 additions & 1 deletion src/Custom/Assistants/Streaming/StreamingUpdate.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Net.ServerSentEvents;
using System.Text.Json;

namespace OpenAI.Assistants;
Expand Down Expand Up @@ -38,7 +39,7 @@ internal StreamingUpdate(StreamingUpdateReason updateKind)
UpdateKind = updateKind;
}

internal static IEnumerable<StreamingUpdate> FromEvent(ServerSentEvent sseItem)
internal static IEnumerable<StreamingUpdate> FromEvent(SseItem<byte[]> sseItem)
{
StreamingUpdateReason updateKind = StreamingUpdateReasonExtensions.FromSseEventLabel(sseItem.EventType);
using JsonDocument dataDocument = JsonDocument.Parse(sseItem.Data);
Expand Down
11 changes: 6 additions & 5 deletions src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.ServerSentEvents;

#nullable enable

Expand All @@ -30,7 +31,7 @@ public override IEnumerator<StreamingUpdate> GetEnumerator()

private sealed class StreamingUpdateEnumerator : IEnumerator<StreamingUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<ClientResult> _getResult;
private readonly StreamingUpdateCollection _enumerable;
Expand All @@ -42,7 +43,7 @@ private sealed class StreamingUpdateEnumerator : IEnumerator<StreamingUpdate>
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IEnumerator<ServerSentEvent>? _events;
private IEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingUpdate>? _updates;

private StreamingUpdate? _current;
Expand Down Expand Up @@ -81,7 +82,7 @@ public bool MoveNext()

if (_events.MoveNext())
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
Expand All @@ -101,7 +102,7 @@ public bool MoveNext()
return false;
}

private IEnumerator<ServerSentEvent> CreateEventEnumerator()
private IEnumerator<SseItem<byte[]>> CreateEventEnumerator()
{
ClientResult result = _getResult();
PipelineResponse response = result.GetRawResponse();
Expand All @@ -112,7 +113,7 @@ private IEnumerator<ServerSentEvent> CreateEventEnumerator()
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

ServerSentEventEnumerable enumerable = new(response.ContentStream);
IEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).Enumerate();
return enumerable.GetEnumerator();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.ServerSentEvents;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -32,7 +33,7 @@ public override IAsyncEnumerator<StreamingChatCompletionUpdate> GetAsyncEnumerat

private sealed class AsyncStreamingChatUpdateEnumerator : IAsyncEnumerator<StreamingChatCompletionUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<Task<ClientResult>> _getResultAsync;
private readonly AsyncStreamingChatCompletionUpdateCollection _enumerable;
Expand All @@ -45,7 +46,7 @@ private sealed class AsyncStreamingChatUpdateEnumerator : IAsyncEnumerator<Strea
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IAsyncEnumerator<ServerSentEvent>? _events;
private IAsyncEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingChatCompletionUpdate>? _updates;

private StreamingChatCompletionUpdate? _current;
Expand Down Expand Up @@ -85,7 +86,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingChatCompletionUpdate>.MoveNextAs

if (await _events.MoveNextAsync().ConfigureAwait(false))
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
Expand All @@ -106,7 +107,7 @@ async ValueTask<bool> IAsyncEnumerator<StreamingChatCompletionUpdate>.MoveNextAs
return false;
}

private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync()
private async Task<IAsyncEnumerator<SseItem<byte[]>>> CreateEventEnumeratorAsync()
{
ClientResult result = await _getResultAsync().ConfigureAwait(false);
PipelineResponse response = result.GetRawResponse();
Expand All @@ -117,7 +118,7 @@ private async Task<IAsyncEnumerator<ServerSentEvent>> CreateEventEnumeratorAsync
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

AsyncServerSentEventEnumerable enumerable = new(response.ContentStream);
IAsyncEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).EnumerateAsync();
return enumerable.GetAsyncEnumerator(_cancellationToken);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.ServerSentEvents;
using System.Text.Json;

#nullable enable
Expand Down Expand Up @@ -31,7 +32,7 @@ public override IEnumerator<StreamingChatCompletionUpdate> GetEnumerator()

private sealed class StreamingChatUpdateEnumerator : IEnumerator<StreamingChatCompletionUpdate>
{
private const string _terminalData = "[DONE]";
private static ReadOnlySpan<byte> TerminalData => "[DONE]"u8;

private readonly Func<ClientResult> _getResult;
private readonly StreamingChatCompletionUpdateCollection _enumerable;
Expand All @@ -43,7 +44,7 @@ private sealed class StreamingChatUpdateEnumerator : IEnumerator<StreamingChatCo
// // get _updates from sse event
// foreach (var update in _updates) { ... }
// }
private IEnumerator<ServerSentEvent>? _events;
private IEnumerator<SseItem<byte[]>>? _events;
private IEnumerator<StreamingChatCompletionUpdate>? _updates;

private StreamingChatCompletionUpdate? _current;
Expand Down Expand Up @@ -82,7 +83,7 @@ public bool MoveNext()

if (_events.MoveNext())
{
if (_events.Current.Data == _terminalData)
if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData))
{
_current = default;
return false;
Expand All @@ -103,7 +104,7 @@ public bool MoveNext()
return false;
}

private IEnumerator<ServerSentEvent> CreateEventEnumerator()
private IEnumerator<SseItem<byte[]>> CreateEventEnumerator()
{
ClientResult result = _getResult();
PipelineResponse response = result.GetRawResponse();
Expand All @@ -114,7 +115,7 @@ private IEnumerator<ServerSentEvent> CreateEventEnumerator()
throw new InvalidOperationException("Unable to create result from response with null ContentStream");
}

ServerSentEventEnumerable enumerable = new(response.ContentStream);
IEnumerable<SseItem<byte[]>> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).Enumerate();
return enumerable.GetEnumerator();
}

Expand Down
7 changes: 7 additions & 0 deletions src/OpenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
<NoWarn>$(NoWarn),0169</NoWarn>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<!-- Allow use of unsafe code, for System.Net.ServerSentEvents polyfill on netstandard2.0
TODO https://github.com/openai/openai-dotnet/issues/41: Remove once polyfill for
System.Net.ServerSentEvents is removed in favor of referencing the package -->
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
</PropertyGroup>

<PropertyGroup Condition="'$(GITHUB_ACTIONS)' == 'true'">
<!-- Normalize stored file paths in symbols when in a CI build. -->
<ContinuousIntegrationBuild>true</ContinuousIntegrationBuild>
Expand Down
82 changes: 0 additions & 82 deletions src/Utility/AsyncServerSentEventEnumerable.cs

This file was deleted.

24 changes: 0 additions & 24 deletions src/Utility/ServerSentEvent.cs

This file was deleted.

Loading
Loading