Skip to content

feat(js/plugins/{googleai,vertexai}): implemented dynamic model listing for googleai, and new model lookup API for googleai & vertexai #2839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion js/ai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { Action, defineAction, z } from '@genkit-ai/core';
import { Action, ActionMetadata, defineAction, z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { Document, DocumentData, DocumentDataSchema } from './document.js';
Expand Down Expand Up @@ -293,3 +293,31 @@ export function embedderRef<
): EmbedderReference<CustomOptionsSchema> {
return { ...options };
}

/**
* Packages embedder information into ActionMetadata object.
*/
export function embedderActionMetadata({
name,
info,
configSchema,
}: {
name: string;
info?: EmbedderInfo;
configSchema?: z.ZodTypeAny;
}): ActionMetadata {
return {
actionType: 'embedder',
name: name,
inputJsonSchema: toJsonSchema({ schema: EmbedRequestSchema }),
outputJsonSchema: toJsonSchema({ schema: EmbedResponseSchema }),
metadata: {
embedder: {
...info,
customOptions: configSchema
? toJsonSchema({ schema: configSchema })
: undefined,
},
},
} as ActionMetadata;
}
3 changes: 3 additions & 0 deletions js/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export {
} from './document.js';
export {
embed,
embedderActionMetadata,
embedderRef,
type EmbedderAction,
type EmbedderArgument,
Expand Down Expand Up @@ -65,6 +66,8 @@ export {
ModelResponseSchema,
PartSchema,
RoleSchema,
modelActionMetadata,
modelRef,
type GenerateRequest,
type GenerateRequestData,
type GenerateResponseChunkData,
Expand Down
29 changes: 29 additions & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import {
Action,
ActionMetadata,
defineAction,
GenkitError,
getStreamingCallback,
Expand Down Expand Up @@ -507,6 +508,34 @@ export interface ModelReference<CustomOptions extends z.ZodTypeAny> {
withVersion(version: string): ModelReference<CustomOptions>;
}

/**
* Packages model information into ActionMetadata object.
*/
export function modelActionMetadata({
name,
info,
configSchema,
}: {
name: string;
info?: ModelInfo;
configSchema?: z.ZodTypeAny;
}): ActionMetadata {
return {
actionType: 'model',
name: name,
inputJsonSchema: toJsonSchema({ schema: GenerateRequestSchema }),
outputJsonSchema: toJsonSchema({ schema: GenerateResponseSchema }),
metadata: {
model: {
...info,
customOptions: configSchema
? toJsonSchema({ schema: configSchema })
: undefined,
},
},
} as ActionMetadata;
}

/** Cretes a model reference. */
export function modelRef<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
Expand Down
3 changes: 3 additions & 0 deletions js/genkit/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ export {
RoleSchema,
ToolCallSchema,
ToolSchema,
embedderActionMetadata,
embedderRef,
evaluatorRef,
indexerRef,
modelActionMetadata,
modelRef,
rerankerRef,
retrieverRef,
type DocumentData,
Expand Down
2 changes: 1 addition & 1 deletion js/plugins/googleai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export function defineGoogleAIEmbedder(
dimensions: 768,
label: `Google AI - ${name}`,
supports: {
input: ['text'],
input: ['text', 'image', 'video'],
},
},
});
Expand Down
8 changes: 4 additions & 4 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ export function defineGoogleAIModel({
: name;

const model: ModelReference<z.ZodTypeAny> =
SUPPORTED_GEMINI_MODELS[name] ??
SUPPORTED_GEMINI_MODELS[apiModelName] ??
modelRef({
name: `googleai/${apiModelName}`,
info: {
Expand All @@ -839,7 +839,7 @@ export function defineGoogleAIModel({
});

const middleware: ModelMiddleware[] = [];
if (SUPPORTED_V1_MODELS[name]) {
if (SUPPORTED_V1_MODELS[apiModelName]) {
middleware.push(simulateSystemPrompt());
}
if (model.info?.supports?.media) {
Expand Down Expand Up @@ -896,7 +896,7 @@ export function defineGoogleAIModel({
// systemInstructions to be provided as a separate input. The first
// message detected with role=system will be used for systemInstructions.
let systemInstruction: GeminiMessage | undefined = undefined;
if (SUPPORTED_V15_MODELS[name]) {
if (SUPPORTED_V15_MODELS[apiModelName]) {
const systemMessage = messages.find((m) => m.role === 'system');
if (systemMessage) {
messages.splice(messages.indexOf(systemMessage), 1);
Expand Down Expand Up @@ -982,7 +982,7 @@ export function defineGoogleAIModel({
} as StartChatParams;
const modelVersion = (request.config?.version ||
model.version ||
name) as string;
apiModelName) as string;
const cacheConfigDetails = extractCacheConfig(request);

const { chatRequest: updatedChatRequest, cache } =
Expand Down
174 changes: 150 additions & 24 deletions js/plugins/googleai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,30 @@
* limitations under the License.
*/

import { Genkit, ModelReference } from 'genkit';
import {
ActionMetadata,
embedderActionMetadata,
embedderRef,
EmbedderReference,
Genkit,
modelActionMetadata,
ModelReference,
z,
} from 'genkit';
import { logger } from 'genkit/logging';
import { modelRef } from 'genkit/model';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import { getApiKeyFromEnvVar } from './common.js';
import {
SUPPORTED_MODELS as EMBEDDER_MODELS,
defineGoogleAIEmbedder,
SUPPORTED_MODELS as EMBEDDER_MODELS,
GeminiEmbeddingConfig,
GeminiEmbeddingConfigSchema,
textEmbedding004,
textEmbeddingGecko001,
} from './embedder.js';
import {
GeminiConfigSchema,
SUPPORTED_V15_MODELS,
SUPPORTED_V1_MODELS,
defineGoogleAIModel,
gemini,
gemini10Pro,
Expand All @@ -40,9 +51,13 @@ import {
gemini25FlashPreview0417,
gemini25ProExp0325,
gemini25ProPreview0325,
GeminiConfigSchema,
SUPPORTED_V15_MODELS,
SUPPORTED_V1_MODELS,
type GeminiConfig,
type GeminiVersionString,
} from './gemini.js';
import { listModels } from './list-models.js';
export {
gemini,
gemini10Pro,
Expand Down Expand Up @@ -161,48 +176,159 @@ async function resolver(
actionName: string,
options?: PluginOptions
) {
// TODO: also support other actions like 'embedder'
switch (actionType) {
case 'model':
await resolveModel(ai, actionName, options);
resolveModel(ai, actionName, options);
break;
case 'embedder':
resolveEmbedder(ai, actionName, options);
break;
default:
// no-op
}
}

async function resolveModel(
function resolveModel(ai: Genkit, actionName: string, options?: PluginOptions) {
const modelRef = gemini(actionName);
defineGoogleAIModel({
ai,
name: modelRef.name,
apiKey: options?.apiKey,
baseUrl: options?.baseUrl,
info: {
...modelRef.info,
label: `Google AI - ${actionName}`,
},
debugTraces: options?.experimental_debugTraces,
});
}

function resolveEmbedder(
ai: Genkit,
actionName: string,
options?: PluginOptions
) {
if (actionName.includes('gemini')) {
const modelRef = gemini(actionName);
defineGoogleAIModel({
ai,
name: modelRef.name,
apiKey: options?.apiKey,
baseUrl: options?.baseUrl,
info: {
...modelRef.info,
label: `Google AI - ${actionName}`,
},
debugTraces: options?.experimental_debugTraces,
});
defineGoogleAIEmbedder(ai, `googleai/${actionName}`, {
apiKey: options?.apiKey,
});
}

async function listActions(options?: PluginOptions): Promise<ActionMetadata[]> {
const apiKey = options?.apiKey || getApiKeyFromEnvVar();
if (!apiKey) {
// If API key is not configured we don't want to error, just return empty.
// In practice it will never actually reach this point without the API key,
// plugin initializer will fail before that.
logger.error(
'Pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.'
);
return [];
}
// TODO: Support other models

const models = await listModels(
options?.baseUrl || 'https://generativelanguage.googleapis.com',
apiKey
);

return [
// Models
...models
.filter((m) => m.supportedGenerationMethods.includes('generateContent'))
// Filter out deprecated
.filter((m) => !m.description || !m.description.includes('deprecated'))
.map((m) => {
const ref = gemini(
m.name.startsWith('models/')
? m.name.substring('models/'.length)
: m.name
);

return modelActionMetadata({
name: ref.name,
info: ref.info,
configSchema: GeminiConfigSchema,
});
}),
// Embedders
...models
.filter((m) => m.supportedGenerationMethods.includes('embedContent'))
// Filter out deprecated
.filter((m) => !m.description || !m.description.includes('deprecated'))
.map((m) => {
const name =
'googleai/' +
(m.name.startsWith('models/')
? m.name.substring('models/'.length)
: m.name);

return embedderActionMetadata({
name,
configSchema: GeminiEmbeddingConfigSchema,
info: {
dimensions: 768,
label: `Google Gen AI - ${name}`,
supports: {
input: ['text'],
},
},
});
}),
];
}

/**
* Google Gemini Developer API plugin.
*/
export function googleAI(options?: PluginOptions): GenkitPlugin {
export function googleAIPlugin(options?: PluginOptions): GenkitPlugin {
let listActionsCache;
return genkitPlugin(
'googleai',
async (ai: Genkit) => await initializer(ai, options),
async (ai: Genkit, actionType: ActionType, actionName: string) =>
await resolver(ai, actionType, actionName, options)
await resolver(ai, actionType, actionName, options),
async () => {
if (listActionsCache) return listActionsCache;
listActionsCache = await listActions(options);
return listActionsCache;
}
);
}

export type GoogleAIPlugin = {
(params?: PluginOptions): GenkitPlugin;
model(
name: GeminiVersionString,
config?: z.infer<typeof GeminiConfigSchema>
): ModelReference<typeof GeminiConfigSchema>;
embedder(
name: string,
config?: GeminiEmbeddingConfig
): EmbedderReference<typeof GeminiEmbeddingConfigSchema>;
};

/**
* Google Gemini Developer API plugin.
*/
export const googleAI = googleAIPlugin as GoogleAIPlugin;
googleAI.model = (
name: GeminiVersionString,
config?: GeminiConfig
): ModelReference<typeof GeminiConfigSchema> => {
return modelRef({
name: `googleai/${name}`,
config,
configSchema: GeminiConfigSchema,
});
};
googleAI.embedder = (
name: string,
config?: GeminiEmbeddingConfig
): EmbedderReference<typeof GeminiEmbeddingConfigSchema> => {
return embedderRef({
name: `googleai/${name}`,
config,
configSchema: GeminiEmbeddingConfigSchema,
});
};

export default googleAI;
Loading