Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ai rag provider #1129

Merged
merged 20 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions packages/ai/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@ import {
getAuthenticatedFetch,
type AuthenticatedFetch,
} from "edgedb/dist/utils.js";
import type {
AIOptions,
QueryContext,
RAGRequest,
StreamingMessage,
} from "./types.js";
import type { AIOptions, QueryContext, StreamingMessage } from "./types.js";
import { getHTTPSCRAMAuth } from "edgedb/dist/httpScram.js";
import { cryptoUtils } from "edgedb/dist/browserCrypto.js";

Expand Down Expand Up @@ -68,11 +63,17 @@ export class EdgeDBAI {
});
}

private async fetchRag(request: Omit<RAGRequest, "model" | "prompt">) {
private async fetchRag(request: any) {
diksipav marked this conversation as resolved.
Show resolved Hide resolved
const headers = request.stream
? { Accept: "text/event-stream", "Content-Type": "application/json" }
: { Accept: "application/json", "Content-Type": "application/json" };

const { messages } = request;

const providedPrompt =
this.options.prompt &&
("name" in this.options.prompt || "id" in this.options.prompt);

const response = await (
await this.authenticatedFetch
)("rag", {
Expand All @@ -81,7 +82,24 @@ export class EdgeDBAI {
body: JSON.stringify({
...request,
model: this.options.model,
prompt: this.options.prompt,
...((this.options.prompt || messages.length > 1) && {
prompt: {
...this.options.prompt,
...(messages.length > 1 && {
// if user provides prompt.custom without id/name it is his choice
// to not include default prompt msgs, but if user provides messages
// and doesn't provide prompt.custom, since we add messages to the
// prompt.custom we also have to include default prompt messages
...(!this.options.prompt?.custom &&
!providedPrompt && {
name: "builtin::rag-default",
}),
custom: [...(this.options.prompt?.custom || []), ...messages],
}),
},
}),
query: [...messages].reverse().find((msg) => msg.role === "user")!
.content[0].text,
}),
});

Expand All @@ -93,10 +111,10 @@ export class EdgeDBAI {
return response;
}

async queryRag(query: string, context = this.context): Promise<string> {
async queryRag(request: any, context = this.context): Promise<string> {
const res = await this.fetchRag({
context,
query,
...request,
stream: false,
});

Expand All @@ -117,19 +135,18 @@ export class EdgeDBAI {
"Expected response to be object with response key of type string",
);
}

return data.response;
}

streamRag(
query: string,
request: any,
context = this.context,
): AsyncIterable<StreamingMessage> & PromiseLike<Response> {
const fetchRag = this.fetchRag.bind(this);

const ragOptions = {
context,
query,
...request,
stream: true,
};

Expand Down
144 changes: 101 additions & 43 deletions packages/ai/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { z } from "zod";

export type ChatParticipantRole = "system" | "user" | "assistant" | "tool";

export type Prompt =
| { name: string }
| { id: string }
| { custom: { role: ChatParticipantRole; content: string }[] };
| { name: string; custom?: EdgeDBMessage[] }
| { id: string; custom?: EdgeDBMessage[] }
| { custom: EdgeDBMessage[] };

export interface AIOptions {
model: string;
Expand All @@ -25,53 +27,109 @@ export interface RAGRequest {
stream?: boolean;
}

export interface MessageStart {
type: "message_start";
message: {
role: "assistant" | "system" | "user";
id: string;
model: string;
};
}
export type EdgeDBMessage =
| EdgeDBSystemMessage
| EdgeDBUserMessage
| EdgeDBAssistantMessage
| EdgeDBToolMessage;

export interface ContentBlockStart {
type: "content_block_start";
index: number;
content_block: {
text: string;
type: "text";
};
export interface EdgeDBSystemMessage {
role: "system";
content: string;
}

export interface ContentBlockDelta {
type: "content_block_delta";
delta: {
type: "text_delta";
text: string;
};
index: number;
export interface EdgeDBUserMessage {
role: "user";
content: { type: "text"; text: string }[];
}

export interface ContentBlockStop {
type: "content_block_stop";
index: number;
export interface EdgeDBAssistantMessage {
role: "assistant";
content: string;
tool_calls?: {
id: string;
type: "function";
function: { name: string; arguments: string };
}[];
}

export interface MessageDelta {
type: "message_delta";
delta: {
stop_reason: "stop";
};
export interface EdgeDBToolMessage {
role: "tool";
content: string;
tool_call_id: string;
}

export interface MessageStop {
type: "message_stop";
}
export type StreamingMessage = z.infer<typeof _edgedbRagChunkSchema>;

export type StreamingMessage =
| MessageStart
| ContentBlockStart
| ContentBlockDelta
| ContentBlockStop
| MessageDelta
| MessageStop;
const _edgedbRagChunkSchema = z.discriminatedUnion("type", [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not exporting this Zod schema or using it in this file: why are we using a Zod schema instead of writing this as a TypeScript interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this in the vercel provider so it was easier to copy paste it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we export it then maybe? Not a blocker.

Copy link
Contributor Author

@diksipav diksipav Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can export it, I just dont understand/see when would someone need it?
we export:
export type StreamingMessage = z.infer<typeof _edgedbRagChunkSchema>;
which will user most probably use?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just dont understand/see when would someone need it?

Oh, I thought you mean that you are using a copy-pasted version of this schema in the Vercel provider, no? I'm saying export it here and import it in the Vercel provider instead of copy-pasting it. Or export it from the Vercel provider and import the type here. Whichever makes sense.

This is not a blocker for merging, it just feels a little strange here to have an essentially "private" Zod schema that isn't used for anything, but has to be kept in sync across two packages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm okay I can do this, even tho it makes sense and is better to have one source of truth it feels a bit weird to me to install the whole lib just because one type, but maybe that's usually the way to go in this kind of situations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I hadn't noticed: Does the Vercel AI SDK Provider not use the @edgedb/ai package for anything already?

z.object({
type: z.literal("message_start"),
message: z.object({
id: z.string(),
model: z.string(),
role: z.enum(["assistant", "system", "user"]),
usage: z
.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
})
.nullish(),
}),
}),
z.object({
type: z.literal("content_block_start"),
index: z.number(),
content_block: z.discriminatedUnion("type", [
z.object({
type: z.literal("text"),
text: z.string(),
}),
z.object({
type: z.literal("tool_use"),
id: z.string().nullish(),
name: z.string(),
args: z.string().nullish(),
}),
]),
}),
z.object({
type: z.literal("content_block_delta"),
index: z.number(),
delta: z.discriminatedUnion("type", [
z.object({
type: z.literal("text_delta"),
text: z.string(),
}),
z.object({
type: z.literal("tool_call_delta"),
args: z.string(), // partial json
}),
]),
logprobs: z
.object({
tokens: z.array(z.string()),
token_logprobs: z.array(z.number()),
top_logprobs: z.array(z.record(z.string(), z.number())).nullable(),
})
.nullish(),
}),
z.object({
type: z.literal("content_block_stop"),
index: z.number(),
}),
z.object({
type: z.literal("message_delta"),
delta: z.object({ stop_reason: z.string() }),
usage: z.object({ completion_tokens: z.number() }).nullish(),
}),
z.object({
type: z.literal("message_stop"),
}),
z.object({
type: z.literal("error"),
error: z.object({
type: z.string(),
message: z.string(),
}),
}),
]);
1 change: 1 addition & 0 deletions packages/driver/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"@js-temporal/polyfill": "0.4.3",
"@repo/tsconfig": "*",
"@types/jest": "^29.5.12",
"@types/node": "^22.7.5",
"@types/semver": "^7.5.8",
"@types/shell-quote": "^1.7.5",
"@types/which": "^3.0.3",
Expand Down
6 changes: 6 additions & 0 deletions packages/driver/src/cli.mts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ async function main(args: string[]) {
async function whichEdgeDbCli() {
debug("Checking if CLI is in PATH...");
const locations = await which("edgedb", { nothrow: true, all: true });

if (locations == null) {
debug(" - No CLI found in PATH.");
return null;
}

for (const location of locations) {
const actualLocation = await fs.realpath(location);
debug(
Expand Down
16 changes: 12 additions & 4 deletions packages/driver/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ export interface CryptoUtils {
const _tokens = new WeakMap<ResolvedConnectConfigReadonly, string>();

export type AuthenticatedFetch = (
path: string,
init: RequestInit,
path: RequestInfo | URL,
init?: RequestInit,
) => Promise<Response>;

export async function getAuthenticatedFetch(
Expand All @@ -94,10 +94,18 @@ export async function getAuthenticatedFetch(
_tokens.set(config, token);
}

return (path: string, init: RequestInit) => {
return (input: RequestInfo | URL, init?: RequestInit) => {
let path: string;

if (typeof input === "string") {
path = input;
} else if (input instanceof Request) {
path = input.url;
} else path = input.toString();

const url = new URL(path, databaseUrl);

const headers = new Headers(init.headers);
const headers = new Headers(init?.headers);

if (config.user !== undefined) {
headers.append("X-EdgeDB-User", config.user);
Expand Down
43 changes: 43 additions & 0 deletions packages/vercel-ai-provider/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Provider for the Vercel AI SDK

The provider for the [Vercel AI SDK](https://sdk.vercel.ai/docs) contains language model support for
the OpenAi, Mistral and Anthropic chat and completion APIs that implements EdgeDB RAG, and embedding model support for the OpenAI and Mistral embeddings API.

## Setup

Provider is available in the `@edgedb/vercel-ai-provider` module. You can install it with:

```bash
npm i @edgedb/vercel-ai-provider
```

## Provider Instance

You can import the default provider instance `edgedbRag` from `@edgedb/vercel-ai-provider`:

```ts
import { edgedbRag } from "@edgedb/vercel-ai-provider";
```

## Example

```ts
import { generateText } from "ai";
import { createClient } from "edgedb";
import { edgedbRag } from "@edgedb/vercel-ai-provider";

const textModel = (await edgedbRag).languageModel("gpt-4-turbo-preview");

const { text } = await generateText({
model: textModel.withSettings({
context: { query: "your context" },
}),
prompt: "your prompt",
});

console.log(text);
```

## Documentation

Please check out the **[EdgeDB provider documentation](https://docs.edgedb.com)** for more information.
Loading