diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 2f02dd132..7d32d7817 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -37,11 +37,10 @@ import { GenerateRequest, GenerationCommonConfigSchema, MessageData, - ModelAction, ModelArgument, ModelMiddleware, - ModelReference, Part, + resolveModel, } from './model.js'; import { ExecutablePrompt } from './prompt.js'; import { resolveTools, ToolArgument, toToolDefinition } from './tool.js'; @@ -138,56 +137,6 @@ export async function toGenerateRequest( return out; } -interface ResolvedModel { - modelAction: ModelAction; - config?: z.infer; - version?: string; -} - -async function resolveModel( - registry: Registry, - options: GenerateOptions -): Promise { - let model = options.model; - let out: ResolvedModel; - let modelId: string; - - if (!model) { - throw new GenkitError({ - status: 'INVALID_ARGUMENT', - message: 'Must supply a `model` to `generate()` calls.', - }); - } - if (typeof model === 'string') { - modelId = model; - out = { modelAction: await registry.lookupAction(`/model/${model}`) }; - } else if (model.hasOwnProperty('__action')) { - modelId = (model as ModelAction).__action.name; - out = { modelAction: model as ModelAction }; - } else { - const ref = model as ModelReference; - modelId = ref.name; - out = { - modelAction: (await registry.lookupAction( - `/model/${ref.name}` - )) as ModelAction, - config: { - ...ref.config, - }, - version: ref.version, - }; - } - - if (!out.modelAction) { - throw new GenkitError({ - status: 'NOT_FOUND', - message: `Model ${modelId} not found`, - }); - } - - return out; -} - export class GenerationResponseError extends GenkitError { detail: { response: GenerateResponse; @@ -286,7 +235,7 @@ export async function generate< ): Promise>> { const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const resolvedModel = await resolveModel(registry, resolvedOptions); + const resolvedModel = await resolveModel(registry, resolvedOptions.model); const tools = await toolsToActionRefs(registry, resolvedOptions.tools); diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index a43badd04..513a24b6b 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -20,6 +20,7 @@ import { runWithStreamingCallback, z, } from '@genkit-ai/core'; +import { logger } from '@genkit-ai/core/logging'; import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing'; @@ -39,13 +40,13 @@ import { GenerateResponseData, MessageData, MessageSchema, - ModelAction, Part, + resolveModel, Role, ToolDefinitionSchema, ToolResponsePart, } from '../model.js'; -import { lookupToolByName, ToolAction, toToolDefinition } from '../tool.js'; +import { resolveTools, ToolAction, toToolDefinition } from '../tool.js'; export const GenerateUtilParamSchema = z.object({ /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ @@ -104,37 +105,15 @@ async function generate( rawRequest: z.infer, middleware?: Middleware[] ): Promise { - const model = (await registry.lookupAction( - `/model/${rawRequest.model}` - )) as ModelAction; - if (!model) { - throw new Error(`Model ${rawRequest.model} not found`); - } + const { modelAction: model } = await resolveModel(registry, rawRequest.model); if (model.__action.metadata?.model.stage === 'deprecated') { - console.warn( + logger.warn( `${clc.bold(clc.yellow('Warning:'))} ` + `Model '${model.__action.name}' is deprecated and may be removed in a future release.` ); } - let tools: ToolAction[] | undefined; - if (rawRequest.tools?.length) { - if (!model.__action.metadata?.model.supports?.tools) { - throw new Error( - `Model ${rawRequest.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` - ); - } - tools = await Promise.all( - rawRequest.tools.map(async (toolRef) => { - if (typeof toolRef === 'string') { - return lookupToolByName(registry, toolRef as string); - } else if (toolRef.name) { - return lookupToolByName(registry, toolRef.name); - } - throw `Unable to resolve tool ${JSON.stringify(toolRef)}`; - }) - ); - } + const tools = await resolveTools(registry, rawRequest.tools); const resolvedFormat = rawRequest.output?.format ? await resolveFormat(registry, rawRequest.output?.format) diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 8c50a9d00..b4caa795c 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -17,6 +17,7 @@ import { Action, defineAction, + GenkitError, getStreamingCallback, Middleware, StreamingCallback, @@ -479,3 +480,54 @@ function getPartCounts(parts: Part[]): PartCounts { export type ModelArgument< CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, > = ModelAction | ModelReference | string; + +export interface ResolvedModel< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +> { + modelAction: ModelAction; + config?: z.infer; + version?: string; +} + +export async function resolveModel( + registry: Registry, + model: ModelArgument | undefined +): Promise> { + let out: ResolvedModel; + let modelId: string; + + if (!model) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: 'Must supply a `model` to `generate()` calls.', + }); + } + if (typeof model === 'string') { + modelId = model; + out = { modelAction: await registry.lookupAction(`/model/${model}`) }; + } else if (model.hasOwnProperty('__action')) { + modelId = (model as ModelAction).__action.name; + out = { modelAction: model as ModelAction }; + } else { + const ref = model as ModelReference; + modelId = ref.name; + out = { + modelAction: (await registry.lookupAction( + `/model/${ref.name}` + )) as ModelAction, + config: { + ...ref.config, + }, + version: ref.version, + }; + } + + if (!out.modelAction) { + throw new GenkitError({ + status: 'NOT_FOUND', + message: `Model ${modelId} not found`, + }); + } + + return out; +} diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index cadbbecd9..bfeb37efd 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -95,7 +95,11 @@ export function asTool( export async function resolveTools< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(registry: Registry, tools: ToolArgument[] = []): Promise { +>(registry: Registry, tools?: ToolArgument[]): Promise { + if (!tools || tools.length === 0) { + return []; + } + return await Promise.all( tools.map(async (ref): Promise => { if (typeof ref === 'string') { diff --git a/js/testapps/format-tester/src/tools.ts b/js/testapps/format-tester/src/tools.ts new file mode 100644 index 000000000..80ed923af --- /dev/null +++ b/js/testapps/format-tester/src/tools.ts @@ -0,0 +1,54 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; +import { genkit, z } from 'genkit'; + +const ai = genkit({ plugins: [googleAI()], model: gemini15Flash }); + +const lookupUsers = ai.defineTool( + { + name: 'lookupUsers', + description: 'use this tool to list users', + outputSchema: z.array(z.object({ name: z.string(), id: z.number() })), + }, + async () => [ + { id: 123, name: 'Michael Bleigh' }, + { id: 456, name: 'Pavel Jbanov' }, + { id: 789, name: 'Chris Gill' }, + { id: 1122, name: 'Marissa Christy' }, + ] +); + +async function main() { + const { stream } = await ai.generateStream({ + prompt: + 'use the lookupUsers tool and generate silly nicknames for each, then generate 50 fake users in the same format. return a JSON array.', + output: { + format: 'json', + schema: z.array( + z.object({ id: z.number(), name: z.string(), nickname: z.string() }) + ), + }, + tools: [lookupUsers], + }); + + for await (const chunk of stream) { + console.log('raw:', chunk); + console.log('output:', chunk.output); + } +} +main();