Skip to content

Commit

Permalink
Reduce duplicate codepaths for resolving models and tools. (#1206)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbleigh authored Nov 7, 2024
1 parent 4ec5a4f commit ae2f1e8
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 81 deletions.
55 changes: 2 additions & 53 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ import {
GenerateRequest,
GenerationCommonConfigSchema,
MessageData,
ModelAction,
ModelArgument,
ModelMiddleware,
ModelReference,
Part,
resolveModel,
} from './model.js';
import { ExecutablePrompt } from './prompt.js';
import { resolveTools, ToolArgument, toToolDefinition } from './tool.js';
Expand Down Expand Up @@ -138,56 +137,6 @@ export async function toGenerateRequest(
return out;
}

interface ResolvedModel<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> {
modelAction: ModelAction;
config?: z.infer<CustomOptions>;
version?: string;
}

async function resolveModel(
registry: Registry,
options: GenerateOptions
): Promise<ResolvedModel> {
let model = options.model;
let out: ResolvedModel;
let modelId: string;

if (!model) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'Must supply a `model` to `generate()` calls.',
});
}
if (typeof model === 'string') {
modelId = model;
out = { modelAction: await registry.lookupAction(`/model/${model}`) };
} else if (model.hasOwnProperty('__action')) {
modelId = (model as ModelAction).__action.name;
out = { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference<any>;
modelId = ref.name;
out = {
modelAction: (await registry.lookupAction(
`/model/${ref.name}`
)) as ModelAction,
config: {
...ref.config,
},
version: ref.version,
};
}

if (!out.modelAction) {
throw new GenkitError({
status: 'NOT_FOUND',
message: `Model ${modelId} not found`,
});
}

return out;
}

export class GenerationResponseError extends GenkitError {
detail: {
response: GenerateResponse;
Expand Down Expand Up @@ -286,7 +235,7 @@ export async function generate<
): Promise<GenerateResponse<z.infer<O>>> {
const resolvedOptions: GenerateOptions<O, CustomOptions> =
await Promise.resolve(options);
const resolvedModel = await resolveModel(registry, resolvedOptions);
const resolvedModel = await resolveModel(registry, resolvedOptions.model);

const tools = await toolsToActionRefs(registry, resolvedOptions.tools);

Expand Down
33 changes: 6 additions & 27 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
runWithStreamingCallback,
z,
} from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
Expand All @@ -39,13 +40,13 @@ import {
GenerateResponseData,
MessageData,
MessageSchema,
ModelAction,
Part,
resolveModel,
Role,
ToolDefinitionSchema,
ToolResponsePart,
} from '../model.js';
import { lookupToolByName, ToolAction, toToolDefinition } from '../tool.js';
import { resolveTools, ToolAction, toToolDefinition } from '../tool.js';

export const GenerateUtilParamSchema = z.object({
/** A model name (e.g. `vertexai/gemini-1.0-pro`). */
Expand Down Expand Up @@ -104,37 +105,15 @@ async function generate(
rawRequest: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
): Promise<GenerateResponseData> {
const model = (await registry.lookupAction(
`/model/${rawRequest.model}`
)) as ModelAction;
if (!model) {
throw new Error(`Model ${rawRequest.model} not found`);
}
const { modelAction: model } = await resolveModel(registry, rawRequest.model);
if (model.__action.metadata?.model.stage === 'deprecated') {
console.warn(
logger.warn(
`${clc.bold(clc.yellow('Warning:'))} ` +
`Model '${model.__action.name}' is deprecated and may be removed in a future release.`
);
}

let tools: ToolAction[] | undefined;
if (rawRequest.tools?.length) {
if (!model.__action.metadata?.model.supports?.tools) {
throw new Error(
`Model ${rawRequest.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
);
}
tools = await Promise.all(
rawRequest.tools.map(async (toolRef) => {
if (typeof toolRef === 'string') {
return lookupToolByName(registry, toolRef as string);
} else if (toolRef.name) {
return lookupToolByName(registry, toolRef.name);
}
throw `Unable to resolve tool ${JSON.stringify(toolRef)}`;
})
);
}
const tools = await resolveTools(registry, rawRequest.tools);

const resolvedFormat = rawRequest.output?.format
? await resolveFormat(registry, rawRequest.output?.format)
Expand Down
52 changes: 52 additions & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import {
Action,
defineAction,
GenkitError,
getStreamingCallback,
Middleware,
StreamingCallback,
Expand Down Expand Up @@ -479,3 +480,54 @@ function getPartCounts(parts: Part[]): PartCounts {
export type ModelArgument<
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
> = ModelAction<CustomOptions> | ModelReference<CustomOptions> | string;

export interface ResolvedModel<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> {
modelAction: ModelAction;
config?: z.infer<CustomOptions>;
version?: string;
}

export async function resolveModel<C extends z.ZodTypeAny = z.ZodTypeAny>(
registry: Registry,
model: ModelArgument<C> | undefined
): Promise<ResolvedModel<C>> {
let out: ResolvedModel<C>;
let modelId: string;

if (!model) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'Must supply a `model` to `generate()` calls.',
});
}
if (typeof model === 'string') {
modelId = model;
out = { modelAction: await registry.lookupAction(`/model/${model}`) };
} else if (model.hasOwnProperty('__action')) {
modelId = (model as ModelAction).__action.name;
out = { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference<any>;
modelId = ref.name;
out = {
modelAction: (await registry.lookupAction(
`/model/${ref.name}`
)) as ModelAction,
config: {
...ref.config,
},
version: ref.version,
};
}

if (!out.modelAction) {
throw new GenkitError({
status: 'NOT_FOUND',
message: `Model ${modelId} not found`,
});
}

return out;
}
6 changes: 5 additions & 1 deletion js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ export function asTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
export async function resolveTools<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(registry: Registry, tools: ToolArgument[] = []): Promise<ToolAction[]> {
>(registry: Registry, tools?: ToolArgument[]): Promise<ToolAction[]> {
if (!tools || tools.length === 0) {
return [];
}

return await Promise.all(
tools.map(async (ref): Promise<ToolAction> => {
if (typeof ref === 'string') {
Expand Down
54 changes: 54 additions & 0 deletions js/testapps/format-tester/src/tools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { gemini15Flash, googleAI } from '@genkit-ai/googleai';
import { genkit, z } from 'genkit';

const ai = genkit({ plugins: [googleAI()], model: gemini15Flash });

const lookupUsers = ai.defineTool(
{
name: 'lookupUsers',
description: 'use this tool to list users',
outputSchema: z.array(z.object({ name: z.string(), id: z.number() })),
},
async () => [
{ id: 123, name: 'Michael Bleigh' },
{ id: 456, name: 'Pavel Jbanov' },
{ id: 789, name: 'Chris Gill' },
{ id: 1122, name: 'Marissa Christy' },
]
);

async function main() {
const { stream } = await ai.generateStream({
prompt:
'use the lookupUsers tool and generate silly nicknames for each, then generate 50 fake users in the same format. return a JSON array.',
output: {
format: 'json',
schema: z.array(
z.object({ id: z.number(), name: z.string(), nickname: z.string() })
),
},
tools: [lookupUsers],
});

for await (const chunk of stream) {
console.log('raw:', chunk);
console.log('output:', chunk.output);
}
}
main();

0 comments on commit ae2f1e8

Please sign in to comment.