Skip to content

Commit

Permalink
feat: relax chatFn types
Browse files Browse the repository at this point in the history
  • Loading branch information
transitive-bullshit committed Aug 7, 2024
1 parent f4b79d6 commit f2169f1
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 10 deletions.
1 change: 1 addition & 0 deletions examples/dexter/bin/election-news-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ async function main() {
})

const chain = createAIChain({
name: 'search_news',
chatFn: chatModel.run.bind(chatModel),
tools: [perigon.functions.pick('search_news_stories'), serper],
params: {
Expand Down
10 changes: 7 additions & 3 deletions packages/core/src/create-ai-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
const functionSet = new AIFunctionSet(tools)
const schema = rawSchema ? asSchema(rawSchema, { strict }) : undefined

// TODO: support custom stopping criteria (like setting a flag in a tool call)

const defaultParams: Partial<types.ChatParams> | undefined =
schema && !functionSet.size
? {
Expand Down Expand Up @@ -190,10 +192,10 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw new AbortError(
'Function calls are not supported; expected tool call'
)
} else if (Msg.isRefusal(message)) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (Msg.isAssistant(message)) {
if (message.refusal) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (schema && schema.validate) {
if (schema && schema.validate) {
const result = schema.validate(message.content)

if (result.success) {
Expand All @@ -212,6 +214,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
throw err
}

console.warn(`Chain "${name}" error:`, err.message)

messages.push(
Msg.user(
`There was an error validating the response. Please check the error message and try again.\nError:\n${getErrorMessage(err)}`
Expand Down
55 changes: 51 additions & 4 deletions packages/core/src/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ export interface Msg {
name?: string
}

export interface LegacyMsg {
content: string | null
role: Msg.Role
function_call?: Msg.Call.Function
tool_calls?: Msg.Call.Tool[]
tool_call_id?: string
name?: string
}

/** Narrowed OpenAI Message types. */
export namespace Msg {
/** Possible roles for a message. */
Expand Down Expand Up @@ -102,8 +111,14 @@ export namespace Msg {
export type Assistant = {
role: 'assistant'
name?: string
content?: string
refusal?: string
content: string
}

/** Message with refusal reason from the assistant. */
export type Refusal = {
role: 'assistant'
name?: string
refusal: string
}

/** Message with arguments to call a function. */
Expand Down Expand Up @@ -193,6 +208,27 @@ export namespace Msg {
}
}

/**
* Create an assistant refusal message. Cleans indentation and newlines by
* default.
*/
export function refusal(
refusal: string,
opts?: {
/** Custom name for the message. */
name?: string
/** Whether to clean extra newlines and indentation. Defaults to true. */
cleanRefusal?: boolean
}
): Msg.Refusal {
const { name, cleanRefusal = true } = opts ?? {}
return {
role: 'assistant',
refusal: cleanRefusal ? cleanStringForModel(refusal) : refusal,
...(name ? { name } : {})
}
}

/** Create a function call message with argumets. */
export function funcCall(
function_call: {
Expand Down Expand Up @@ -257,21 +293,23 @@ export namespace Msg {
// @TODO
response: any
// response: ChatModel.EnrichedResponse
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
const msg = response.choices[0].message as Msg
return narrowResponseMessage(msg)
}

/** Narrow a message received from the API. It only responds with role=assistant */
export function narrowResponseMessage(
msg: Msg
): Msg.Assistant | Msg.FuncCall | Msg.ToolCall {
): Msg.Assistant | Msg.Refusal | Msg.FuncCall | Msg.ToolCall {
if (msg.content === null && msg.tool_calls != null) {
return Msg.toolCall(msg.tool_calls)
} else if (msg.content === null && msg.function_call != null) {
return Msg.funcCall(msg.function_call)
} else if (msg.content !== null && msg.content !== undefined) {
return Msg.assistant(msg.content)
} else if (msg.refusal != null) {
return Msg.refusal(msg.refusal)
} else {
// @TODO: probably don't want to error here
console.log('Invalid message', msg)
Expand All @@ -291,6 +329,10 @@ export namespace Msg {
export function isAssistant(message: Msg): message is Msg.Assistant {
return message.role === 'assistant' && message.content !== null
}
/** Check if a message is an assistant refusal message. */
export function isRefusal(message: Msg): message is Msg.Refusal {
return message.role === 'assistant' && message.refusal !== null
}
/** Check if a message is a function call message with arguments. */
export function isFuncCall(message: Msg): message is Msg.FuncCall {
return message.role === 'assistant' && message.function_call != null
Expand All @@ -312,6 +354,7 @@ export namespace Msg {
export function narrow(message: Msg.System): Msg.System
export function narrow(message: Msg.User): Msg.User
export function narrow(message: Msg.Assistant): Msg.Assistant
export function narrow(message: Msg.Assistant): Msg.Refusal
export function narrow(message: Msg.FuncCall): Msg.FuncCall
export function narrow(message: Msg.FuncResult): Msg.FuncResult
export function narrow(message: Msg.ToolCall): Msg.ToolCall
Expand All @@ -322,6 +365,7 @@ export namespace Msg {
| Msg.System
| Msg.User
| Msg.Assistant
| Msg.Refusal
| Msg.FuncCall
| Msg.FuncResult
| Msg.ToolCall
Expand All @@ -335,6 +379,9 @@ export namespace Msg {
if (isAssistant(message)) {
return message
}
if (isRefusal(message)) {
return message
}
if (isFuncCall(message)) {
return message
}
Expand Down
21 changes: 18 additions & 3 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { z } from 'zod'

import type { AIFunctionSet } from './ai-function-set'
import type { AIFunctionsProvider } from './fns'
import type { Msg } from './message'
import type { LegacyMsg, Msg } from './message'

export type { Msg } from './message'
export type { Schema } from './schema'
Expand Down Expand Up @@ -130,6 +130,10 @@ export interface ChatParams {
user?: string
}

export type LegacyChatParams = Simplify<
Omit<ChatParams, 'messages'> & { messages: LegacyMsg[] }
>

export interface ResponseFormatJSONSchema {
/**
* The name of the response format. Must be a-z, A-Z, 0-9, or contain
Expand Down Expand Up @@ -158,10 +162,21 @@ export interface ResponseFormatJSONSchema {
strict?: boolean
}

/**
* OpenAI has changed some of their types, so instead of trying to support all
* possible types, for these params, just relax them for now.
*/
export type RelaxedChatParams = Simplify<
Omit<ChatParams, 'messages' | 'response_format'> & {
messages: object[]
response_format?: { type: 'text' | 'json_object' | string }
}
>

/** An OpenAI-compatible chat completions API */
export type ChatFn = (
params: Simplify<SetOptional<ChatParams, 'model'>>
) => Promise<{ message: Msg }>
params: Simplify<SetOptional<RelaxedChatParams, 'model'>>
) => Promise<{ message: Msg | LegacyMsg }>

export type AIChainResult = string | Record<string, any>

Expand Down

0 comments on commit f2169f1

Please sign in to comment.