Skip to content

Commit

Permalink
Adds context as first-class feature of Model. (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbleigh authored May 8, 2024
1 parent 7b54f0f commit 0106996
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 74 deletions.
18 changes: 18 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ await generate({
});
```

## Retriever context

Documents from a retriever can be passed directly to `generate` to provide
grounding context:

```javascript
const docs = await companyPolicyRetriever({ query: question });

await generate({
model: geminiPro,
prompt: `Answer using the available context from company policy: ${question}`,
context: docs,
});
```

The document context is automatically appended to the content of the prompt
sent to the model.

## Message history

Genkit models support maintaining a history of the messages sent to the model
Expand Down
25 changes: 2 additions & 23 deletions docs/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ import { configureGenkit } from '@genkit-ai/core';
import { defineFlow } from '@genkit-ai/flow';
import { generate } from '@genkit-ai/ai/generate';
import { retrieve } from '@genkit-ai/ai/retriever';
import { definePrompt } from '@genkit-ai/dotprompt';
import {
devLocalRetrieverRef,
devLocalVectorstore,
Expand Down Expand Up @@ -211,31 +210,11 @@ export const ragFlow = defineFlow(
query: input,
options: { k: 3 },
});
const facts = docs.map((d) => d.text());

const promptGenerator = definePrompt(
{
name: 'bob-facts',
model: 'google-vertex/gemini-pro',
input: {
schema: z.object({
facts: z.array(z.string()),
question: z.string(),
}),
},
},
'{{#each people}}{{this}}\n\n{{/each}}\n{{question}}'
);
const prompt = await promptGenerator.generate({
input: {
facts,
question: input,
},
});

const llmResponse = await generate({
model: geminiPro,
prompt: prompt.text(),
prompt: `Answer this question: ${input}`,
context: docs,
});

const output = llmResponse.text();
Expand Down
81 changes: 45 additions & 36 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema';
import { DocumentData } from '@google-cloud/firestore';
import { z } from 'zod';
import { extractJson } from './extract.js';
import {
Expand Down Expand Up @@ -386,36 +387,38 @@ function inferRoleFromParts(parts: Part[]): Role {
}

export async function toGenerateRequest(
prompt: GenerateOptions
options: GenerateOptions
): Promise<GenerateRequest> {
const promptMessage: MessageData = { role: 'user', content: [] };
if (typeof prompt.prompt === 'string') {
promptMessage.content.push({ text: prompt.prompt });
} else if (Array.isArray(prompt.prompt)) {
promptMessage.role = inferRoleFromParts(prompt.prompt);
promptMessage.content.push(...prompt.prompt);
if (typeof options.prompt === 'string') {
promptMessage.content.push({ text: options.prompt });
} else if (Array.isArray(options.prompt)) {
promptMessage.role = inferRoleFromParts(options.prompt);
promptMessage.content.push(...options.prompt);
} else {
promptMessage.role = inferRoleFromParts([prompt.prompt]);
promptMessage.content.push(prompt.prompt);
promptMessage.role = inferRoleFromParts([options.prompt]);
promptMessage.content.push(options.prompt);
}
const messages: MessageData[] = [...(prompt.history || []), promptMessage];
const messages: MessageData[] = [...(options.history || []), promptMessage];
let tools: Action<any, any>[] | undefined;
if (prompt.tools) {
tools = await resolveTools(prompt.tools);
if (options.tools) {
tools = await resolveTools(options.tools);
}

const out = {
messages,
candidates: prompt.candidates,
config: prompt.config,
candidates: options.candidates,
config: options.config,
tools: tools?.map((tool) => toToolDefinition(tool)) || [],
output: {
format:
prompt.output?.format ||
(prompt.output?.schema || prompt.output?.jsonSchema ? 'json' : 'text'),
options.output?.format ||
(options.output?.schema || options.output?.jsonSchema
? 'json'
: 'text'),
schema: toJsonSchema({
schema: prompt.output?.schema,
jsonSchema: prompt.output?.jsonSchema,
schema: options.output?.schema,
jsonSchema: options.output?.jsonSchema,
}),
},
};
Expand All @@ -431,6 +434,8 @@ export interface GenerateOptions<
model: ModelArgument<CustomOptions>;
/** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */
prompt: string | Part | Part[];
/** Retrieved documents to be used as context for this generation. */
context?: DocumentData[];
/** Conversation history for multi-turn prompting when supported by the underlying model. */
history?: MessageData[];
/** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */
Expand Down Expand Up @@ -530,29 +535,33 @@ export async function generate<
| GenerateOptions<O, CustomOptions>
| PromiseLike<GenerateOptions<O, CustomOptions>>
): Promise<GenerateResponse<z.infer<O>>> {
const prompt: GenerateOptions<O, CustomOptions> =
const resolvedOptions: GenerateOptions<O, CustomOptions> =
await Promise.resolve(options);
const model = await resolveModel(prompt.model);
const model = await resolveModel(resolvedOptions.model);
if (!model) {
throw new Error(`Model ${JSON.stringify(prompt.model)} not found`);
throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`);
}

let tools: ToolAction[] | undefined;
if (prompt.tools?.length) {
if (resolvedOptions.tools?.length) {
if (!model.__action.metadata?.model.supports?.tools) {
throw new Error(
`Model ${JSON.stringify(prompt.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
`Model ${JSON.stringify(resolvedOptions.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
);
}
tools = await resolveTools(prompt.tools);
tools = await resolveTools(resolvedOptions.tools);
}

const request = await toGenerateRequest(prompt);
telemetry.recordGenerateActionInputLogs(model.__action.name, prompt, request);
const request = await toGenerateRequest(resolvedOptions);
telemetry.recordGenerateActionInputLogs(
model.__action.name,
resolvedOptions,
request
);
const response = await runWithStreamingCallback(
prompt.streamingCallback
resolvedOptions.streamingCallback
? (chunk: GenerateResponseChunkData) =>
prompt.streamingCallback!(new GenerateResponseChunk(chunk))
resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk))
: undefined,
async () => new GenerateResponse<z.infer<O>>(await model(request), request)
);
Expand All @@ -569,13 +578,13 @@ export async function generate<
});
}

if (prompt.output?.schema || prompt.output?.jsonSchema) {
if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) {
// find a candidate with valid output schema
const candidateValidations = response.candidates.map((c) => {
try {
return validateSchema(c.output(), {
jsonSchema: prompt.output?.jsonSchema,
schema: prompt.output?.schema,
jsonSchema: resolvedOptions.output?.jsonSchema,
schema: resolvedOptions.output?.schema,
});
} catch (e) {
return {
Expand Down Expand Up @@ -612,10 +621,10 @@ export async function generate<
const toolCalls = selected.message.content.filter(
(part) => !!part.toolRequest
);
if (prompt.returnToolRequests || toolCalls.length === 0) {
if (resolvedOptions.returnToolRequests || toolCalls.length === 0) {
telemetry.recordGenerateActionOutputLogs(
model.__action.name,
prompt,
resolvedOptions,
response
);
return response;
Expand All @@ -642,10 +651,10 @@ export async function generate<
};
})
);
prompt.history = request.messages;
prompt.history.push(selected.message);
prompt.prompt = toolResponses;
return await generate(prompt);
resolvedOptions.history = request.messages;
resolvedOptions.history.push(selected.message);
resolvedOptions.prompt = toolResponses;
return await generate(resolvedOptions);
}

export type GenerateStreamOptions<
Expand Down
15 changes: 12 additions & 3 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ import {
import { toJsonSchema } from '@genkit-ai/core/schema';
import { performance } from 'node:perf_hooks';
import { z } from 'zod';
import { conformOutput, validateSupport } from './model/middleware.js';
import { DocumentDataSchema } from './document.js';
import {
augmentWithContext,
conformOutput,
validateSupport,
} from './model/middleware.js';
import * as telemetry from './telemetry.js';

//
Expand Down Expand Up @@ -127,6 +132,8 @@ export const ModelInfoSchema = z.object({
systemRole: z.boolean().optional(),
/** Model can output this type of data. */
output: z.array(OutputFormatSchema).optional(),
/** Model can natively support document-based context grounding. */
context: z.boolean().optional(),
})
.optional(),
});
Expand Down Expand Up @@ -166,6 +173,7 @@ export const GenerateRequestSchema = z.object({
config: z.any().optional(),
tools: z.array(ToolDefinitionSchema).optional(),
output: OutputConfigSchema.optional(),
context: z.array(DocumentDataSchema).optional(),
candidates: z.number().optional(),
});

Expand Down Expand Up @@ -264,11 +272,12 @@ export function defineModel<
) => Promise<GenerateResponseData>
): ModelAction<CustomOptionsSchema> {
const label = options.label || `${options.name} GenAI model`;
const middleware = [
const middleware: ModelMiddleware[] = [
...(options.use || []),
validateSupport(options),
conformOutput(),
];
if (!options?.supports?.context) middleware.push(augmentWithContext());
middleware.push(conformOutput());
const act = defineAction(
{
actionType: 'model',
Expand Down
52 changes: 52 additions & 0 deletions js/ai/src/model/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

import { Document } from '../document.js';
import { ModelInfo, ModelMiddleware, Part } from '../model.js';

/**
Expand Down Expand Up @@ -171,3 +172,54 @@ export function simulateSystemPrompt(options?: {
return next({ ...req, messages });
};
}

export interface AugmentWithContextOptions {
/** Preceding text to place before the rendered context documents. */
preface?: string | null;
/** A function to render a document into a text part to be included in the message. */
itemTemplate?: (d: Document, options?: AugmentWithContextOptions) => string;
/** The metadata key to use for citation reference. Pass `null` to provide no citations. */
citationKey?: string | null;
}

export const CONTEXT_PREFACE =
'\n\nUse the following information to complete your task:\n\n';
const CONTEXT_ITEM_TEMPLATE = (
d: Document,
index: number,
options?: AugmentWithContextOptions
) => {
let out = '- ';
if (options?.citationKey) {
out += `[${d.metadata![options.citationKey]}]: `;
} else if (options?.citationKey === undefined) {
out += `[${d.metadata?.['ref'] || d.metadata?.['id'] || index}]: `;
}
out += d.text() + '\n';
return out;
};
export function augmentWithContext(
options?: AugmentWithContextOptions
): ModelMiddleware {
const preface =
typeof options?.preface === 'undefined' ? CONTEXT_PREFACE : options.preface;
const itemTemplate = options?.itemTemplate || CONTEXT_ITEM_TEMPLATE;
const citationKey = options?.citationKey;
return (req, next) => {
// if there is no context in the request, no-op
if (!req.context?.length) return next(req);
const userMessage = req.messages.at(-1);
// if there are no messages, no-op
if (!userMessage) return next(req);
// if there is already a context part, no-op
if (userMessage?.content.find((p) => p.metadata?.purpose === 'context'))
return next(req);
let out = `${preface || ''}`;
req.context?.forEach((d, i) => {
out += itemTemplate(new Document(d), i, options);
});
out += '\n';
userMessage.content.push({ text: out, metadata: { purpose: 'context' } });
return next(req);
};
}
10 changes: 10 additions & 0 deletions js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import { Action, defineAction, JSONSchema7 } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { DocumentData } from '@google-cloud/firestore';
import z from 'zod';
import { GenerateOptions } from './generate';
import { GenerateRequest, GenerateRequestSchema, ModelArgument } from './model';
Expand All @@ -35,6 +36,13 @@ export type PromptAction<I extends z.ZodTypeAny = z.ZodTypeAny> = Action<
};
};

export function isPrompt(arg: any): boolean {
return (
typeof arg === 'function' &&
(arg as any).__action?.metadata?.type === 'prompt'
);
}

export function definePrompt<I extends z.ZodTypeAny>(
{
name,
Expand Down Expand Up @@ -79,6 +87,7 @@ export async function renderPrompt<
>(params: {
prompt: PromptArgument<I>;
input: z.infer<I>;
context?: DocumentData[];
model: ModelArgument<CustomOptions>;
config?: z.infer<CustomOptions>;
}): Promise<GenerateOptions> {
Expand All @@ -94,5 +103,6 @@ export async function renderPrompt<
config: { ...(rendered.config || {}), ...params.config },
history: rendered.messages.slice(0, rendered.messages.length - 1),
prompt: rendered.messages[rendered.messages.length - 1].content,
context: params.context,
};
}
Loading

0 comments on commit 0106996

Please sign in to comment.