Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js/ai,plugins/vertexai): add rerankers #803

Merged
merged 8 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions docs/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,89 @@ const docs = await retrieve({
options: { preRerankK: 7, k: 3 },
});
```

### Rerankers and Two-Stage Retrieval

A reranking model — also known as a cross-encoder — is a type of model that, given a query and document, will output a similarity score. We use this score to reorder the documents by relevance to our query. Reranker APIs take a list of documents (for example the output of a retriever) and reorders the documents based on their relevance to the query. This step can be useful for fine-tuning the results and ensuring that the most pertinent information is used in the prompt provided to a generative model.


#### Reranker Example

A reranker in Genkit is defined in a similar syntax to retrievers and indexers. Here is an example using a reranker in Genkit. This flow reranks a set of documents based on their relevance to the provided query using a predefined Vertex AI reranker.

```ts
import { rerank } from '@genkit-ai/ai/reranker';
import { Document } from '@genkit-ai/ai/retriever';
import { defineFlow } from '@genkit-ai/flow';
import * as z from 'zod';

const FAKE_DOCUMENT_CONTENT = [
'pythagorean theorem',
'e=mc^2',
'pi',
'dinosaurs',
'quantum mechanics',
'pizza',
'harry potter',
];

export const rerankFlow = defineFlow(
{
name: 'rerankFlow',
inputSchema: z.object({ query: z.string() }),
outputSchema: z.array(
z.object({
text: z.string(),
score: z.number(),
})
),
},
async ({ query }) => {
const documents = FAKE_DOCUMENT_CONTENT.map((text) =>
Document.fromText(text)
);

const rerankedDocuments = await rerank({
reranker: 'vertexai/semantic-ranker-512',
query: Document.fromText(query),
documents,
});

return rerankedDocuments.map((doc) => ({
text: doc.text(),
score: doc.metadata.score,
}));
}
);
```
This reranker uses the Vertex AI genkit plugin with `semantic-ranker-512` to score and rank documents. The higher the score, the more relevant the document is to the query.

#### Custom Rerankers

You can also define custom rerankers to suit your specific use case. This is helpful when you need to rerank documents using your own custom logic or a custom model. Here’s a simple example of defining a custom reranker:
```typescript
import { defineReranker } from '@genkit-ai/ai/reranker';
import * as z from 'zod';

export const customReranker = defineReranker(
{
name: 'custom/reranker',
configSchema: z.object({
k: z.number().optional(),
}),
},
async (query, documents, options) => {
// Your custom reranking logic here
const rerankedDocs = documents.map((doc) => {
const score = Math.random(); // Assign random scores for demonstration
return {
...doc,
metadata: { ...doc.metadata, score },
};
});

return rerankedDocs.sort((a, b) => b.metadata.score - a.metadata.score).slice(0, options.k || 3);
}
);
```
Once defined, this custom reranker can be used just like any other reranker in your RAG flows, giving you flexibility to implement advanced reranking strategies.
9 changes: 9 additions & 0 deletions js/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@
"require": "./lib/tool.js",
"import": "./lib/tool.mjs",
"default": "./lib/tool.js"
},
"./reranker": {
"types": "./lib/reranker.d.ts",
"require": "./lib/reranker.js",
"import": "./lib/reranker.mjs",
"default": "./lib/reranker.js"
}
},
"typesVersions": {
Expand All @@ -114,6 +120,9 @@
],
"tool": [
"lib/tool"
],
"reranker": [
"lib/reranker"
]
}
}
Expand Down
205 changes: 205 additions & 0 deletions js/ai/src/reranker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { Action, defineAction } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import * as z from 'zod';
import { Part, PartSchema } from './document.js';
import { Document, DocumentData, DocumentDataSchema } from './retriever.js';

type RerankerFn<RerankerOptions extends z.ZodTypeAny> = (
query: Document,
documents: Document[],
queryOpts: z.infer<RerankerOptions>
) => Promise<RerankerResponse>;

export const RankedDocumentDataSchema = z.object({
content: z.array(PartSchema),
metadata: z
.object({
score: z.number(), // Enforces that 'score' must be a number
})
.passthrough(), // Allows other properties in 'metadata' with any type
});

export type RankedDocumentData = z.infer<typeof RankedDocumentDataSchema>;

export class RankedDocument extends Document implements RankedDocumentData {
content: Part[];
metadata: { score: number } & Record<string, any>;

constructor(data: RankedDocumentData) {
super(data);
this.content = data.content;
this.metadata = data.metadata;
}
/**
* Returns the score of the document.
* @returns The score of the document.
*/
score(): number {
return this.metadata.score;
}
}

const RerankerRequestSchema = z.object({
query: DocumentDataSchema,
documents: z.array(DocumentDataSchema),
options: z.any().optional(),
});

const RerankerResponseSchema = z.object({
documents: z.array(RankedDocumentDataSchema),
});
type RerankerResponse = z.infer<typeof RerankerResponseSchema>;

export const RerankerInfoSchema = z.object({
label: z.string().optional(),
/** Supported model capabilities. */
supports: z
.object({
/** Model can process media as part of the prompt (multimodal input). */
media: z.boolean().optional(),
})
.optional(),
});
export type RerankerInfo = z.infer<typeof RerankerInfoSchema>;

export type RerankerAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
Action<
typeof RerankerRequestSchema,
typeof RerankerResponseSchema,
{ model: RerankerInfo }
> & {
__configSchema?: CustomOptions;
};

function rerankerWithMetadata<
RerankerOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
reranker: Action<typeof RerankerRequestSchema, typeof RerankerResponseSchema>,
configSchema?: RerankerOptions
): RerankerAction<RerankerOptions> {
const withMeta = reranker as RerankerAction<RerankerOptions>;
withMeta.__configSchema = configSchema;
return withMeta;
}

/**
* Creates a reranker action for the provided {@link RerankerFn} implementation.
*/
export function defineReranker<OptionsType extends z.ZodTypeAny = z.ZodTypeAny>(
options: {
name: string;
configSchema?: OptionsType;
info?: RerankerInfo;
},
runner: RerankerFn<OptionsType>
) {
const reranker = defineAction(
{
actionType: 'reranker',
name: options.name,
inputSchema: options.configSchema
? RerankerRequestSchema.extend({
options: options.configSchema.optional(),
})
: RerankerRequestSchema,
outputSchema: RerankerResponseSchema,
metadata: {
type: 'reranker',
info: options.info,
},
},
(i) =>
runner(
new Document(i.query),
i.documents.map((d) => new Document(d)),
i.options
)
);
const rwm = rerankerWithMetadata(
reranker as Action<
typeof RerankerRequestSchema,
typeof RerankerResponseSchema
>,
options.configSchema
);
return rwm;
}

export interface RerankerParams<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> {
reranker: RerankerArgument<CustomOptions>;
query: string | DocumentData;
documents: DocumentData[];
options?: z.infer<CustomOptions>;
}

export type RerankerArgument<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> = RerankerAction<CustomOptions> | RerankerReference<CustomOptions> | string;

/**
* Reranks documents from a {@link RerankerArgument} based on the provided query.
*/
export async function rerank<CustomOptions extends z.ZodTypeAny>(
params: RerankerParams<CustomOptions>
): Promise<Array<RankedDocument>> {
let reranker: RerankerAction<CustomOptions>;
if (typeof params.reranker === 'string') {
reranker = await lookupAction(`/reranker/${params.reranker}`);
} else if (Object.hasOwnProperty.call(params.reranker, 'info')) {
reranker = await lookupAction(`/reranker/${params.reranker.name}`);
} else {
reranker = params.reranker as RerankerAction<CustomOptions>;
}
if (!reranker) {
throw new Error('Unable to resolve the reranker');
}
const response = await reranker({
query:
typeof params.query === 'string'
? Document.fromText(params.query)
: params.query,
documents: params.documents,
options: params.options,
});

return response.documents.map((d) => new RankedDocument(d));
}

export const CommonRerankerOptionsSchema = z.object({
k: z.number().describe('Number of documents to rerank').optional(),
});

export interface RerankerReference<CustomOptions extends z.ZodTypeAny> {
name: string;
configSchema?: CustomOptions;
info?: RerankerInfo;
}

/**
* Helper method to configure a {@link RerankerReference} to a plugin.
*/
export function rerankerRef<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
options: RerankerReference<CustomOptionsSchema>
): RerankerReference<CustomOptionsSchema> {
return { ...options };
}
Loading
Loading