Skip to content

Commit

Permalink
Introduce VertexAIModel base class, add documentation, and respond to…
Browse files Browse the repository at this point in the history
… other comments
  • Loading branch information
dlarocque committed Jan 2, 2025
1 parent ec35231 commit 180f091
Show file tree
Hide file tree
Showing 18 changed files with 552 additions and 463 deletions.
145 changes: 35 additions & 110 deletions common/api-review/vertexai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,16 +323,14 @@ export interface GenerativeContentBlob {
}

// @public
export class GenerativeModel {
export class GenerativeModel extends VertexAIModel {
constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions);
countTokens(request: CountTokensRequest | string | Array<string | Part>): Promise<CountTokensResponse>;
generateContent(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentStreamResult>;
// (undocumented)
generationConfig: GenerationConfig;
// (undocumented)
model: string;
// (undocumented)
requestOptions?: RequestOptions;
// (undocumented)
safetySettings: SafetySetting[];
Expand Down Expand Up @@ -432,77 +430,53 @@ export enum HarmSeverity {
HARM_SEVERITY_NEGLIGIBLE = "HARM_SEVERITY_NEGLIGIBLE"
}

// @public (undocumented)
// @public
export enum ImagenAspectRatio {
// (undocumented)
CLASSIC_LANDSCAPE = "4:3",
LANDSCAPE_16x9 = "16:9",
// (undocumented)
CLASSIC_PORTRAIT = "3:4",
LANDSCAPE_3x4 = "3:4",
// (undocumented)
PORTRAIT = "9:16",
PORTRAIT_4x3 = "4:3",
// (undocumented)
SQUARE = "1:1",
PORTRAIT_9x16 = "9:16",
// (undocumented)
WIDESCREEN = "16:9"
SQUARE = "1:1"
}

// Warning: (ae-incompatible-release-tags) The symbol "ImagenGCSImage" is marked as @public, but its signature references "ImagenImage" which is marked as @internal
//
// @public
export interface ImagenGCSImage extends ImagenImage {
export interface ImagenGCSImage {
gcsURI: string;
mimeType: string;
}

// @public (undocumented)
// @public
export interface ImagenGCSImageResponse {
// (undocumented)
filteredReason?: string;
// (undocumented)
images: ImagenGCSImage[];
}

// @public (undocumented)
// @public
export interface ImagenGenerationConfig {
// (undocumented)
aspectRatio?: ImagenAspectRatio;
// (undocumented)
negativePrompt?: string;
// (undocumented)
numberOfImages?: number;
}

// Warning: (ae-internal-missing-underscore) The name "ImagenImage" should be prefixed with an underscore because the declaration is marked as @internal
//
// @internal
export interface ImagenImage {
// (undocumented)
mimeType: string;
}

// @public (undocumented)
export interface ImagenImageFormat {
// @public
export class ImagenImageFormat {
// (undocumented)
compressionQuality?: number;
static jpeg(compressionQuality: number): ImagenImageFormat;
// (undocumented)
mimeType: string;
static png(): ImagenImageFormat;
}

// @public (undocumented)
export interface ImagenImageReponse {
// (undocumented)
filteredReason?: string;
// Warning: (ae-incompatible-release-tags) The symbol "images" is marked as @public, but its signature references "ImagenImage" which is marked as @internal
//
// (undocumented)
images: ImagenImage[];
}

// Warning: (ae-incompatible-release-tags) The symbol "ImagenInlineImage" is marked as @public, but its signature references "ImagenImage" which is marked as @internal
//
// @public
export interface ImagenInlineImage extends ImagenImage {
// (undocumented)
export interface ImagenInlineImage {
bytesBase64Encoded: string;
mimeType: string;
}

// @public
Expand All @@ -512,63 +486,43 @@ export interface ImagenInlineImageResponse {
}

// @public
export class ImagenModel {
export class ImagenModel extends VertexAIModel {
constructor(vertexAI: VertexAI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined);
generateImages(prompt: string, imagenRequestOptions?: ImagenGenerationConfig): Promise<ImagenInlineImageResponse>;
generateImagesGCS(prompt: string, gcsURI: string, imagenRequestOptions?: ImagenGenerationConfig): Promise<ImagenGCSImageResponse>;
readonly modelConfig: ImagenModelConfig;
// (undocumented)
model: string;
}
readonly requestOptions?: RequestOptions | undefined;
}

// @public (undocumented)
// @public
export interface ImagenModelConfig {
// (undocumented)
addWatermark?: boolean;
// (undocumented)
imageFormat?: ImagenImageFormat;
// (undocumented)
safetySettings?: ImagenSafetySettings;
}

// @public
export interface ImagenModelParams extends ImagenModelConfig {
// (undocumented)
model: string;
}

// @public (undocumented)
// @public
export enum ImagenPersonFilterLevel {
// (undocumented)
ALLOW_ADULT = "allow_adult",
// (undocumented)
ALLOW_ALL = "allow_all",
// (undocumented)
BLOCK_ALL = "dont_allow"
}

// Warning: (ae-internal-missing-underscore) The name "ImagenRequestConfig" should be prefixed with an underscore because the declaration is marked as @internal
//
// @internal
export interface ImagenRequestConfig extends ImagenModelConfig, ImagenGenerationConfig {
// (undocumented)
gcsURI?: string;
// (undocumented)
prompt: string;
}

// @public (undocumented)
// @public
export enum ImagenSafetyFilterLevel {
// (undocumented)
BLOCK_LOW_AND_ABOVE = "block_low_and_above",
// (undocumented)
BLOCK_MEDIUM_AND_ABOVE = "block_medium_and_above",
// (undocumented)
BLOCK_NONE = "block_none",
// (undocumented)
BLOCK_ONLY_HIGH = "block_only_high"
}

// @public (undocumented)
// @public
export interface ImagenSafetySettings {
personFilterLevel?: ImagenPersonFilterLevel;
safetyFilterLevel?: ImagenSafetyFilterLevel;
Expand All @@ -592,9 +546,6 @@ export class IntegerSchema extends Schema {
constructor(schemaParams?: SchemaParams);
}

// @public
export function jpeg(compressionQuality: number): ImagenImageFormat;

// @public
export interface ModelParams extends BaseParams {
// (undocumented)
Expand Down Expand Up @@ -638,37 +589,9 @@ export interface ObjectSchemaInterface extends SchemaInterface {
// @public
export type Part = TextPart | InlineDataPart | FunctionCallPart | FunctionResponsePart | FileDataPart;

// @public
export function png(): ImagenImageFormat;

// @public
export const POSSIBLE_ROLES: readonly ["user", "model", "function", "system"];

// Warning: (ae-internal-missing-underscore) The name "PredictRequestBody" should be prefixed with an underscore because the declaration is marked as @internal
//
// @internal
export interface PredictRequestBody {
// (undocumented)
instances: [
{
prompt: string;
}
];
// (undocumented)
parameters: {
sampleCount: number;
aspectRatio: string;
mimeType: string;
compressionQuality?: number;
negativePrompt?: string;
storageUri?: string;
addWatermark?: boolean;
safetyFilterLevel?: string;
personGeneration?: string;
includeRaiReason: boolean;
};
}

// @public
export interface PromptFeedback {
// (undocumented)
Expand Down Expand Up @@ -696,14 +619,6 @@ export interface RetrievedContextAttribution {
// @public
export type Role = (typeof POSSIBLE_ROLES)[number];

// @public (undocumented)
export interface SafetyAttributes {
// (undocumented)
categories: string[];
// (undocumented)
scores: number[];
}

// @public
export interface SafetyRating {
// (undocumented)
Expand Down Expand Up @@ -902,6 +817,16 @@ export const enum VertexAIErrorCode {
RESPONSE_ERROR = "response-error"
}

// @public
export class VertexAIModel {
// @internal
protected constructor(vertexAI: VertexAI, modelName: string);
// (undocumented)
protected _apiSettings: ApiSettings;
readonly model: string;
static normalizeModelName(modelName: string): string;
}

// @public
export interface VertexAIOptions {
// (undocumented)
Expand Down
2 changes: 1 addition & 1 deletion packages/vertexai/src/api.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ describe('Top level API', () => {
);
}
});
it('getGenerativeModel gets an ImagenModel', () => {
it('getImagenModel gets an ImagenModel', () => {
const genModel = getImagenModel(fakeVertexAI, { model: 'my-model' });
expect(genModel).to.be.an.instanceOf(ImagenModel);
expect(genModel.model).to.equal('publishers/google/models/my-model');
Expand Down
10 changes: 3 additions & 7 deletions packages/vertexai/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,12 @@ import {
VertexAIErrorCode
} from './types';
import { VertexAIError } from './errors';
import { GenerativeModel } from './models/generative-model';
import { ImagenModel, jpeg, png } from './models/imagen-model';
import { VertexAIModel, GenerativeModel, ImagenModel } from './models';

export { ChatSession } from './methods/chat-session';
export * from './requests/schema-builder';

export { jpeg, png };

export { GenerativeModel, ImagenModel };

export { ImagenImageFormat } from './requests/imagen-image-format';
export { VertexAIModel, GenerativeModel, ImagenModel };
export { VertexAIError };

declare module '@firebase/component' {
Expand Down
22 changes: 0 additions & 22 deletions packages/vertexai/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,6 @@ const fakeVertexAI: VertexAI = {
};

describe('GenerativeModel', () => {
it('handles plain model name', () => {
const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' });
expect(genModel.model).to.equal('publishers/google/models/my-model');
});
it('handles models/ prefixed model name', () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'models/my-model'
});
expect(genModel.model).to.equal('publishers/google/models/my-model');
});
it('handles full model name', () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'publishers/google/models/my-model'
});
expect(genModel.model).to.equal('publishers/google/models/my-model');
});
it('handles prefixed tuned model name', () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'tunedModels/my-model'
});
expect(genModel.model).to.equal('tunedModels/my-model');
});
it('passes params through to generateContent', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
Expand Down
50 changes: 4 additions & 46 deletions packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,22 @@ import {
SafetySetting,
StartChatParams,
Tool,
ToolConfig,
VertexAIErrorCode
ToolConfig
} from '../types';
import { VertexAIError } from '../errors';
import { ChatSession } from '../methods/chat-session';
import { countTokens } from '../methods/count-tokens';
import {
formatGenerateContentInput,
formatSystemInstruction
} from '../requests/request-helpers';
import { VertexAI } from '../public-types';
import { ApiSettings } from '../types/internal';
import { VertexAIService } from '../service';
import { VertexAIModel } from './vertexai-model';

/**
* Class for generative model APIs.
* @public
*/
export class GenerativeModel {
private _apiSettings: ApiSettings;
model: string;
export class GenerativeModel extends VertexAIModel {
generationConfig: GenerationConfig;
safetySettings: SafetySetting[];
requestOptions?: RequestOptions;
Expand All @@ -66,44 +61,7 @@ export class GenerativeModel {
modelParams: ModelParams,
requestOptions?: RequestOptions
) {
if (!vertexAI.app?.options?.apiKey) {
throw new VertexAIError(
VertexAIErrorCode.NO_API_KEY,
`The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.`
);
} else if (!vertexAI.app?.options?.projectId) {
throw new VertexAIError(
VertexAIErrorCode.NO_PROJECT_ID,
`The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.`
);
} else {
this._apiSettings = {
apiKey: vertexAI.app.options.apiKey,
project: vertexAI.app.options.projectId,
location: vertexAI.location
};
if ((vertexAI as VertexAIService).appCheck) {
this._apiSettings.getAppCheckToken = () =>
(vertexAI as VertexAIService).appCheck!.getToken();
}

if ((vertexAI as VertexAIService).auth) {
this._apiSettings.getAuthToken = () =>
(vertexAI as VertexAIService).auth!.getToken();
}
}
if (modelParams.model.includes('/')) {
if (modelParams.model.startsWith('models/')) {
// Add "publishers/google" if the user is only passing in 'models/model-name'.
this.model = `publishers/google/${modelParams.model}`;
} else {
// Any other custom format (e.g. tuned models) must be passed in correctly.
this.model = modelParams.model;
}
} else {
// If path is not included, assume it's a non-tuned model.
this.model = `publishers/google/models/${modelParams.model}`;
}
super(vertexAI, modelParams.model);
this.generationConfig = modelParams.generationConfig || {};
this.safetySettings = modelParams.safetySettings || [];
this.tools = modelParams.tools;
Expand Down
Loading

0 comments on commit 180f091

Please sign in to comment.