From ad365cf4e287b2cf042d84f9b0d2b111ff2ffbd7 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 29 Apr 2025 09:04:47 -0400 Subject: [PATCH 1/2] feat(js/plugins/{googleai,vertexai}): implemented dynamic model listing for googleai, and new model lookup API for googleai & vertexai --- js/ai/src/embedder.ts | 30 +++- js/ai/src/index.ts | 3 + js/ai/src/model.ts | 29 ++++ js/genkit/src/common.ts | 3 + js/plugins/googleai/src/gemini.ts | 8 +- js/plugins/googleai/src/index.ts | 174 +++++++++++++++++++---- js/plugins/googleai/src/list-models.ts | 60 ++++++++ js/plugins/googleai/tests/gemini_test.ts | 9 +- js/plugins/vertexai/src/index.ts | 103 +++++++++++--- js/testapps/flow-simple-ai/src/index.ts | 41 +++++- 10 files changed, 404 insertions(+), 56 deletions(-) create mode 100644 js/plugins/googleai/src/list-models.ts diff --git a/js/ai/src/embedder.ts b/js/ai/src/embedder.ts index fc84ac694f..7aa86d7dd3 100644 --- a/js/ai/src/embedder.ts +++ b/js/ai/src/embedder.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Action, defineAction, z } from '@genkit-ai/core'; +import { Action, ActionMetadata, defineAction, z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; @@ -293,3 +293,31 @@ export function embedderRef< ): EmbedderReference { return { ...options }; } + +/** + * Packages embedder information into ActionMetadata object. + */ +export function embedderActionMetadata({ + name, + info, + configSchema, +}: { + name: string; + info?: EmbedderInfo; + configSchema?: z.ZodTypeAny; +}): ActionMetadata { + return { + actionType: 'embedder', + name: name, + inputJsonSchema: toJsonSchema({ schema: EmbedRequestSchema }), + outputJsonSchema: toJsonSchema({ schema: EmbedResponseSchema }), + metadata: { + embedder: { + ...info, + customOptions: configSchema + ? toJsonSchema({ schema: configSchema }) + : undefined, + }, + }, + } as ActionMetadata; +} diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index f68a2941eb..48ed7e1fd2 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -23,6 +23,7 @@ export { } from './document.js'; export { embed, + embedderActionMetadata, embedderRef, type EmbedderAction, type EmbedderArgument, @@ -65,6 +66,8 @@ export { ModelResponseSchema, PartSchema, RoleSchema, + modelActionMetadata, + modelRef, type GenerateRequest, type GenerateRequestData, type GenerateResponseChunkData, diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 3facb201bb..b731b5347f 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -16,6 +16,7 @@ import { Action, + ActionMetadata, defineAction, GenkitError, getStreamingCallback, @@ -507,6 +508,34 @@ export interface ModelReference { withVersion(version: string): ModelReference; } +/** + * Packages model information into ActionMetadata object. + */ +export function modelActionMetadata({ + name, + info, + configSchema, +}: { + name: string; + info?: ModelInfo; + configSchema?: z.ZodTypeAny; +}): ActionMetadata { + return { + actionType: 'model', + name: name, + inputJsonSchema: toJsonSchema({ schema: GenerateRequestSchema }), + outputJsonSchema: toJsonSchema({ schema: GenerateResponseSchema }), + metadata: { + model: { + ...info, + customOptions: configSchema + ? toJsonSchema({ schema: configSchema }) + : undefined, + }, + }, + } as ActionMetadata; +} + /** Cretes a model reference. */ export function modelRef< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, diff --git a/js/genkit/src/common.ts b/js/genkit/src/common.ts index aabebc562b..83ed928604 100644 --- a/js/genkit/src/common.ts +++ b/js/genkit/src/common.ts @@ -32,9 +32,12 @@ export { RoleSchema, ToolCallSchema, ToolSchema, + embedderActionMetadata, embedderRef, evaluatorRef, indexerRef, + modelActionMetadata, + modelRef, rerankerRef, retrieverRef, type DocumentData, diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 3b287f5dbc..750db37303 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -821,7 +821,7 @@ export function defineGoogleAIModel({ : name; const model: ModelReference = - SUPPORTED_GEMINI_MODELS[name] ?? + SUPPORTED_GEMINI_MODELS[apiModelName] ?? modelRef({ name: `googleai/${apiModelName}`, info: { @@ -839,7 +839,7 @@ export function defineGoogleAIModel({ }); const middleware: ModelMiddleware[] = []; - if (SUPPORTED_V1_MODELS[name]) { + if (SUPPORTED_V1_MODELS[apiModelName]) { middleware.push(simulateSystemPrompt()); } if (model.info?.supports?.media) { @@ -896,7 +896,7 @@ export function defineGoogleAIModel({ // systemInstructions to be provided as a separate input. The first // message detected with role=system will be used for systemInstructions. let systemInstruction: GeminiMessage | undefined = undefined; - if (SUPPORTED_V15_MODELS[name]) { + if (SUPPORTED_V15_MODELS[apiModelName]) { const systemMessage = messages.find((m) => m.role === 'system'); if (systemMessage) { messages.splice(messages.indexOf(systemMessage), 1); @@ -982,7 +982,7 @@ export function defineGoogleAIModel({ } as StartChatParams; const modelVersion = (request.config?.version || model.version || - name) as string; + apiModelName) as string; const cacheConfigDetails = extractCacheConfig(request); const { chatRequest: updatedChatRequest, cache } = diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index 09d89629a9..90094f2161 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -14,19 +14,30 @@ * limitations under the License. */ -import { Genkit, ModelReference } from 'genkit'; +import { + ActionMetadata, + embedderActionMetadata, + embedderRef, + EmbedderReference, + Genkit, + modelActionMetadata, + ModelReference, + z, +} from 'genkit'; +import { logger } from 'genkit/logging'; +import { modelRef } from 'genkit/model'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; +import { getApiKeyFromEnvVar } from './common.js'; import { - SUPPORTED_MODELS as EMBEDDER_MODELS, defineGoogleAIEmbedder, + SUPPORTED_MODELS as EMBEDDER_MODELS, + GeminiEmbeddingConfig, + GeminiEmbeddingConfigSchema, textEmbedding004, textEmbeddingGecko001, } from './embedder.js'; import { - GeminiConfigSchema, - SUPPORTED_V15_MODELS, - SUPPORTED_V1_MODELS, defineGoogleAIModel, gemini, gemini10Pro, @@ -40,9 +51,13 @@ import { gemini25FlashPreview0417, gemini25ProExp0325, gemini25ProPreview0325, + GeminiConfigSchema, + SUPPORTED_V15_MODELS, + SUPPORTED_V1_MODELS, type GeminiConfig, type GeminiVersionString, } from './gemini.js'; +import { listModels } from './list-models.js'; export { gemini, gemini10Pro, @@ -161,48 +176,159 @@ async function resolver( actionName: string, options?: PluginOptions ) { - // TODO: also support other actions like 'embedder' switch (actionType) { case 'model': - await resolveModel(ai, actionName, options); + resolveModel(ai, actionName, options); + break; + case 'embedder': + resolveEmbedder(ai, actionName, options); break; default: // no-op } } -async function resolveModel( +function resolveModel(ai: Genkit, actionName: string, options?: PluginOptions) { + const modelRef = gemini(actionName); + defineGoogleAIModel({ + ai, + name: modelRef.name, + apiKey: options?.apiKey, + baseUrl: options?.baseUrl, + info: { + ...modelRef.info, + label: `Google AI - ${actionName}`, + }, + debugTraces: options?.experimental_debugTraces, + }); +} + +function resolveEmbedder( ai: Genkit, actionName: string, options?: PluginOptions ) { - if (actionName.includes('gemini')) { - const modelRef = gemini(actionName); - defineGoogleAIModel({ - ai, - name: modelRef.name, - apiKey: options?.apiKey, - baseUrl: options?.baseUrl, - info: { - ...modelRef.info, - label: `Google AI - ${actionName}`, - }, - debugTraces: options?.experimental_debugTraces, - }); + defineGoogleAIEmbedder(ai, `googleai/${actionName}`, { + apiKey: options?.apiKey, + }); +} + +async function listActions(options?: PluginOptions): Promise { + const apiKey = options?.apiKey || getApiKeyFromEnvVar(); + if (!apiKey) { + // If API key is not configured we don't want to error, just return empty. + // In practice it will never actually reach this point without the API key, + // plugin initializer will fail before that. + logger.error( + 'Pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.' + ); + return []; } - // TODO: Support other models + + const models = await listModels( + options?.baseUrl || 'https://generativelanguage.googleapis.com', + apiKey + ); + + return [ + // Models + ...models + .filter((m) => m.supportedGenerationMethods.includes('generateContent')) + // Filter out deprecated + .filter((m) => !m.description || !m.description.includes('deprecated')) + .map((m) => { + const ref = gemini( + m.name.startsWith('models/') + ? m.name.substring('models/'.length) + : m.name + ); + + return modelActionMetadata({ + name: ref.name, + info: ref.info, + configSchema: GeminiConfigSchema, + }); + }), + // Embedders + ...models + .filter((m) => m.supportedGenerationMethods.includes('embedContent')) + // Filter out deprecated + .filter((m) => !m.description || !m.description.includes('deprecated')) + .map((m) => { + const name = + 'googleai/' + + (m.name.startsWith('models/') + ? m.name.substring('models/'.length) + : m.name); + + return embedderActionMetadata({ + name, + configSchema: GeminiEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Google Gen AI - ${name}`, + supports: { + input: ['text'], + }, + }, + }); + }), + ]; } /** * Google Gemini Developer API plugin. */ -export function googleAI(options?: PluginOptions): GenkitPlugin { +export function googleAIPlugin(options?: PluginOptions): GenkitPlugin { + let listActionsCache; return genkitPlugin( 'googleai', async (ai: Genkit) => await initializer(ai, options), async (ai: Genkit, actionType: ActionType, actionName: string) => - await resolver(ai, actionType, actionName, options) + await resolver(ai, actionType, actionName, options), + async () => { + if (listActionsCache) return listActionsCache; + listActionsCache = await listActions(options); + return listActionsCache; + } ); } +export type GoogleAIPlugin = { + (params?: PluginOptions): GenkitPlugin; + model( + name: GeminiVersionString, + config?: z.infer + ): ModelReference; + embedder( + name: string, + config?: GeminiEmbeddingConfig + ): EmbedderReference; +}; + +/** + * Google Gemini Developer API plugin. + */ +export const googleAI = googleAIPlugin as GoogleAIPlugin; +googleAI.model = ( + name: GeminiVersionString, + config?: GeminiConfig +): ModelReference => { + return modelRef({ + name: `googleai/${name}`, + config, + configSchema: GeminiConfigSchema, + }); +}; +googleAI.embedder = ( + name: string, + config?: GeminiEmbeddingConfig +): EmbedderReference => { + return embedderRef({ + name: `googleai/${name}`, + config, + configSchema: GeminiEmbeddingConfigSchema, + }); +}; + export default googleAI; diff --git a/js/plugins/googleai/src/list-models.ts b/js/plugins/googleai/src/list-models.ts new file mode 100644 index 0000000000..adf8711bba --- /dev/null +++ b/js/plugins/googleai/src/list-models.ts @@ -0,0 +1,60 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Gemini model definition +export interface Model { + name: string; + baseModelId: string; + version: string; + displayName: string; + description: string; + inputTokenLimit: number; + outputTokenLimit: number; + supportedGenerationMethods: string[]; + temperature: number; + maxTemperature: number; + topP: number; + topK: number; +} + +// Gemini list models response +interface ListModelsResponse { + models: Model[]; + nextPageToken?: string; +} + +/** + * List Gemini models by making an RPC call to the API. + * + * https://ai.google.dev/api/models#method:-models.list + */ +export async function listModels( + baseUrl: string, + apiKey: string +): Promise { + // We call the gemini list local models api: + const res = await fetch( + `${baseUrl}/v1beta/models?pageSize=1000&key=${apiKey}`, + { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + } + ); + const modelResponse = JSON.parse(await res.text()) as ListModelsResponse; + return modelResponse.models; +} diff --git a/js/plugins/googleai/tests/gemini_test.ts b/js/plugins/googleai/tests/gemini_test.ts index 45d2ad8732..89d916b759 100644 --- a/js/plugins/googleai/tests/gemini_test.ts +++ b/js/plugins/googleai/tests/gemini_test.ts @@ -419,13 +419,16 @@ describe('plugin', () => { describe('plugin - no env', () => { it('should throw when registering models with no apiKey and no env', async () => { + delete process.env.GEMINI_API_KEY; + delete process.env.GOOGLE_API_KEY; + delete process.env.GOOGLE_GENAI_API_KEY; const ai = genkit({ plugins: [googleAI()] }); - assert.rejects(ai.registry.initializeAllPlugins()); + await assert.rejects(ai.registry.initializeAllPlugins()); }); - it('should not throw when registering models with {apiKey: false} and no env', () => { + it('should not throw when registering models with {apiKey: false} and no env', async () => { const ai = genkit({ plugins: [googleAI({ apiKey: false })] }); - assert.doesNotReject(ai.registry.initializeAllPlugins()); + await assert.doesNotReject(ai.registry.initializeAllPlugins()); }); }); diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index aba1ecb1cd..42c8187af8 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -20,23 +20,31 @@ * @module / */ -import { Genkit } from 'genkit'; +import { + embedderRef, + EmbedderReference, + Genkit, + modelRef, + ModelReference, + z, +} from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; import { getDerivedParams } from './common/index.js'; import { PluginOptions } from './common/types.js'; import { - SUPPORTED_EMBEDDER_MODELS, defineVertexAIEmbedder, multimodalEmbedding001, + SUPPORTED_EMBEDDER_MODELS, textEmbedding004, textEmbedding005, textEmbeddingGecko003, textEmbeddingGeckoMultilingual001, textMultilingualEmbedding002, + VertexEmbeddingConfig, + VertexEmbeddingConfigSchema, } from './embedder.js'; import { - SUPPORTED_GEMINI_MODELS, defineGeminiKnownModel, defineGeminiModel, gemini, @@ -51,14 +59,17 @@ import { gemini25FlashPreview0417, gemini25ProExp0325, gemini25ProPreview0325, + GeminiConfigSchema, + SUPPORTED_GEMINI_MODELS, type GeminiConfig, + type GeminiVersionString, } from './gemini.js'; import { - SUPPORTED_IMAGEN_MODELS, imagen2, imagen3, imagen3Fast, imagenModel, + SUPPORTED_IMAGEN_MODELS, } from './imagen.js'; export { type PluginOptions } from './common/types.js'; export { @@ -74,6 +85,7 @@ export { gemini25FlashPreview0417, gemini25ProExp0325, gemini25ProPreview0325, + GeminiConfigSchema, imagen2, imagen3, imagen3Fast, @@ -84,6 +96,7 @@ export { textEmbeddingGeckoMultilingual001, textMultilingualEmbedding002, type GeminiConfig, + type GeminiVersionString, }; async function initializer(ai: Genkit, options?: PluginOptions) { @@ -145,6 +158,9 @@ async function resolver( case 'model': await resolveModel(ai, actionName, options); break; + case 'embedder': + await resolveEmbedder(ai, actionName, options); + break; default: // no-op } @@ -157,28 +173,35 @@ async function resolveModel( ) { const { projectId, location, vertexClientFactory } = await getDerivedParams(options); - if (actionName.includes('gemini')) { - const modelRef = gemini(actionName); - defineGeminiModel({ - ai, - modelName: modelRef.name, - version: actionName, - modelInfo: modelRef.info, - vertexClientFactory, - options: { - projectId, - location, - }, - debugTraces: options?.experimental_debugTraces, - }); - } - // TODO: Support other models + const modelRef = gemini(actionName); + defineGeminiModel({ + ai, + modelName: modelRef.name, + version: actionName, + modelInfo: modelRef.info, + vertexClientFactory, + options: { + projectId, + location, + }, + debugTraces: options?.experimental_debugTraces, + }); +} + +async function resolveEmbedder( + ai: Genkit, + actionName: string, + options?: PluginOptions +) { + const { projectId, location, authClient } = await getDerivedParams(options); + + defineVertexAIEmbedder(ai, actionName, authClient, { projectId, location }); } /** * Add Google Cloud Vertex AI to Genkit. Includes Gemini and Imagen models and text embedder. */ -export function vertexAI(options?: PluginOptions): GenkitPlugin { +function vertexAIPlugin(options?: PluginOptions): GenkitPlugin { return genkitPlugin( 'vertexai', async (ai: Genkit) => await initializer(ai, options), @@ -187,4 +210,42 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { ); } +export type VertexAIPlugin = { + (params?: PluginOptions): GenkitPlugin; + model( + name: GeminiVersionString, + config?: z.infer + ): ModelReference; + embedder( + name: string, + config?: VertexEmbeddingConfig + ): EmbedderReference; +}; + +/** + * Google Cloud Vertex AI plugin for Genkit. + * Includes Gemini and Imagen models and text embedder. + */ +export const vertexAI = vertexAIPlugin as VertexAIPlugin; +vertexAI.model = ( + name: GeminiVersionString, + config?: GeminiConfig +): ModelReference => { + return modelRef({ + name: `vertexai/${name}`, + config, + configSchema: GeminiConfigSchema, + }); +}; +vertexAI.embedder = ( + name: string, + config?: VertexEmbeddingConfig +): EmbedderReference => { + return embedderRef({ + name: `vertexai/${name}`, + config, + configSchema: VertexEmbeddingConfigSchema, + }); +}; + export default vertexAI; diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 0adfe575e1..2ba827ed6b 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -139,16 +139,36 @@ export const drawPictureFlow = ai.defineFlow( } ); -export const streamFlow = ai.defineFlow( +export const streamFlowVertex = ai.defineFlow( { - name: 'streamFlow', + name: 'streamFlowVertex', inputSchema: z.string(), outputSchema: z.string(), streamSchema: z.string(), }, async (prompt, { sendChunk }) => { const { response, stream } = ai.generateStream({ - model: gemini15Flash, + model: vertexAI.model('gemini-2.0-flash-001', { temperature: 0.77 }), + prompt, + }); + + for await (const chunk of stream) { + sendChunk(chunk.content[0].text!); + } + + return (await response).text; + } +); +export const streamFlowGemini = ai.defineFlow( + { + name: 'streamFlowGemini', + inputSchema: z.string(), + outputSchema: z.string(), + streamSchema: z.string(), + }, + async (prompt, { sendChunk }) => { + const { response, stream } = ai.generateStream({ + model: googleAI.model('gemini-2.0-flash-001', { temperature: 0.77 }), prompt, }); @@ -868,3 +888,18 @@ ai.defineFlow('geminiEnum', async (thing) => { return output; }); + +ai.defineFlow('embedders-tester', async () => { + console.log( + await ai.embed({ + content: 'hello world', + embedder: googleAI.embedder('text-embedding-004'), + }) + ); + console.log( + await ai.embed({ + content: 'hello world', + embedder: vertexAI.embedder('text-embedding-004'), + }) + ); +}); From c4cb6acc82640d764db34ed24fa27e24dbd4d0f5 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 29 Apr 2025 11:14:47 -0400 Subject: [PATCH 2/2] default to multimodal embedding --- js/plugins/googleai/src/embedder.ts | 2 +- js/plugins/vertexai/src/embedder.ts | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/js/plugins/googleai/src/embedder.ts b/js/plugins/googleai/src/embedder.ts index 48b840ccce..9921031c12 100644 --- a/js/plugins/googleai/src/embedder.ts +++ b/js/plugins/googleai/src/embedder.ts @@ -110,7 +110,7 @@ export function defineGoogleAIEmbedder( dimensions: 768, label: `Google AI - ${name}`, supports: { - input: ['text'], + input: ['text', 'image', 'video'], }, }, }); diff --git a/js/plugins/vertexai/src/embedder.ts b/js/plugins/vertexai/src/embedder.ts index 032da1c8b9..666db29792 100644 --- a/js/plugins/vertexai/src/embedder.ts +++ b/js/plugins/vertexai/src/embedder.ts @@ -247,7 +247,19 @@ export function defineVertexAIEmbedder( client: GoogleAuth, options: PluginOptions ): EmbedderAction { - const embedder = SUPPORTED_EMBEDDER_MODELS[name]; + const embedder = + SUPPORTED_EMBEDDER_MODELS[name] ?? + embedderRef({ + name: name, + configSchema: VertexEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Vertex AI - ${name}`, + supports: { + input: ['text', 'image', 'video'], + }, + }, + }); const predictClients: Record< string, PredictClient