diff --git a/src/core.ts b/src/core.ts index 6875045..4e476b7 100644 --- a/src/core.ts +++ b/src/core.ts @@ -1,4 +1,5 @@ import { VERSION } from './version'; +import { Stream } from './lib/streaming'; import { GroqError, APIError, @@ -38,6 +39,19 @@ type APIResponseProps = { async function defaultParseResponse(props: APIResponseProps): Promise { const { response } = props; + if (props.options.stream) { + debug('response', response.status, response.url, response.headers, response.body); + + // Note: there is an invariant here that isn't represented in the type system + // that if you set `stream: true` the response type must also be `Stream` + + if (props.options.__streamClass) { + return props.options.__streamClass.fromSSEResponse(response, props.controller) as any; + } + + return Stream.fromSSEResponse(response, props.controller) as any; + } + // fetch refuses to read the body when the status code is 204. if (response.status === 204) { return null as T; @@ -736,6 +750,7 @@ export type RequestOptions | Readable> = idempotencyKey?: string; __binaryResponse?: boolean | undefined; + __streamClass?: typeof Stream; }; // This is required so that we can determine if a given object matches the RequestOptions @@ -756,6 +771,7 @@ const requestOptionsKeys: KeysEnum = { idempotencyKey: true, __binaryResponse: true, + __streamClass: true, }; export const isRequestOptions = (obj: unknown): obj is RequestOptions => { diff --git a/src/resources/chat/completions.ts b/src/resources/chat/completions.ts index 81b6bc8..7bf78a4 100644 --- a/src/resources/chat/completions.ts +++ b/src/resources/chat/completions.ts @@ -3,13 +3,32 @@ import * as Core from 'groq-sdk/core'; import { APIResource } from 'groq-sdk/resource'; import * as CompletionsAPI from 'groq-sdk/resources/chat/completions'; +import { Stream } from 'groq-sdk/lib/streaming'; +import { ChatCompletionChunk } from 'groq-sdk/lib/chat_completions_ext'; export class Completions extends APIResource { /** * Creates a completion for a chat prompt */ - create(body: CompletionCreateParams, options?: Core.RequestOptions): Core.APIPromise { - return this._client.post('/openai/v1/chat/completions', { body, ...options }); + create( + body: ChatCompletionCreateParamsNonStreaming, + options?: Core.RequestOptions, + ): Core.APIPromise; + create( + body: ChatCompletionCreateParamsStreaming, + options?: Core.RequestOptions, + ): Core.APIPromise>; + create( + body: ChatCompletionCreateParamsBase, + options?: Core.RequestOptions, + ): Core.APIPromise | ChatCompletion>; + create( + body: ChatCompletionCreateParams, + options?: Core.RequestOptions, + ): Core.APIPromise | Core.APIPromise> { + return this._client.post('/openai/v1/chat/completions', { body, ...options, stream: body.stream ?? false }) as + | Core.APIPromise + | Core.APIPromise>; } } @@ -111,7 +130,7 @@ export namespace ChatCompletion { } } -export interface CompletionCreateParams { +export interface ChatCompletionCreateParamsBase { messages: Array; model: string; @@ -235,3 +254,15 @@ export namespace Completions { export import ChatCompletion = CompletionsAPI.ChatCompletion; export import CompletionCreateParams = CompletionsAPI.CompletionCreateParams; } + +export interface ChatCompletionCreateParamsNonStreaming extends ChatCompletionCreateParamsBase { + stream?: false; +} + +export interface ChatCompletionCreateParamsStreaming extends ChatCompletionCreateParamsBase { + stream: true; +} + +export type ChatCompletionCreateParams = + | ChatCompletionCreateParamsNonStreaming + | ChatCompletionCreateParamsStreaming;