Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ai rag provider #1129

Merged
merged 20 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions packages/ai/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,139 @@
export type ChatParticipantRole = "system" | "user" | "assistant" | "tool";
import { z } from "zod";

export type ChatParticipantRole = "system" | "user" | "assistant" | "tool";

export type Prompt =
| { name: string; custom?: EdgeDBRagMessage }
| { id: string; custom?: EdgeDBRagMessage[] }
| { custom: EdgeDBRagMessage[] };

export interface AIOptions {
model: string;
prompt?: Prompt;
}

export interface QueryContext {
query: string;
variables?: Record<string, unknown>;
globals?: Record<string, unknown>;
max_object_count?: number;
}

export interface RAGRequest {
model: string;
prompt?: Prompt;
context: QueryContext;
query: string;
stream?: boolean;
}

export type EdgeDBRagMessage =
| EdgeDBSystemMessage
| EdgeDBUserMessage
| EdgeDBAssistantMessage
| EdgeDBToolMessage;

export interface EdgeDBSystemMessage {
role: "system";
content: string;
}

export interface EdgeDBUserMessage {
role: "user";
content: { type: "text"; text: string }[];
}

export interface EdgeDBAssistantMessage {
role: "assistant";
content: string;
tool_calls?: {
id: string;
type: "function";
function: { name: string; arguments: string };
}[];
}

export interface EdgeDBToolMessage {
role: "tool";
content: string;
tool_call_id: string;
}

export type StreamingMessage = z.infer<typeof _edgedbRagChunkSchema>;

const _edgedbRagChunkSchema = z.discriminatedUnion("type", [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not exporting this Zod schema or using it in this file: why are we using a Zod schema instead of writing this as a TypeScript interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this in the vercel provider so it was easier to copy paste it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we export it then maybe? Not a blocker.

Copy link
Contributor Author

@diksipav diksipav Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can export it, I just dont understand/see when would someone need it?
we export:
export type StreamingMessage = z.infer<typeof _edgedbRagChunkSchema>;
which will user most probably use?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just dont understand/see when would someone need it?

Oh, I thought you mean that you are using a copy-pasted version of this schema in the Vercel provider, no? I'm saying export it here and import it in the Vercel provider instead of copy-pasting it. Or export it from the Vercel provider and import the type here. Whichever makes sense.

This is not a blocker for merging, it just feels a little strange here to have an essentially "private" Zod schema that isn't used for anything, but has to be kept in sync across two packages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm okay I can do this, even tho it makes sense and is better to have one source of truth it feels a bit weird to me to install the whole lib just because one type, but maybe that's usually the way to go in this kind of situations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I hadn't noticed: Does the Vercel AI SDK Provider not use the @edgedb/ai package for anything already?

z.object({
type: z.literal("message_start"),
message: z.object({
id: z.string(),
model: z.string(),
role: z.enum(["assistant", "system", "user"]),
usage: z
.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
})
.nullish(),
}),
}),
z.object({
type: z.literal("content_block_start"),
index: z.number(),
content_block: z.discriminatedUnion("type", [
z.object({
type: z.literal("text"),
text: z.string(),
}),
z.object({
type: z.literal("tool_use"),
id: z.string().nullish(),
name: z.string(),
args: z.string().nullish(),
}),
]),
}),
z.object({
type: z.literal("content_block_delta"),
index: z.number(),
delta: z.discriminatedUnion("type", [
z.object({
type: z.literal("text_delta"),
text: z.string(),
}),
z.object({
type: z.literal("tool_call_delta"),
args: z.string(), // partial json
}),
]),
logprobs: z
.object({
tokens: z.array(z.string()),
token_logprobs: z.array(z.number()),
top_logprobs: z.array(z.record(z.string(), z.number())).nullable(),
})
.nullish(),
}),
z.object({
type: z.literal("content_block_stop"),
index: z.number(),
}),
z.object({
type: z.literal("message_delta"),
delta: z.object({ stop_reason: z.string() }),
usage: z.object({ completion_tokens: z.number() }).nullish(),
}),
z.object({
type: z.literal("message_stop"),
}),
z.object({
type: z.literal("error"),
error: z.object({
type: z.string(),
message: z.string(),
}),
}),
]);
export type Prompt =
| { name: string }
| { id: string }
Expand Down
1 change: 1 addition & 0 deletions packages/driver/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"@js-temporal/polyfill": "0.4.3",
"@repo/tsconfig": "*",
"@types/jest": "^29.5.12",
"@types/node": "^22.7.5",
"@types/semver": "^7.5.8",
"@types/shell-quote": "^1.7.5",
"@types/which": "^3.0.3",
Expand Down
6 changes: 6 additions & 0 deletions packages/driver/src/cli.mts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ async function main(args: string[]) {
async function whichEdgeDbCli() {
debug("Checking if CLI is in PATH...");
const locations = await which("edgedb", { nothrow: true, all: true });

if (locations == null) {
debug(" - No CLI found in PATH.");
return null;
}

for (const location of locations) {
const actualLocation = await fs.realpath(location);
debug(
Expand Down
16 changes: 12 additions & 4 deletions packages/driver/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ export interface CryptoUtils {
const _tokens = new WeakMap<ResolvedConnectConfigReadonly, string>();

export type AuthenticatedFetch = (
path: string,
init: RequestInit,
path: RequestInfo | URL,
init?: RequestInit,
) => Promise<Response>;

export async function getAuthenticatedFetch(
Expand All @@ -94,10 +94,18 @@ export async function getAuthenticatedFetch(
_tokens.set(config, token);
}

return (path: string, init: RequestInit) => {
return (input: RequestInfo | URL, init?: RequestInit) => {
let path: string;

if (typeof input === "string") {
path = input;
} else if (input instanceof Request) {
path = input.url;
} else path = input.toString();

const url = new URL(path, databaseUrl);

const headers = new Headers(init.headers);
const headers = new Headers(init?.headers);

if (config.user !== undefined) {
headers.append("X-EdgeDB-User", config.user);
Expand Down
43 changes: 43 additions & 0 deletions packages/edgedb-rag-sdk/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# AI SDK - EdgeDB Provider

The provider for the [Vercel AI SDK](https://sdk.vercel.ai/docs) contains language model support for
the OpenAi, Mistral and Anthropic chat and completion APIs that implements EdgeDB RAG and embedding model support for the OpenAI and Mistral embeddings API.

## Setup

The EdgeDB provider is available in the `@edgedb/ai-sdk` module. You can install it with:

```bash
npm i @edgedb/ai-sdk
```

## Provider Instance

You can import the default provider instance `edgedbRag` from `@edgedb/ai-sdk`:

```ts
import { edgedbRag } from "@edgedb/ai-sdk";
```

## Example

```ts
import { generateText } from "ai";
import { createClient } from "edgedb";
import { edgedbRag } from "@edgedb/ai-sdk";

const textModel = (await edgedbRag).languageModel("gpt-4-turbo-preview");

const { text } = await generateText({
model: textModel.withSettings({
context: { query: "your context" },
}),
prompt: "your prompt",
});

console.log(text);
```

## Documentation

Please check out the **[EdgeDB provider documentation](https://docs.edgedb.com)** for more information.
60 changes: 60 additions & 0 deletions packages/edgedb-rag-sdk/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"name": "@edgedb/rag-sdk",
diksipav marked this conversation as resolved.
Show resolved Hide resolved
"version": "0.0.1",
"license": "Apache-2.0",
"sideEffects": false,
"main": "./dist/index.js",
"module": "./dist/index.mjs",
"types": "./dist/index.d.ts",
"files": [
"/dist"
],
"scripts": {
"build": "tsc --project tsconfig.json",
"clean": "rm -rf dist",
"lint": "eslint \"./**/*.ts*\"",
"type-check": "tsc --noEmit",
"prettier-check": "prettier --check \"./**/*.ts*\""
},
"exports": {
"./package.json": "./package.json",
".": {
"types": "./dist/index.d.ts",
"import": "./dist/index.mjs",
"require": "./dist/index.js"
}
},
"dependencies": {
"@ai-sdk/provider": "0.0.24",
"@ai-sdk/provider-utils": "^1.0.20"
},
"devDependencies": {
"@repo/tsconfig": "*",
"@types/node": "^18",
"tsup": "^8",
"typescript": "5.5.4",
"zod": "3.23.8",
"edgedb": "*"
},
"peerDependencies": {
"zod": "^3.0.0",
"edgedb": "^1.5.0"
},
"engines": {
"node": ">=18"
},
"publishConfig": {
"access": "public"
},
"homepage": "https://sdk.vercel.ai/docs",
"repository": {
"type": "git",
"url": "git+https://github.com/vercel/ai.git"
},
"bugs": {
"url": "https://github.com/vercel/ai/issues"
},
"keywords": [
"ai"
]
}
93 changes: 93 additions & 0 deletions packages/edgedb-rag-sdk/src/convert-to-edgedb-rag-messages.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import { type LanguageModelV1Prompt } from "@ai-sdk/provider";
import type { EdgeDBRagMessage } from "./edgedb-rag-settings";

export function convertToEdgeDBRagMessages(
prompt: LanguageModelV1Prompt,
): EdgeDBRagMessage[] {
const messages: EdgeDBRagMessage[] = [];

for (const { role, content } of prompt) {
switch (role) {
case "system": {
messages.push({ role: "system", content });
break;
}

case "user": {
messages.push({
role: "user",
content: content.map((part) => {
switch (part.type) {
case "text": {
return { type: "text", text: part.text };
}
default: {
throw new Error(`Unsupported type: ${part.type}`);
}
}
}),
});
break;
}

case "assistant": {
let text = "";
const toolCalls: {
id: string;
type: "function";
function: { name: string; arguments: string };
}[] = [];

for (const part of content) {
switch (part.type) {
case "text": {
text += part.text;
break;
}
case "tool-call": {
toolCalls.push({
id: part.toolCallId,
type: "function",
function: {
name: part.toolName,
arguments: JSON.stringify(part.args),
},
});
break;
}

default: {
const _exhaustiveCheck: never = part;
throw new Error(`Unsupported part: ${_exhaustiveCheck}`);
}
}
}

messages.push({
role: "assistant",
content: text,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined,
});
break;
}

case "tool": {
for (const toolResponse of content) {
messages.push({
role: "tool",
content: JSON.stringify(toolResponse.result),
tool_call_id: toolResponse.toolCallId,
});
}
break;
}

default: {
const _exhaustiveCheck: never = role;
throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
}
}
}

return messages;
}
Loading
Loading