Skip to content

Commit

Permalink
prepare repo for new message sse decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
dleviminzi committed Mar 8, 2024
1 parent f0f9cec commit 0b0a722
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
6 changes: 3 additions & 3 deletions complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ type CompletionResponse struct {
// StreamingCompletionResponse contains the server sent events decoder, the response body from the request, and a
// cancel function to enforce a timeout.
type StreamingCompletionResponse struct {
decoder *SSEDecoder
decoder *CompletionSSEDecoder
body io.ReadCloser
cancel context.CancelFunc
}

// Decode is a method for CompleteStreamResponse that returns the next event
// from the server-sent events decoder, or an error if one occurred.
func (c StreamingCompletionResponse) Decode() (*Event, error) {
func (c StreamingCompletionResponse) Decode() (*CompletionEvent, error) {
return c.decoder.Decode()
}

Expand Down Expand Up @@ -146,5 +146,5 @@ func (c *Client) StreamingCompletionRequest(ctx context.Context, payload Complet
return nil, fmt.Errorf("%s: %s", errorResponse.Error.Type, errorResponse.Error.Message)
}

return &StreamingCompletionResponse{NewSSEDecoder(res.Body), res.Body, cancel}, nil
return &StreamingCompletionResponse{NewCompletionSSEDecoder(res.Body), res.Body, cancel}, nil
}
File renamed without changes.
28 changes: 14 additions & 14 deletions sse_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,39 @@ import (
"strings"
)

// ResponseData represents the data payload in a Server-Sent Events (SSE) message.
type ResponseData struct {
// CompletionEventData represents the data payload in a Server-Sent Events (SSE) message.
type CompletionEventData struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Stop string `json:"stop"`
LogID string `json:"log_id"`
}

// Event represents a single Server-Sent Event. It includes the event type, data, ID, and retry fields.
type Event struct {
// CompletionEvent represents a single Server-Sent CompletionEvent. It includes the event type, data, ID, and retry fields.
type CompletionEvent struct {
Event string
Data *ResponseData
Data *CompletionEventData
ID string
Retry int
}

// SSEDecoder is a decoder for Server-Sent Events. It maintains a buffer reader and the current event being processed.
type SSEDecoder struct {
currentEvent Event
// CompletionSSEDecoder is a decoder for Server-Sent Events. It maintains a buffer reader and the current event being processed.
type CompletionSSEDecoder struct {
currentEvent CompletionEvent
Reader *bufio.Reader
}

// NewSSEDecoder initializes a new SSEDecoder with the provided reader.
func NewSSEDecoder(r io.Reader) *SSEDecoder {
return &SSEDecoder{
// NewCompletionSSEDecoder initializes a new SSEDecoder with the provided reader.
func NewCompletionSSEDecoder(r io.Reader) *CompletionSSEDecoder {
return &CompletionSSEDecoder{
Reader: bufio.NewReader(r),
}
}

// Decode reads from the buffered reader line by line, parses Server-Sent Events and sets fields on the current event.
// It returns the complete event when encountering an empty line, and nil otherwise. It will return EOF when nothing is left.
func (d *SSEDecoder) Decode() (*Event, error) {
func (d *CompletionSSEDecoder) Decode() (*CompletionEvent, error) {
line, err := d.Reader.ReadString('\n')
if err != nil {
return nil, err
Expand All @@ -55,7 +55,7 @@ func (d *SSEDecoder) Decode() (*Event, error) {
}

ev := d.currentEvent
d.currentEvent = Event{ID: ev.ID} // preserve LastEventID for the next event
d.currentEvent = CompletionEvent{ID: ev.ID} // preserve LastEventID for the next event
return &ev, nil
}

Expand All @@ -79,7 +79,7 @@ func (d *SSEDecoder) Decode() (*Event, error) {
case "event":
d.currentEvent.Event = fieldValue
case "data":
var data ResponseData
var data CompletionEventData
err := json.Unmarshal([]byte(fieldValue), &data)
if err != nil {
return nil, fmt.Errorf("error decoding data field: %w", err)
Expand Down
14 changes: 7 additions & 7 deletions sse_decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func TestDecode(t *testing.T) {
name string
input string
wantErr bool
wantEv *Event
wantEv *CompletionEvent
}{
{
name: "has : prefix",
Expand Down Expand Up @@ -42,26 +42,26 @@ func TestDecode(t *testing.T) {
name: "id field",
input: "id: testID\n\r",
wantErr: false,
wantEv: &Event{ID: "testID"},
wantEv: &CompletionEvent{ID: "testID"},
},
{
name: "event field",
input: "event: testEvent\n\r",
wantErr: false,
wantEv: &Event{Event: "testEvent"},
wantEv: &CompletionEvent{Event: "testEvent"},
},
{
name: "retry field",
input: "retry: 5\n\r",
wantErr: false,
wantEv: &Event{Retry: 5},
wantEv: &CompletionEvent{Retry: 5},
},
{
name: "data field",
input: "data: {\"completion\":\"testCompletion\",\"stop_reason\":\"testReason\",\"model\":\"testModel\",\"stop\":\"testStop\",\"log_id\":\"testLogId\"}\n\r",
wantErr: false,
wantEv: &Event{
Data: &ResponseData{
wantEv: &CompletionEvent{
Data: &CompletionEventData{
Completion: "testCompletion",
StopReason: "testReason",
Model: "testModel",
Expand Down Expand Up @@ -92,7 +92,7 @@ func TestDecode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := strings.NewReader(tt.input + "\n")
dec := NewSSEDecoder(r)
dec := NewCompletionSSEDecoder(r)

ev, err := dec.Decode()
if !tt.wantErr {
Expand Down

0 comments on commit 0b0a722

Please sign in to comment.