From 180f091d6bd8704ce1f1fa4553e417bf0b19c3cb Mon Sep 17 00:00:00 2001 From: Daniel La Rocque Date: Thu, 2 Jan 2025 12:38:21 -0600 Subject: [PATCH] Introduce VertexAIModel base class, add documentation, and respond to other comments --- common/api-review/vertexai.api.md | 145 ++++------------- packages/vertexai/src/api.test.ts | 2 +- packages/vertexai/src/api.ts | 10 +- .../src/models/generative-model.test.ts | 22 --- .../vertexai/src/models/generative-model.ts | 50 +----- .../vertexai/src/models/imagen-model.test.ts | 66 +------- packages/vertexai/src/models/imagen-model.ts | 151 +++++++---------- packages/vertexai/src/models/index.ts | 3 + .../src/models/vertexai-model.test.ts | 103 ++++++++++++ .../vertexai/src/models/vertexai-model.ts | 94 +++++++++++ .../src/requests/imagen-image-format.ts | 48 ++++++ .../src/requests/request-helpers.test.ts | 2 +- .../vertexai/src/requests/request-helpers.ts | 12 +- .../vertexai/src/requests/response-helpers.ts | 4 +- .../vertexai/src/types/imagen/internal.ts | 85 +++++++++- .../vertexai/src/types/imagen/requests.ts | 152 +++++++++--------- .../vertexai/src/types/imagen/responses.ts | 61 ++++--- packages/vertexai/src/types/responses.ts | 5 - 18 files changed, 552 insertions(+), 463 deletions(-) create mode 100644 packages/vertexai/src/models/index.ts create mode 100644 packages/vertexai/src/models/vertexai-model.test.ts create mode 100644 packages/vertexai/src/models/vertexai-model.ts create mode 100644 packages/vertexai/src/requests/imagen-image-format.ts diff --git a/common/api-review/vertexai.api.md b/common/api-review/vertexai.api.md index aa35a3f024f..5b5c00b122e 100644 --- a/common/api-review/vertexai.api.md +++ b/common/api-review/vertexai.api.md @@ -323,7 +323,7 @@ export interface GenerativeContentBlob { } // @public -export class GenerativeModel { +export class GenerativeModel extends VertexAIModel { constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions); countTokens(request: CountTokensRequest | string | Array): Promise; generateContent(request: GenerateContentRequest | string | Array): Promise; @@ -331,8 +331,6 @@ export class GenerativeModel { // (undocumented) generationConfig: GenerationConfig; // (undocumented) - model: string; - // (undocumented) requestOptions?: RequestOptions; // (undocumented) safetySettings: SafetySetting[]; @@ -432,77 +430,53 @@ export enum HarmSeverity { HARM_SEVERITY_NEGLIGIBLE = "HARM_SEVERITY_NEGLIGIBLE" } -// @public (undocumented) +// @public export enum ImagenAspectRatio { // (undocumented) - CLASSIC_LANDSCAPE = "4:3", + LANDSCAPE_16x9 = "16:9", // (undocumented) - CLASSIC_PORTRAIT = "3:4", + LANDSCAPE_3x4 = "3:4", // (undocumented) - PORTRAIT = "9:16", + PORTRAIT_4x3 = "4:3", // (undocumented) - SQUARE = "1:1", + PORTRAIT_9x16 = "9:16", // (undocumented) - WIDESCREEN = "16:9" + SQUARE = "1:1" } -// Warning: (ae-incompatible-release-tags) The symbol "ImagenGCSImage" is marked as @public, but its signature references "ImagenImage" which is marked as @internal -// // @public -export interface ImagenGCSImage extends ImagenImage { +export interface ImagenGCSImage { gcsURI: string; + mimeType: string; } -// @public (undocumented) +// @public export interface ImagenGCSImageResponse { - // (undocumented) filteredReason?: string; - // (undocumented) images: ImagenGCSImage[]; } -// @public (undocumented) +// @public export interface ImagenGenerationConfig { - // (undocumented) aspectRatio?: ImagenAspectRatio; - // (undocumented) negativePrompt?: string; - // (undocumented) numberOfImages?: number; } -// Warning: (ae-internal-missing-underscore) The name "ImagenImage" should be prefixed with an underscore because the declaration is marked as @internal -// -// @internal -export interface ImagenImage { - // (undocumented) - mimeType: string; -} - -// @public (undocumented) -export interface ImagenImageFormat { +// @public +export class ImagenImageFormat { // (undocumented) compressionQuality?: number; + static jpeg(compressionQuality: number): ImagenImageFormat; // (undocumented) mimeType: string; + static png(): ImagenImageFormat; } -// @public (undocumented) -export interface ImagenImageReponse { - // (undocumented) - filteredReason?: string; - // Warning: (ae-incompatible-release-tags) The symbol "images" is marked as @public, but its signature references "ImagenImage" which is marked as @internal - // - // (undocumented) - images: ImagenImage[]; -} - -// Warning: (ae-incompatible-release-tags) The symbol "ImagenInlineImage" is marked as @public, but its signature references "ImagenImage" which is marked as @internal -// // @public -export interface ImagenInlineImage extends ImagenImage { - // (undocumented) +export interface ImagenInlineImage { bytesBase64Encoded: string; + mimeType: string; } // @public @@ -512,63 +486,43 @@ export interface ImagenInlineImageResponse { } // @public -export class ImagenModel { +export class ImagenModel extends VertexAIModel { constructor(vertexAI: VertexAI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined); generateImages(prompt: string, imagenRequestOptions?: ImagenGenerationConfig): Promise; generateImagesGCS(prompt: string, gcsURI: string, imagenRequestOptions?: ImagenGenerationConfig): Promise; + readonly modelConfig: ImagenModelConfig; // (undocumented) - model: string; - } + readonly requestOptions?: RequestOptions | undefined; +} -// @public (undocumented) +// @public export interface ImagenModelConfig { - // (undocumented) addWatermark?: boolean; - // (undocumented) imageFormat?: ImagenImageFormat; - // (undocumented) safetySettings?: ImagenSafetySettings; } // @public export interface ImagenModelParams extends ImagenModelConfig { - // (undocumented) model: string; } -// @public (undocumented) +// @public export enum ImagenPersonFilterLevel { - // (undocumented) ALLOW_ADULT = "allow_adult", - // (undocumented) ALLOW_ALL = "allow_all", - // (undocumented) BLOCK_ALL = "dont_allow" } -// Warning: (ae-internal-missing-underscore) The name "ImagenRequestConfig" should be prefixed with an underscore because the declaration is marked as @internal -// -// @internal -export interface ImagenRequestConfig extends ImagenModelConfig, ImagenGenerationConfig { - // (undocumented) - gcsURI?: string; - // (undocumented) - prompt: string; -} - -// @public (undocumented) +// @public export enum ImagenSafetyFilterLevel { - // (undocumented) BLOCK_LOW_AND_ABOVE = "block_low_and_above", - // (undocumented) BLOCK_MEDIUM_AND_ABOVE = "block_medium_and_above", - // (undocumented) BLOCK_NONE = "block_none", - // (undocumented) BLOCK_ONLY_HIGH = "block_only_high" } -// @public (undocumented) +// @public export interface ImagenSafetySettings { personFilterLevel?: ImagenPersonFilterLevel; safetyFilterLevel?: ImagenSafetyFilterLevel; @@ -592,9 +546,6 @@ export class IntegerSchema extends Schema { constructor(schemaParams?: SchemaParams); } -// @public -export function jpeg(compressionQuality: number): ImagenImageFormat; - // @public export interface ModelParams extends BaseParams { // (undocumented) @@ -638,37 +589,9 @@ export interface ObjectSchemaInterface extends SchemaInterface { // @public export type Part = TextPart | InlineDataPart | FunctionCallPart | FunctionResponsePart | FileDataPart; -// @public -export function png(): ImagenImageFormat; - // @public export const POSSIBLE_ROLES: readonly ["user", "model", "function", "system"]; -// Warning: (ae-internal-missing-underscore) The name "PredictRequestBody" should be prefixed with an underscore because the declaration is marked as @internal -// -// @internal -export interface PredictRequestBody { - // (undocumented) - instances: [ - { - prompt: string; - } - ]; - // (undocumented) - parameters: { - sampleCount: number; - aspectRatio: string; - mimeType: string; - compressionQuality?: number; - negativePrompt?: string; - storageUri?: string; - addWatermark?: boolean; - safetyFilterLevel?: string; - personGeneration?: string; - includeRaiReason: boolean; - }; -} - // @public export interface PromptFeedback { // (undocumented) @@ -696,14 +619,6 @@ export interface RetrievedContextAttribution { // @public export type Role = (typeof POSSIBLE_ROLES)[number]; -// @public (undocumented) -export interface SafetyAttributes { - // (undocumented) - categories: string[]; - // (undocumented) - scores: number[]; -} - // @public export interface SafetyRating { // (undocumented) @@ -902,6 +817,16 @@ export const enum VertexAIErrorCode { RESPONSE_ERROR = "response-error" } +// @public +export class VertexAIModel { + // @internal + protected constructor(vertexAI: VertexAI, modelName: string); + // (undocumented) + protected _apiSettings: ApiSettings; + readonly model: string; + static normalizeModelName(modelName: string): string; +} + // @public export interface VertexAIOptions { // (undocumented) diff --git a/packages/vertexai/src/api.test.ts b/packages/vertexai/src/api.test.ts index c9432d2a7ea..c1b2635ce70 100644 --- a/packages/vertexai/src/api.test.ts +++ b/packages/vertexai/src/api.test.ts @@ -129,7 +129,7 @@ describe('Top level API', () => { ); } }); - it('getGenerativeModel gets an ImagenModel', () => { + it('getImagenModel gets an ImagenModel', () => { const genModel = getImagenModel(fakeVertexAI, { model: 'my-model' }); expect(genModel).to.be.an.instanceOf(ImagenModel); expect(genModel.model).to.equal('publishers/google/models/my-model'); diff --git a/packages/vertexai/src/api.ts b/packages/vertexai/src/api.ts index 07154356435..6ad61798c4f 100644 --- a/packages/vertexai/src/api.ts +++ b/packages/vertexai/src/api.ts @@ -28,16 +28,12 @@ import { VertexAIErrorCode } from './types'; import { VertexAIError } from './errors'; -import { GenerativeModel } from './models/generative-model'; -import { ImagenModel, jpeg, png } from './models/imagen-model'; +import { VertexAIModel, GenerativeModel, ImagenModel } from './models'; export { ChatSession } from './methods/chat-session'; export * from './requests/schema-builder'; - -export { jpeg, png }; - -export { GenerativeModel, ImagenModel }; - +export { ImagenImageFormat } from './requests/imagen-image-format'; +export { VertexAIModel, GenerativeModel, ImagenModel }; export { VertexAIError }; declare module '@firebase/component' { diff --git a/packages/vertexai/src/models/generative-model.test.ts b/packages/vertexai/src/models/generative-model.test.ts index e03f39e8a83..c2dbdfac75c 100644 --- a/packages/vertexai/src/models/generative-model.test.ts +++ b/packages/vertexai/src/models/generative-model.test.ts @@ -37,28 +37,6 @@ const fakeVertexAI: VertexAI = { }; describe('GenerativeModel', () => { - it('handles plain model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); - expect(genModel.model).to.equal('publishers/google/models/my-model'); - }); - it('handles models/ prefixed model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'models/my-model' - }); - expect(genModel.model).to.equal('publishers/google/models/my-model'); - }); - it('handles full model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'publishers/google/models/my-model' - }); - expect(genModel.model).to.equal('publishers/google/models/my-model'); - }); - it('handles prefixed tuned model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'tunedModels/my-model' - }); - expect(genModel.model).to.equal('tunedModels/my-model'); - }); it('passes params through to generateContent', async () => { const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model', diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index e719529967c..b4cf464f025 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -33,10 +33,8 @@ import { SafetySetting, StartChatParams, Tool, - ToolConfig, - VertexAIErrorCode + ToolConfig } from '../types'; -import { VertexAIError } from '../errors'; import { ChatSession } from '../methods/chat-session'; import { countTokens } from '../methods/count-tokens'; import { @@ -44,16 +42,13 @@ import { formatSystemInstruction } from '../requests/request-helpers'; import { VertexAI } from '../public-types'; -import { ApiSettings } from '../types/internal'; -import { VertexAIService } from '../service'; +import { VertexAIModel } from './vertexai-model'; /** * Class for generative model APIs. * @public */ -export class GenerativeModel { - private _apiSettings: ApiSettings; - model: string; +export class GenerativeModel extends VertexAIModel { generationConfig: GenerationConfig; safetySettings: SafetySetting[]; requestOptions?: RequestOptions; @@ -66,44 +61,7 @@ export class GenerativeModel { modelParams: ModelParams, requestOptions?: RequestOptions ) { - if (!vertexAI.app?.options?.apiKey) { - throw new VertexAIError( - VertexAIErrorCode.NO_API_KEY, - `The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.` - ); - } else if (!vertexAI.app?.options?.projectId) { - throw new VertexAIError( - VertexAIErrorCode.NO_PROJECT_ID, - `The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.` - ); - } else { - this._apiSettings = { - apiKey: vertexAI.app.options.apiKey, - project: vertexAI.app.options.projectId, - location: vertexAI.location - }; - if ((vertexAI as VertexAIService).appCheck) { - this._apiSettings.getAppCheckToken = () => - (vertexAI as VertexAIService).appCheck!.getToken(); - } - - if ((vertexAI as VertexAIService).auth) { - this._apiSettings.getAuthToken = () => - (vertexAI as VertexAIService).auth!.getToken(); - } - } - if (modelParams.model.includes('/')) { - if (modelParams.model.startsWith('models/')) { - // Add "publishers/google" if the user is only passing in 'models/model-name'. - this.model = `publishers/google/${modelParams.model}`; - } else { - // Any other custom format (e.g. tuned models) must be passed in correctly. - this.model = modelParams.model; - } - } else { - // If path is not included, assume it's a non-tuned model. - this.model = `publishers/google/models/${modelParams.model}`; - } + super(vertexAI, modelParams.model); this.generationConfig = modelParams.generationConfig || {}; this.safetySettings = modelParams.safetySettings || []; this.tools = modelParams.tools; diff --git a/packages/vertexai/src/models/imagen-model.test.ts b/packages/vertexai/src/models/imagen-model.test.ts index 909bff3dea2..a0bc857a53b 100644 --- a/packages/vertexai/src/models/imagen-model.test.ts +++ b/packages/vertexai/src/models/imagen-model.test.ts @@ -44,68 +44,6 @@ const fakeVertexAI: VertexAI = { }; describe('ImagenModel', () => { - it('handles plain model name', () => { - const imagenModel = new ImagenModel(fakeVertexAI, { model: 'my-model' }); - expect(imagenModel.model).to.equal('publishers/google/models/my-model'); - }); - it('handles models/ prefixed model name', () => { - const imagenModel = new ImagenModel(fakeVertexAI, { - model: 'models/my-model' - }); - expect(imagenModel.model).to.equal('publishers/google/models/my-model'); - }); - it('handles full model name', () => { - const imagenModel = new ImagenModel(fakeVertexAI, { - model: 'publishers/google/models/my-model' - }); - expect(imagenModel.model).to.equal('publishers/google/models/my-model'); - }); - it('handles prefixed tuned model name', () => { - const imagenModel = new ImagenModel(fakeVertexAI, { - model: 'tunedModels/my-model' - }); - expect(imagenModel.model).to.equal('tunedModels/my-model'); - }); - it('throws if not passed an api key', () => { - const fakeVertexAI: VertexAI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - projectId: 'my-project' - } - }, - location: 'us-central1' - }; - try { - new ImagenModel(fakeVertexAI, { - model: 'my-model' - }); - } catch (e) { - expect((e as VertexAIError).code).to.equal(VertexAIErrorCode.NO_API_KEY); - } - }); - it('throws if not passed a project ID', () => { - const fakeVertexAI: VertexAI = { - app: { - name: 'DEFAULT', - automaticDataCollectionEnabled: true, - options: { - apiKey: 'key' - } - }, - location: 'us-central1' - }; - try { - new ImagenModel(fakeVertexAI, { - model: 'my-model' - }); - } catch (e) { - expect((e as VertexAIError).code).to.equal( - VertexAIErrorCode.NO_PROJECT_ID - ); - } - }); it('generateImages makes a request to predict with default parameters', async () => { const imagenModel = new ImagenModel(fakeVertexAI, { model: 'my-model' @@ -190,7 +128,7 @@ describe('ImagenModel', () => { const prompt = 'A photorealistic image of a toy boat at sea.'; await imagenModel.generateImages(prompt, { numberOfImages: 4, - aspectRatio: ImagenAspectRatio.WIDESCREEN, + aspectRatio: ImagenAspectRatio.LANDSCAPE_16x9, negativePrompt: 'do not hallucinate' }); expect(makeRequestStub).to.be.calledWith( @@ -237,7 +175,7 @@ describe('ImagenModel', () => { expect((e as VertexAIError).code).to.equal(VertexAIErrorCode.FETCH_ERROR); expect((e as VertexAIError).message).to.include('400'); expect((e as VertexAIError).message).to.include( - "Image generation failed with the following error: The prompt could not be submitted. This prompt contains sensitive words that violate Google's Responsible AI practices. Try rephrasing the prompt. If you think this was an error, send feedback. Support codes: 42876398" + "Image generation failed with the following error: The prompt could not be submitted. This prompt contains sensitive words that violate Google's Responsible AI practices. Try rephrasing the prompt. If you think this was an error, send feedback." ); } finally { restore(); diff --git a/packages/vertexai/src/models/imagen-model.ts b/packages/vertexai/src/models/imagen-model.ts index 17d1e1cfb2c..3434e1dd99d 100644 --- a/packages/vertexai/src/models/imagen-model.ts +++ b/packages/vertexai/src/models/imagen-model.ts @@ -15,102 +15,85 @@ * limitations under the License. */ -import { VertexAIError } from '../errors'; import { VertexAI } from '../public-types'; import { Task, makeRequest } from '../requests/request'; import { createPredictRequestBody } from '../requests/request-helpers'; import { handlePredictResponse } from '../requests/response-helpers'; -import { VertexAIService } from '../service'; import { ImagenGCSImage, ImagenGCSImageResponse, - ImagenImageFormat, ImagenGenerationConfig, ImagenInlineImage, RequestOptions, - VertexAIErrorCode, ImagenModelParams, ImagenInlineImageResponse, ImagenModelConfig } from '../types'; -import { ApiSettings } from '../types/internal'; +import { VertexAIModel } from './vertexai-model'; /** * Class for Imagen model APIs. + * + * This class provides methods for generating images using the Imagen model. + * You can generate images inline as base64-encoded strings, or directly to + * Google Cloud Storage (GCS). + * + * @example + * ```javascript + * const imagen = new ImagenModel(vertexAI, { + * model: 'imagen-3.0-generate-001' + * }); + * + * const response = await imagen.generateImages('A photo of a cat'); + * console.log(response.images[0].bytesBase64Encoded); + * ``` + * * @public */ -export class ImagenModel { - model: string; - private _apiSettings: ApiSettings; - private modelConfig: ImagenModelConfig; +export class ImagenModel extends VertexAIModel { + /** + * Model-level configurations to use when using Imagen. + */ + readonly modelConfig: ImagenModelConfig; /** + * Constructs a new instance of the {@link ImagenModel} class. * - * @param vertexAI - * @param modelParams - * @param requestOptions + * @param vertexAI - An instance of the Vertex AI in Firebase SDK. + * @param modelParams - Parameters to use when making Imagen requests. + * @param requestOptions - Additional options to use when making requests. + * + * @throws If the `apiKey` or `projectId` fields are missing in your + * Firebase config. */ constructor( vertexAI: VertexAI, modelParams: ImagenModelParams, - private requestOptions?: RequestOptions + readonly requestOptions?: RequestOptions ) { const { model, ...modelConfig } = modelParams; + super(vertexAI, model); this.modelConfig = modelConfig; - if (model.includes('/')) { - if (model.startsWith('models/')) { - // Add "publishers/google" if the user is only passing in 'models/model-name'. - this.model = `publishers/google/${model}`; - } else { - // Any other custom format (e.g. tuned models) must be passed in correctly. - this.model = model; - } - } else { - // If path is not included, assume it's a non-tuned model. - this.model = `publishers/google/models/${model}`; - } - - if (!vertexAI.app?.options?.apiKey) { - throw new VertexAIError( - VertexAIErrorCode.NO_API_KEY, - `The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.` - ); - } else if (!vertexAI.app?.options?.projectId) { - throw new VertexAIError( - VertexAIErrorCode.NO_PROJECT_ID, - `The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.` - ); - } else { - this._apiSettings = { - apiKey: vertexAI.app.options.apiKey, - project: vertexAI.app.options.projectId, - location: vertexAI.location - }; - if ((vertexAI as VertexAIService).appCheck) { - this._apiSettings.getAppCheckToken = () => - (vertexAI as VertexAIService).appCheck!.getToken(); - } - - if ((vertexAI as VertexAIService).auth) { - this._apiSettings.getAuthToken = () => - (vertexAI as VertexAIService).auth!.getToken(); - } - } } /** - * Generates images using the Imagen model and returns them as base64-encoded strings. + * Generates images using the Imagen model and returns them as + * base64-encoded strings. * - * @param prompt The text prompt used to generate the images. - * @param imagenRequestOptions Configuration options for the Imagen generation request. + * @param prompt - The text prompt used to generate the images. + * @param imagenRequestOptions - Configuration options for the Imagen + * generation request. * See {@link ImagenGenerationConfig}. - * @returns A promise that resolves to an {@link ImagenInlineImageResponse} object containing the generated images. + * @returns A promise that resolves to an {@link ImagenInlineImageResponse} + * object containing the generated images. * - * @throws If the request fails or if the prompt is blocked, throws a {@link VertexAIError}. + * @throws If the request to generate images fails. This happens if the + * prompt is blocked. * * @remarks - * If one or more images are filtered, the returned object will have a defined `filteredReason` property. - * If all images are filtered, the `images` array will be empty, and no error will be thrown. + * If one or more images are filtered, the returned object will have a + * defined `filteredReason` property. If all images are filtered, the + * `images` array will be empty, and no error will be thrown. */ async generateImages( prompt: string, @@ -133,19 +116,23 @@ export class ImagenModel { } /** - * Generates images using the Imagen model and returns them as base64-encoded strings. + * Generates images to Google Cloud Storage (GCS) using the Imagen model. * - * @param prompt The text prompt used to generate the images. - * @param gcsURI The GCS URI where the images should be stored. - * @param imagenRequestOptions Configuration options for the Imagen generation request. - * See {@link ImagenGenerationConfig}. - * @returns A promise that resolves to an {@link ImagenGCSImageResponse} object containing the generated images. + * @param prompt - The text prompt used to generate the images. + * @param gcsURI - The GCS URI where the images should be stored. + * This should be a directory. For example, `gs://my-bucket/my-directory/`. + * @param imagenRequestOptions - Configuration options for the Imagen + * generation request. See {@link ImagenGenerationConfig}. + * @returns A promise that resolves to an {@link ImagenGCSImageResponse} + * object containing the URLs of the generated images. * - * @throws If the request fails or if the prompt is blocked, throws a {@link VertexAIError}. + * @throws If the request fails to generate images fails. This happens if + * the prompt is blocked. * * @remarks - * If one or more images are filtered, the returned object will have a defined `filteredReason` property. - * If all images are filtered, the `images` array will be empty, and no error will be thrown. + * If one or more images are filtered due to safety reasons, the returned object + * will have a defined `filteredReason` property. If all images are filtered, + * the `images` array will be empty, and no error will be thrown. */ async generateImagesGCS( prompt: string, @@ -169,31 +156,3 @@ export class ImagenModel { return handlePredictResponse(response); } } - -/** - * Creates an {@link ImagenImageFormat} for a JPEG image, to be included in an {@link ImagenModelParams}. - * - * @param compressionQuality The level of compression. - * @returns {@link ImagenImageFormat} - * - * @public - */ -export function jpeg(compressionQuality: number): ImagenImageFormat { - return { - mimeType: 'image/jpeg', - compressionQuality - }; -} - -/** - * Creates an {@link ImageImageFormat} for a PNG image, to be included in a {@link ImagenModelParams}. - * - * @returns {@link ImageImageFormat} - * - * @public - */ -export function png(): ImagenImageFormat { - return { - mimeType: 'image/png' - }; -} diff --git a/packages/vertexai/src/models/index.ts b/packages/vertexai/src/models/index.ts new file mode 100644 index 00000000000..a6ada0d894c --- /dev/null +++ b/packages/vertexai/src/models/index.ts @@ -0,0 +1,3 @@ +export * from './vertexai-model'; +export * from './generative-model'; +export * from './imagen-model'; \ No newline at end of file diff --git a/packages/vertexai/src/models/vertexai-model.test.ts b/packages/vertexai/src/models/vertexai-model.test.ts new file mode 100644 index 00000000000..6b5cdd48e2a --- /dev/null +++ b/packages/vertexai/src/models/vertexai-model.test.ts @@ -0,0 +1,103 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { use, expect } from 'chai'; +import { VertexAI, VertexAIErrorCode } from '../public-types'; +import sinonChai from 'sinon-chai'; +import { VertexAIModel } from './vertexai-model'; +import { VertexAIError } from '../errors'; + +use(sinonChai); + +/** + * A class that extends VertexAIModel that allows us to test the protected constructor. + */ +class TestModel extends VertexAIModel { + /* eslint-disable @typescript-eslint/no-useless-constructor */ + constructor( + vertexAI: VertexAI, + modelName: string + ) { + super(vertexAI, modelName); + } +} + +const fakeVertexAI: VertexAI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project' + } + }, + location: 'us-central1' +}; + +describe('VertexAIModel', () => { + it('handles plain model name', () => { + const testModel = new TestModel(fakeVertexAI, 'my-model'); + expect(testModel.model).to.equal('publishers/google/models/my-model'); + }); + it('handles models/ prefixed model name', () => { + const testModel = new TestModel(fakeVertexAI, 'models/my-model'); + expect(testModel.model).to.equal('publishers/google/models/my-model'); + }); + it('handles full model name', () => { + const testModel = new TestModel(fakeVertexAI, 'publishers/google/models/my-model'); + expect(testModel.model).to.equal('publishers/google/models/my-model'); + }); + it('handles prefixed tuned model name', () => { + const testModel = new TestModel(fakeVertexAI, 'tunedModels/my-model'); + expect(testModel.model).to.equal('tunedModels/my-model'); + }); + it('throws if not passed an api key', () => { + const fakeVertexAI: VertexAI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + projectId: 'my-project' + } + }, + location: 'us-central1' + }; + try { + new TestModel(fakeVertexAI, 'my-model'); + } catch (e) { + expect((e as VertexAIError).code).to.equal(VertexAIErrorCode.NO_API_KEY); + } + }); + it('throws if not passed a project ID', () => { + const fakeVertexAI: VertexAI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key' + } + }, + location: 'us-central1' + }; + try { + new TestModel(fakeVertexAI, 'my-model'); + } catch (e) { + expect((e as VertexAIError).code).to.equal( + VertexAIErrorCode.NO_PROJECT_ID + ); + } + }); +}); diff --git a/packages/vertexai/src/models/vertexai-model.ts b/packages/vertexai/src/models/vertexai-model.ts new file mode 100644 index 00000000000..0a05eb6fea6 --- /dev/null +++ b/packages/vertexai/src/models/vertexai-model.ts @@ -0,0 +1,94 @@ +import { VertexAIError } from "../errors"; +import { VertexAI, VertexAIErrorCode } from "../public-types"; +import { VertexAIService } from "../service"; +import { ApiSettings } from "../types/internal"; + +/** + * Base class for Vertex AI in Firebase model APIs. + * + * @public + */ +export class VertexAIModel { + /** + * The fully qualified model resource name to use for generating images + * (e.g. `publishers/google/models/imagen-3.0-generate-001`). + */ + readonly model: string; + + protected _apiSettings: ApiSettings; + + /** + * Constructs a new instance of the {@link VertexAIModel} class. + * + * This constructor should only be called from subclasses that provide + * a model API. + * + * @param vertexAI - An instance of the Vertex AI in Firebase SDK. + * @param modelName - The name of the model being used. It can be in one of the following formats: + * - `my-model` (short name, will resolve to `publishers/google/models/my-model`) + * - `models/my-model` (will resolve to `publishers/google/models/my-model`) + * - `publishers/my-publisher/models/my-model` (fully qualified model name) + * + * @throws If the `apiKey` or `projectId` fields are missing in your + * Firebase config. + * + * @internal + */ + protected constructor( + vertexAI: VertexAI, + modelName: string + ) { + this.model = VertexAIModel.normalizeModelName(modelName); + + if (!vertexAI.app?.options?.apiKey) { + throw new VertexAIError( + VertexAIErrorCode.NO_API_KEY, + `The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.` + ); + } else if (!vertexAI.app?.options?.projectId) { + throw new VertexAIError( + VertexAIErrorCode.NO_PROJECT_ID, + `The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.` + ); + } else { + this._apiSettings = { + apiKey: vertexAI.app.options.apiKey, + project: vertexAI.app.options.projectId, + location: vertexAI.location + }; + if ((vertexAI as VertexAIService).appCheck) { + this._apiSettings.getAppCheckToken = () => + (vertexAI as VertexAIService).appCheck!.getToken(); + } + + if ((vertexAI as VertexAIService).auth) { + this._apiSettings.getAuthToken = () => + (vertexAI as VertexAIService).auth!.getToken(); + } + } + } + + /** + * Normalizes the given model name to a fully qualified model resource name. + * + * @param modelName - The model name to normalize. + * @returns The fully qualified model resource name. + */ + static normalizeModelName(modelName: string): string { + let model: string; + if (modelName.includes('/')) { + if (modelName.startsWith('models/')) { + // Add 'publishers/google' if the user is only passing in 'models/model-name'. + model = `publishers/google/${modelName}`; + } else { + // Any other custom format (e.g. tuned models) must be passed in correctly. + model = modelName; + } + } else { + // If path is not included, assume it's a non-tuned model. + model = `publishers/google/models/${modelName}`; + } + + return model; + } +} \ No newline at end of file diff --git a/packages/vertexai/src/requests/imagen-image-format.ts b/packages/vertexai/src/requests/imagen-image-format.ts new file mode 100644 index 00000000000..6b7c84d8f2b --- /dev/null +++ b/packages/vertexai/src/requests/imagen-image-format.ts @@ -0,0 +1,48 @@ +/** + * Defines the image format for images output by Imagen. + * + * Use this class to specify the desired format (JPEG or PNG) and compression quality + * for images generated by Imagen. This is typically included as part of + * {@link ImagenModelParams}. + * + * @example + * ```javascript + * const imagenModelParams = { + * // ... other ImagenModelParams + * imageFormat: ImagenImageFormat.jpeg(75) // JPEG with a compression level of 75. + * } + * ``` + * + * @public + */ +export class ImagenImageFormat { + mimeType: string; + compressionQuality?: number; + + private constructor() { + this.mimeType = "image/png"; + } + + /** + * Creates an ImagenImageFormat for a JPEG image. + * + * @param compressionQuality - The level of compression (a number between 0 and 100). + * @returns ImagenImageFormat + * + * @public + */ + static jpeg(compressionQuality: number): ImagenImageFormat { + return { mimeType: "image/jpeg", compressionQuality }; + } + + /** + * Creates an ImageImageFormat for a PNG image. + * + * @returns ImageImageFormat + * + * @public + */ + static png(): ImagenImageFormat { + return { mimeType: "image/png" }; + } +} \ No newline at end of file diff --git a/packages/vertexai/src/requests/request-helpers.test.ts b/packages/vertexai/src/requests/request-helpers.test.ts index 9582f384a2c..2317457aca4 100644 --- a/packages/vertexai/src/requests/request-helpers.test.ts +++ b/packages/vertexai/src/requests/request-helpers.test.ts @@ -239,7 +239,7 @@ describe('request formatting methods', () => { const addWatermark = true; const numberOfImages = 4; const negativePrompt = 'do not hallucinate'; - const aspectRatio = ImagenAspectRatio.WIDESCREEN; + const aspectRatio = ImagenAspectRatio.LANDSCAPE_16x9; const body = createPredictRequestBody({ prompt, numberOfImages, diff --git a/packages/vertexai/src/requests/request-helpers.ts b/packages/vertexai/src/requests/request-helpers.ts index fa4bdf402d7..831bb6a26a5 100644 --- a/packages/vertexai/src/requests/request-helpers.ts +++ b/packages/vertexai/src/requests/request-helpers.ts @@ -18,13 +18,13 @@ import { Content, GenerateContentRequest, - PredictRequestBody, Part, VertexAIErrorCode, ImagenAspectRatio, - ImagenRequestConfig } from '../types'; import { VertexAIError } from '../errors'; +import { ImagenImageFormat } from './imagen-image-format'; +import { ImagenRequestConfig, PredictRequestBody } from '../types/internal'; export function formatSystemInstruction( input?: string | Part | Content @@ -53,12 +53,10 @@ export function formatNewContent( newParts = [{ text: request }]; } else { for (const elem of request) { - // This throws an error if request is not iterable if (typeof elem === 'string') { newParts.push({ text: elem }); } else { - // We assume this is a Part, but it could be anything. - newParts.push(elem); // This could be + newParts.push(elem); } } } @@ -118,8 +116,6 @@ export function formatGenerateContentInput( if ((params as GenerateContentRequest).contents) { formattedRequest = params as GenerateContentRequest; } else { - // Array or string - // ... or something else const content = formatNewContent(params as string | Array); formattedRequest = { contents: [content] }; } @@ -140,7 +136,7 @@ export function formatGenerateContentInput( export function createPredictRequestBody({ prompt, gcsURI, - imageFormat = { mimeType: 'image/png' }, + imageFormat = ImagenImageFormat.png(), addWatermark, safetySettings, numberOfImages = 1, diff --git a/packages/vertexai/src/requests/response-helpers.ts b/packages/vertexai/src/requests/response-helpers.ts index 6b82463dfa1..acdc6e741d1 100644 --- a/packages/vertexai/src/requests/response-helpers.ts +++ b/packages/vertexai/src/requests/response-helpers.ts @@ -201,9 +201,9 @@ export function formatBlockErrorMessage( } /** - * Convert a generic successful fetch {@link Response} body to an Imagen response object + * Convert a generic successful fetch response body to an Imagen response object * that can be returned to the user. This converts the REST APIs response format to our - * representation of a response. + * APIs representation of a response. * * @internal */ diff --git a/packages/vertexai/src/types/imagen/internal.ts b/packages/vertexai/src/types/imagen/internal.ts index 1171df81278..ba7ee2c8f49 100644 --- a/packages/vertexai/src/types/imagen/internal.ts +++ b/packages/vertexai/src/types/imagen/internal.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -// Internal Imagen types +import { ImagenGenerationConfig, ImagenModelConfig } from "./requests"; /** * A response from the REST API is expected to look like this in the success case: @@ -43,12 +43,91 @@ */ export interface ImagenResponseInternal { predictions?: Array<{ - // Defined if the prediction was not filtered + /** + * The MIME type of the generated image. + */ mimeType?: string; + /** + * The image data encoded as a base64 string. + */ bytesBase64Encoded?: string; + gcsUri?: string; - // Defined if the prediction was filtered, and there is no image + /** + * The reason why the image was filtered. + */ raiFilteredReason?: string; }>; } + +/** + * The parameters to be sent in the request body of the HTTP call + * to the Vertex AI backend. + * + * We need a seperate internal-only interface for this because the REST + * API expects different parameter names than what we show to our users. + * + * This interface should be populated from the ImagenGenerationConfig that + * the user defines. + * + * Sample request body JSON: + * { + * "instances": [ + * { + * "prompt": "Portrait of a golden retriever on a beach." + * } + * ], + * "parameters": { + * "mimeType": "image/png", + * "safetyFilterLevel": "block_low_and_above", + * "personGeneration": "allow_all", + * "sampleCount": 2, + * "includeRaiReason": true, + * "aspectRatio": "9:16" + * } + * } + * + * @internal + */ +export interface PredictRequestBody { + instances: [ + { + prompt: string; + } + ]; + parameters: { + sampleCount: number; // Maps to numberOfImages + aspectRatio: string; + mimeType: string; + compressionQuality?: number; + negativePrompt?: string; + storageUri?: string; // Maps to gcsURI + addWatermark?: boolean; + safetyFilterLevel?: string; + personGeneration?: string; // Maps to personFilterLevel + includeRaiReason: boolean; + }; +} + +/** + * Contains all possible REST API paramaters. + * This is the intersection of the model-level (`ImagenModelParams`), + * request-level (`ImagenGenerationConfig`) configurations, along with + * the other required parameters prompt and gcsURI (for GCS generation only). + * + * @internal + */ +export interface ImagenRequestConfig + extends ImagenModelConfig, + ImagenGenerationConfig { + /** + * The text prompt used to generate the images. + */ + prompt: string; + /** + * The Google Cloud Storage (GCS) URI where the images should be stored + * (for GCS requests only). + */ + gcsURI?: string; +} \ No newline at end of file diff --git a/packages/vertexai/src/types/imagen/requests.ts b/packages/vertexai/src/types/imagen/requests.ts index d87723a31c7..a9a9c664b45 100644 --- a/packages/vertexai/src/types/imagen/requests.ts +++ b/packages/vertexai/src/types/imagen/requests.ts @@ -15,128 +15,130 @@ * limitations under the License. */ +import { ImagenImageFormat } from '../../requests/imagen-image-format'; + +/** + * Parameters for configuring an {@link ImagenModel}. + * + * @public + */ export interface ImagenModelParams extends ImagenModelConfig { + /** + * The Imagen model to use for generating images. + * For example: `imagen-3.0-generate-001`. + */ model: string; } +/** + * Model-level configuration options for Imagen. + * + * @public + */ export interface ImagenModelConfig { + /** + * The image format of the generated images. + */ imageFormat?: ImagenImageFormat; + /** + * Whether to add a watermark to generated images. + */ addWatermark?: boolean; + /** + * Safety settings for filtering inapropriate content. + */ safetySettings?: ImagenSafetySettings; } -export interface ImagenGenerationConfig { - numberOfImages?: number; // Default to 1. Possible values are [1...4] - negativePrompt?: string; // Default to null - aspectRatio?: ImagenAspectRatio; // Default to "1:1" -} - /** - * Contains all possible REST API paramaters. - * This is the intersection of the model-level (`ImagenModelParams`), - * request-level (`ImagenGenerationConfig`) configurations, along with - * the other required parameters prompt and gcsURI (for GCS generation only). + * Request-level configuration options for generating images with Imagen. * - * @internal + * @public */ -export interface ImagenRequestConfig - extends ImagenModelConfig, - ImagenGenerationConfig { - prompt: string; - gcsURI?: string; -} - -export interface ImagenImageFormat { - mimeType: string; // image/png, or image/jpeg, default image/png - compressionQuality?: number; // 0-100, default 75. Only for image/jpeg +export interface ImagenGenerationConfig { + /** + * The number of images to generate. Must be between 1 and 4. Defaults to 1. + */ + numberOfImages?: number; + /** + * A text prompt describing what should not be included in the image. + */ + negativePrompt?: string; + /** + * The aspect ratio of the generated images. Defaults to `1:1`. + */ + aspectRatio?: ImagenAspectRatio; } /** + * Safety filter levels for Imagen. + * * @public */ export enum ImagenSafetyFilterLevel { + /** + * Block images with low or higher safety severity. + */ BLOCK_LOW_AND_ABOVE = 'block_low_and_above', + /** + * Block images with medium or higher safety severity. + */ BLOCK_MEDIUM_AND_ABOVE = 'block_medium_and_above', + /** + * Block images with high safety severity. + */ BLOCK_ONLY_HIGH = 'block_only_high', + /** + * Do not block any images based on safety. + */ BLOCK_NONE = 'block_none' } /** + * Person filter levels for Imagen. + * * @public */ export enum ImagenPersonFilterLevel { + /** + * Do not allow any person generation. + */ BLOCK_ALL = 'dont_allow', + /** + * Allow only adults in generated images. + */ ALLOW_ADULT = 'allow_adult', + /** + * Allow all person generation. + */ ALLOW_ALL = 'allow_all' } /** + * Safety settings for Imagen. + * * @public */ export interface ImagenSafetySettings { /** - * Safety filter level + * The safety filter level to use. */ safetyFilterLevel?: ImagenSafetyFilterLevel; /** - * Generate people. + * The person filter level to use. */ personFilterLevel?: ImagenPersonFilterLevel; } -export enum ImagenAspectRatio { - SQUARE = '1:1', - CLASSIC_PORTRAIT = '3:4', - CLASSIC_LANDSCAPE = '4:3', - WIDESCREEN = '16:9', - PORTRAIT = '9:16' -} - /** - * The parameters to be sent in the request body of the HTTP call - * to the Vertex AI backend. - * - * We need a seperate internal-only interface for this because the REST - * API expects different parameter names than what we show to our users. - * - * This interface should be populated from the {@link ImagenGenerationConfig} that - * the user defines. - * - * Sample request body JSON: - * { - * "instances": [ - * { - * "prompt": "Portrait of a golden retriever on a beach." - * } - * ], - * "parameters": { - * "mimeType": "image/png", - * "safetyFilterLevel": "block_low_and_above", - * "personGeneration": "allow_all", - * "sampleCount": 2, - * "includeRaiReason": true, - * "aspectRatio": "9:16" - * } - * } + * Aspect ratios for Imagen images. * - * @internal + * @public */ -export interface PredictRequestBody { - instances: [ - { - prompt: string; - } - ]; - parameters: { - sampleCount: number; // maps to numberOfImages - aspectRatio: string; - mimeType: string; - compressionQuality?: number; - negativePrompt?: string; - storageUri?: string; - addWatermark?: boolean; - safetyFilterLevel?: string; - personGeneration?: string; - includeRaiReason: boolean; - }; +export enum ImagenAspectRatio { + SQUARE = '1:1', + LANDSCAPE_3x4 = '3:4', + PORTRAIT_4x3 = '4:3', + LANDSCAPE_16x9 = '16:9', + PORTRAIT_9x16 = '9:16' } diff --git a/packages/vertexai/src/types/imagen/responses.ts b/packages/vertexai/src/types/imagen/responses.ts index 620f49be4b9..a05e3fda0f7 100644 --- a/packages/vertexai/src/types/imagen/responses.ts +++ b/packages/vertexai/src/types/imagen/responses.ts @@ -16,58 +16,73 @@ */ /** - * Base class for types of images that the Imagen Model can return. - * - * @internal - */ -export interface ImagenImage { - mimeType: string; -} - -/** - * Image generated by Imagen to inline bytes. + * An image generated by Imagen, represented as inline bytes. * * @public */ -export interface ImagenInlineImage extends ImagenImage { +export interface ImagenInlineImage { + /** + * The MIME type of the image. + */ + mimeType: string; + /** + * The image data encoded as a base64 string. + */ bytesBase64Encoded: string; } /** - * Image generated by Imagen, stored in Google Cloud Storage (GCS). + * An image generated by Imagen, stored in Google Cloud Storage (GCS). * * @public */ -export interface ImagenGCSImage extends ImagenImage { +export interface ImagenGCSImage { /** - * The Google Cloud Storage (GCS) URI at which the generated image is stored. + * The MIME type of the image. + */ + mimeType: string; + /** + * The Google Cloud Storage (GCS) URI where the image is stored. */ gcsURI: string; } /** - * Imagen image response. + * The response from a request to generate images to inline bytes. * * @public */ export interface ImagenInlineImageResponse { /** - * The images generated by Imagen. If all images were filtered, this will be empty. + * The images generated by Imagen. + * If all images were filtered out due to safety reasons, this array will be empty. */ images: ImagenInlineImage[]; /** - * The reason the missing images were filtered. - * For the mappings of error codes to reasons, see {@link https://cloud.google.com/vertex-ai/generative-ai/docs/image/responsible-ai-imagen#safety-categories}. + * The reason why any images were filtered. This field is only present if one + * or more images were filtered. + * For the mappings of error codes to reasons, see + * {@link https://cloud.google.com/vertex-ai/generative-ai/docs/image/responsible-ai-imagen#safety-categories}. */ filteredReason?: string; } +/** + * The response from a request to generate images to Google Cloud Storage (GCS). + * + * @public + */ export interface ImagenGCSImageResponse { + /** + * The images generated by Imagen. + * If all images were filtered due to safety reasons, this array will be empty. + */ images: ImagenGCSImage[]; - filteredReason?: string; -} - -export interface ImagenImageReponse { - images: ImagenImage[]; + /** + * The reason why any images were filtered. This field is only present if one + * or more images were filtered. + * For the mappings of error codes to reasons, see + * {@link https://cloud.google.com/vertex-ai/generative-ai/docs/image/responsible-ai-imagen#safety-categories}. + */ filteredReason?: string; } diff --git a/packages/vertexai/src/types/responses.ts b/packages/vertexai/src/types/responses.ts index e2a442821da..83cd4366f12 100644 --- a/packages/vertexai/src/types/responses.ts +++ b/packages/vertexai/src/types/responses.ts @@ -46,11 +46,6 @@ export interface GenerateContentStreamResult { response: Promise; } -export interface SafetyAttributes { - categories: string[]; - scores: number[]; -} - /** * Response object wrapped with helper methods. *