Skip to content

Commit

Permalink
feat: improve custom schema support in @agentic/core
Browse files Browse the repository at this point in the history
  • Loading branch information
transitive-bullshit committed Aug 8, 2024
1 parent f786ca9 commit 4abde52
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 21 deletions.
10 changes: 2 additions & 8 deletions packages/core/src/create-ai-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,8 @@ export function createAIChain<Result extends types.AIChainResult = string>({
} else if (Msg.isRefusal(message)) {
throw new AbortError(`Model refusal: ${message.refusal}`)
} else if (Msg.isAssistant(message)) {
if (schema && schema.validate) {
const result = schema.validate(message.content)

if (result.success) {
return result.data
}

throw new Error(result.error)
if (schema) {
return schema.parse(message.content)
} else {
return message.content as Result
}
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/create-ai-function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
name: string
/** Description of the function. */
description?: string
/** Zod schema for the arguments string. */
/** Zod schema for the function parameters. */
inputSchema: InputSchema
/**
* Whether or not to enable structured output generation based on the given
Expand Down
41 changes: 30 additions & 11 deletions packages/core/src/schema.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import type { z } from 'zod'

import type * as types from './types'
import { safeParseStructuredOutput } from './parse-structured-output'
import { parseStructuredOutput } from './parse-structured-output'
import { stringifyForModel } from './utils'
import { zodToJsonSchema } from './zod-to-json-schema'

/**
* Used to mark schemas so we can support both Zod and custom schemas.
*/
export const schemaSymbol = Symbol('agentic.schema')
export const validatorSymbol = Symbol('agentic.validator')

export type Schema<TData = unknown> = {
/**
Expand All @@ -18,10 +17,18 @@ export type Schema<TData = unknown> = {
readonly jsonSchema: types.JSONSchema

/**
* Optional. Validates that the structure of a value matches this schema,
* and returns a typed version of the value if it does.
* Parses the value, validates that it matches this schema, and returns a
* typed version of the value if it does. Throw an error if the value does
* not match the schema.
*/
readonly validate?: types.ValidatorFn<TData>
readonly parse: types.ParseFn<TData>

/**
* Parses the value, validates that it matches this schema, and returns a
* typed version of the value if it does. Returns an error message if the
* value does not match the schema, and will never throw an error.
*/
readonly safeParse: types.SafeParseFn<TData>

/**
* Used to mark schemas so we can support both Zod and custom schemas.
Expand All @@ -41,7 +48,7 @@ export function isSchema(value: unknown): value is Schema {
schemaSymbol in value &&
value[schemaSymbol] === true &&
'jsonSchema' in value &&
'validate' in value
'parse' in value
)
}

Expand Down Expand Up @@ -71,16 +78,28 @@ export function asSchema<TData>(
export function createSchema<TData = unknown>(
jsonSchema: types.JSONSchema,
{
validate
parse = (value) => value as TData,
safeParse
}: {
validate?: types.ValidatorFn<TData>
parse?: types.ParseFn<TData>
safeParse?: types.SafeParseFn<TData>
} = {}
): Schema<TData> {
safeParse ??= (value: unknown) => {
try {
const result = parse(value)
return { success: true, data: result }
} catch (err: any) {
return { success: false, error: err.message ?? String(err) }
}
}

return {
[schemaSymbol]: true,
_type: undefined as TData,
jsonSchema,
validate
parse,
safeParse
}
}

Expand All @@ -89,8 +108,8 @@ export function createSchemaFromZodSchema<TData>(
opts: { strict?: boolean } = {}
): Schema<TData> {
return createSchema(zodToJsonSchema(zodSchema, opts), {
validate: (value) => {
return safeParseStructuredOutput(value, zodSchema)
parse: (value) => {
return parseStructuredOutput(value, zodSchema)
}
})
}
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,5 @@ export type SafeParseResult<TData> =
error: string
}

export type ValidatorFn<TData> = (value: unknown) => SafeParseResult<TData>
export type ParseFn<TData> = (value: unknown) => TData
export type SafeParseFn<TData> = (value: unknown) => SafeParseResult<TData>

0 comments on commit 4abde52

Please sign in to comment.