Skip to content

Commit

Permalink
chore: add unit test automation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592614153
  • Loading branch information
sararob authored and copybara-github committed Dec 27, 2023
1 parent ce962f8 commit 47de6ae
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 51 deletions.
3 changes: 3 additions & 0 deletions .eslintignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
**/node_modules
build/
docs/
3 changes: 3 additions & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"extends": "./node_modules/gts"
}
21 changes: 21 additions & 0 deletions .github/workflows/presubmit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
on:
pull_request:
name: presubmit
jobs:
units:
runs-on: ubuntu-latest
strategy:
matrix:
node: [18, 20]
steps:
- uses: actions/checkout@v4
- uses: pnpm/action-setup@v2
with:
version: ^6.24.1
- run: node --version
- run: npm install
- run: npm run test
name: Run unit tests
env:
BUILD_TYPE: presubmit
TEST_TYPE: units
3 changes: 3 additions & 0 deletions .prettierignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
**/node_modules
build/
docs/
20 changes: 20 additions & 0 deletions .prettierrc.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* @license
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

module.exports = {
...require('gts/.prettierrc.json')
};
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
"type": "commonjs",
"scripts": {
"clean": "gts clean",
"compile": "tsc",
"compile": "tsc -p .",
"docs": "jsdoc -c .jsdoc.js",
"predocs-test": "npm run docs",
"docs-test": "linkinator docs",
"compile:oss": "tsc -p tsconfig.json.oss",
"fix": "gts fix",
"test": "TODO",
"test": "jasmine build/test/*.js",
"system-test": "jasmine build/system_test/*.js",
"lint": "gts lint",
"clean-js-files": "find . -type f -name \"*.js\" -exec rm -f {} +",
Expand All @@ -36,9 +36,9 @@
},
"devDependencies": {
"@types/jasmine": "^5.1.2",
"jasmine": "^5.1.0",
"@types/node": "^20.9.0",
"gts": "^5.2.0",
"jasmine": "^5.1.0",
"typescript": "~5.2.0",
"jsdoc": "^4.0.0",
"jsdoc-fresh": "^3.0.0",
Expand Down
2 changes: 1 addition & 1 deletion src/types/content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export enum HarmCategory {
}

/**
* Threshhold above which a prompt or candidate will be blocked.
* Threshold above which a prompt or candidate will be blocked.
*/
export enum HarmBlockThreshold {
// Unspecified harm block threshold.
Expand Down
8 changes: 4 additions & 4 deletions system_test/end_to_end_sample_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const generativeVisionModelWithPrefix = vertex_ai.preview.getGenerativeModel({
// expect(...).toBeInstanceOf(...)
describe('generateContentStream', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
});

it('should should return a stream and aggregated response when passed text', async () => {
Expand Down Expand Up @@ -113,8 +113,8 @@ describe('generateContentStream', () => {
item.candidates[0],
`sys test failure on generateContentStream, for item ${item}`
);
for (const candiate of item.candidates) {
for (const part of candiate.content.parts as TextPart[]) {
for (const candidate of item.candidates) {
for (const part of candidate.content.parts as TextPart[]) {
assert(
!part.text.includes('\ufffd'),
`sys test failure on generateContentStream, for item ${item}`
Expand Down Expand Up @@ -246,7 +246,7 @@ describe('countTokens', () => {

describe('generateContentStream using models/model-id', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
});

it('should should return a stream and aggregated response when passed text', async () => {
Expand Down
111 changes: 68 additions & 43 deletions src/index_test.ts → test/index_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
*/

/* tslint:disable */
import 'jasmine';

import {ChatSession, GenerativeModel, StartChatParams, VertexAI} from './index';
import * as StreamFunctions from './process_stream';
import {
ChatSession,
GenerativeModel,
StartChatParams,
VertexAI,
} from '../src/index';
import * as StreamFunctions from '../src/process_stream';
import {
CountTokensRequest,
FinishReason,
Expand All @@ -32,8 +36,8 @@ import {
SafetyRating,
SafetySetting,
StreamGenerateContentResult,
} from './types/content';
import {constants} from './util';
} from '../src/types/content';
import {constants} from '../src/util';

const PROJECT = 'test_project';
const LOCATION = 'test_location';
Expand Down Expand Up @@ -140,6 +144,13 @@ const TEST_MULTIPART_MESSAGE = [
],
},
];

const fetchResponseObj = {
status: 200,
statusText: 'OK',
headers: {'Content-Type': 'application/json'},
};

/**
* Returns a generator, used to mock the generateContentStream response
* @ignore
Expand All @@ -153,6 +164,8 @@ export async function* testGenerator(): AsyncGenerator<GenerateContentResponse>
describe('VertexAI', () => {
let vertexai: VertexAI;
let model: GenerativeModel;
let expectedStreamResult: StreamGenerateContentResult;
let fetchSpy: jasmine.Spy;

beforeEach(() => {
vertexai = new VertexAI({
Expand All @@ -161,6 +174,15 @@ describe('VertexAI', () => {
});
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'});
expectedStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
const fetchResult = new Response(
JSON.stringify(expectedStreamResult),
fetchResponseObj
);
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult);
});

it('should be instantiated', () => {
Expand All @@ -175,10 +197,6 @@ describe('VertexAI', () => {
const expectedResult: GenerateContentResult = {
response: TEST_MODEL_RESPONSE,
};
const expectedStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
Expand Down Expand Up @@ -264,12 +282,11 @@ describe('VertexAI', () => {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
const requestSpy = spyOn(global, 'fetch');
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
await model.generateContent(req);
expect(requestSpy.calls.allArgs()[0][0].toString()).toContain(
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain(
TEST_ENDPOINT_BASE_PATH
);
});
Expand Down Expand Up @@ -300,12 +317,11 @@ describe('VertexAI', () => {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
const requestSpy = spyOn(global, 'fetch');
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
await model.generateContent(req);
expect(requestSpy.calls.allArgs()[0][0].toString()).toContain(
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain(
`${LOCATION}-aiplatform.googleapis.com`
);
});
Expand All @@ -325,20 +341,23 @@ describe('VertexAI', () => {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
const requestSpy = spyOn(global, 'fetch');
// const fetchResult = Promise.resolve(
// new Response(JSON.stringify(expectedStreamResult),
// fetchResponseObj));
// const requestSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
await model.generateContent(reqWithEmptyConfigs);
const requestArgs = requestSpy.calls.allArgs()[0][1];
const requestArgs = fetchSpy.calls.allArgs()[0][1];
if (typeof requestArgs === 'object' && requestArgs) {
expect(JSON.stringify(requestArgs['body'])).not.toContain('top_k');
}
});
});

describe('generateContent', () => {
it('inclues top_k when it is within 1 - 40', async () => {
it('includes top_k when it is within 1 - 40', async () => {
const reqWithEmptyConfigs: GenerateContentRequest = {
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE,
generation_config: {top_k: 1},
Expand All @@ -351,12 +370,11 @@ describe('VertexAI', () => {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
const requestSpy = spyOn(global, 'fetch');
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
await model.generateContent(reqWithEmptyConfigs);
const requestArgs = requestSpy.calls.allArgs()[0][1];
const requestArgs = fetchSpy.calls.allArgs()[0][1];
if (typeof requestArgs === 'object' && requestArgs) {
expect(JSON.stringify(requestArgs['body'])).toContain('top_k');
}
Expand All @@ -379,7 +397,6 @@ describe('VertexAI', () => {
expectedStreamResult
);
const resp = await model.generateContent(req);
console.log(resp.response.candidates[0].citationMetadata, 'yoyoyo');
expect(
resp.response.candidates[0].citationMetadata?.citationSources.length
).toEqual(
Expand Down Expand Up @@ -435,25 +452,28 @@ describe('VertexAI', () => {
expect(resp).toBeInstanceOf(ChatSession);
});
});
});

describe('countTokens', () => {
it('returns the token count', async () => {
const req: CountTokensRequest = {
contents: TEST_USER_CHAT_MESSAGE,
};
const responseBody = {
totalTokens: 1,
};
const response = new Response(JSON.stringify(responseBody), {
status: 200,
statusText: 'OK',
headers: {'Content-Type': 'application/json'},
});
const responsePromise = Promise.resolve(response);
spyOn(global, 'fetch').and.returnValue(responsePromise);
const resp = await model.countTokens(req);
expect(resp).toEqual(responseBody);
describe('countTokens', () => {
it('returns the token count', async () => {
const vertexai = new VertexAI({
project: PROJECT,
location: LOCATION,
});
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'});
const req: CountTokensRequest = {
contents: TEST_USER_CHAT_MESSAGE,
};
const responseBody = {
totalTokens: 1,
};
const response = Promise.resolve(
new Response(JSON.stringify(responseBody), fetchResponseObj)
);
spyOn(global, 'fetch').and.returnValue(response);
const resp = await model.countTokens(req);
expect(resp).toEqual(responseBody);
});
});

Expand All @@ -462,6 +482,7 @@ describe('ChatSession', () => {
let chatSessionWithNoArgs: ChatSession;
let vertexai: VertexAI;
let model: GenerativeModel;
let expectedStreamResult: StreamGenerateContentResult;

beforeEach(() => {
vertexai = new VertexAI({project: PROJECT, location: LOCATION});
Expand All @@ -470,12 +491,16 @@ describe('ChatSession', () => {
chatSession = model.startChat({
history: TEST_USER_CHAT_MESSAGE,
});
chatSessionWithNoArgs = model.startChat();
});

it('should add the provided message to the session history', () => {
expect(chatSession.history).toEqual(TEST_USER_CHAT_MESSAGE);
expect(chatSession.history.length).toEqual(1);
chatSessionWithNoArgs = model.startChat();
expectedStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
const fetchResult = Promise.resolve(
new Response(JSON.stringify(expectedStreamResult), fetchResponseObj)
);
spyOn(global, 'fetch').and.returnValue(fetchResult);
});

describe('sendMessage', () => {
Expand Down Expand Up @@ -510,7 +535,7 @@ describe('ChatSession', () => {
);
const resp = await chatSessionWithNoArgs.sendMessage(req);
expect(resp).toEqual(expectedResult);
expect(chatSession.history.length).toEqual(3);
expect(chatSessionWithNoArgs.history.length).toEqual(2);
});

// TODO: unbreak this test. Currently chatSession.history is saving the
Expand Down Expand Up @@ -567,7 +592,7 @@ describe('ChatSession', () => {
expect(chatSession.history[1].role).toEqual(constants.USER_ROLE);
expect(chatSession.history[2].role).toEqual(constants.MODEL_ROLE);
});
it('returns a StreamGenerateContentResponse and appends role if missiong', async () => {
it('returns a StreamGenerateContentResponse and appends role if missing', async () => {
const req = 'How are you doing today?';
const expectedResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_MISSING_ROLE),
Expand Down

0 comments on commit 47de6ae

Please sign in to comment.