Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable dynamic retrieval for Google Search Retrieval grounding #474

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/models/test/models_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
HarmBlockThreshold,
HarmCategory,
HarmProbability,
Mode,
RequestOptions,
SafetyRating,
SafetySetting,
Expand Down Expand Up @@ -181,7 +182,10 @@ const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [
const TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL: GoogleSearchRetrievalTool[] = [
{
googleSearchRetrieval: {
disableAttribution: false,
dynamicRetrievalConfig: {
dynamicThreshold: 0.5,
mode: Mode.MODE_DYNAMIC,
},
},
},
];
Expand Down Expand Up @@ -332,7 +336,7 @@ describe('GenerativeModel startChat', () => {
history: TEST_USER_CHAT_MESSAGE,
});
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await chat.sendMessage(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand All @@ -357,7 +361,7 @@ describe('GenerativeModel startChat', () => {
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
});
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await chat.sendMessage(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand Down Expand Up @@ -550,7 +554,7 @@ describe('GenerativeModelPreview startChat', () => {
history: TEST_USER_CHAT_MESSAGE,
});
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await chat.sendMessage(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand All @@ -575,7 +579,7 @@ describe('GenerativeModelPreview startChat', () => {
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
});
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await chat.sendMessage(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand Down Expand Up @@ -1161,7 +1165,7 @@ describe('GenerativeModel generateContent', () => {
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
};
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await model.generateContent(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand Down Expand Up @@ -1638,7 +1642,7 @@ describe('GenerativeModelPreview generateContent', () => {
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
};
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await model.generateContent(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand Down Expand Up @@ -1942,7 +1946,7 @@ describe('GenerativeModel generateContentStream', () => {
};
spyOn(PostFetchFunctions, 'processStream').and.resolveTo(expectedResult);
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await model.generateContent(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand Down Expand Up @@ -2259,7 +2263,7 @@ describe('GenerativeModelPreview generateContentStream', () => {
};
spyOn(PostFetchFunctions, 'processStream').and.resolveTo(expectedResult);
const expectedBody =
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}';
'{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}';
await model.generateContent(req);
// @ts-ignore
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
Expand Down
19 changes: 17 additions & 2 deletions src/types/content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -969,11 +969,26 @@ export declare interface Retrieval {
disableAttribution?: boolean;
}

export enum Mode {
MODE_UNSPECIFIED = 'MODE_UNSPECIFIED',
MODE_DYNAMIC = 'MODE_DYNAMIC',
}

/** Describes the options to customize dynamic retrieval. */
export declare interface DynamicRetrievalConfig {
/** Optional. The threshold to be used in dynamic retrieval. If not set, a system default value is used. */
dynamicThreshold?: number;
/** The mode of the predictor to be used in dynamic retrieval. */
mode?: Mode;
}

/**
* Tool to retrieve public web data for grounding, powered by Google.
*/
// eslint-disable-next-line @typescript-eslint/no-empty-interface
export declare interface GoogleSearchRetrieval {}
export declare interface GoogleSearchRetrieval {
/** Specifies the dynamic retrieval configuration for the given source. */
dynamicRetrievalConfig?: DynamicRetrievalConfig;
}

/**
* Retrieve from Vertex AI Search datastore for grounding.
Expand Down
6 changes: 5 additions & 1 deletion system_test/end_to_end_sample_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
VertexAI,
GenerateContentResponseHandler,
GoogleApiError,
Mode,
} from '../src';
import {FunctionDeclarationSchemaType} from '../src/types';

Expand Down Expand Up @@ -87,7 +88,10 @@ const TOOLS_WITH_FUNCTION_DECLARATION: FunctionDeclarationsTool[] = [
const TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL: GoogleSearchRetrievalTool[] = [
{
googleSearchRetrieval: {
disableAttribution: false,
dynamicRetrievalConfig: {
dynamicThreshold: 0.2,
mode: Mode.MODE_DYNAMIC,
},
},
},
];
Expand Down