Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ai rag provider #1129

Merged
merged 20 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions packages/ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ Creates an instance of `EdgeDBAI` with the specified client and options.

- `client`: An EdgeDB client instance.
- `options`: Configuration options for the AI model.
- `model`: Required. Specifies the AI model to use. This could be a version of GPT or any other model supported by EdgeDB AI.
- `prompt`: Optional. Defines the input prompt for the AI model. The prompt can be a simple string, an ID referencing a stored prompt, or a custom prompt structure that includes roles and content for more complex interactions. The default is the built-in system prompt.
- `model`: Required. Specifies the AI model to use. This could be some of the OpenAI, Mistral or Anthropic models supported by EdgeDB AI.
- `prompt`: Optional. Defines the input messages for the AI model. The prompt can have an `ID` or a `name` referencing a stored prompt. The referenced prompt will supply predefined messages. Optionally, include a custom list of messages using the `custom` field. These custom messages will be concatenated with messages from the stored prompt referenced by `id` or `name`. If no `id` or `name` is specified, only the `custom` messages will be used. If no `id`, `name`, or `custom` messages are provided, the built-in system prompt will be used by default.

### `EdgeDBAI`

Expand Down Expand Up @@ -52,6 +52,10 @@ Creates an instance of `EdgeDBAI` with the specified client and options.

Generates embeddings for the array of strings.

## Tool Calls

Tool calls are supported by the AI extension. They should be executed on the client side and tool call results should be provided back to the EdgeDB AI.

## Example

The following example demonstrates how to use the `@edgedb/ai` package to query an AI model about astronomy and chemistry.
Expand Down
105 changes: 66 additions & 39 deletions packages/ai/src/core.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import type { Client } from "edgedb";
import {
EventSourceParserStream,
type ParsedEvent,
} from "eventsource-parser/stream";
import { EventSourceParserStream } from "eventsource-parser/stream";

import type { ResolvedConnectConfig } from "edgedb/dist/conUtils.js";
import {
getAuthenticatedFetch,
type AuthenticatedFetch,
} from "edgedb/dist/utils.js";
import type {
AIOptions,
QueryContext,
RAGRequest,
StreamingMessage,
import {
type AIOptions,
type QueryContext,
type StreamingMessage,
type RagRequest,
isPromptRequest,
} from "./types.js";
import { getHTTPSCRAMAuth } from "edgedb/dist/httpScram.js";
import { cryptoUtils } from "edgedb/dist/browserCrypto.js";
import { extractMessageFromParsedEvent, handleResponseError } from "./utils.js";

export function createAI(client: Client, options: AIOptions) {
return new EdgeDBAI(client, options);
Expand Down Expand Up @@ -68,37 +67,70 @@ export class EdgeDBAI {
});
}

private async fetchRag(request: Omit<RAGRequest, "model" | "prompt">) {
private async fetchRag(request: RagRequest, context: QueryContext) {
const headers = request.stream
? { Accept: "text/event-stream", "Content-Type": "application/json" }
: { Accept: "application/json", "Content-Type": "application/json" };

if (request.prompt && request.initialMessages)
throw new Error(
"You can provide either a prompt or a messages array, not both.",
);

const messages = isPromptRequest(request)
? [
{
role: "user" as const,
content: [{ type: "text", text: request.prompt }],
},
]
: request.messages ?? [];

const providedPrompt =
this.options.prompt &&
("name" in this.options.prompt || "id" in this.options.prompt);

const response = await (
await this.authenticatedFetch
)("rag", {
method: "POST",
headers,
body: JSON.stringify({
...request,
context,
model: this.options.model,
prompt: this.options.prompt,
prompt: {
...this.options.prompt,
// if user provides prompt.custom without id/name it is his choice
// to not include default prompt msgs, but if user provides messages
// and doesn't provide prompt.custom, since we add messages to the
// prompt.custom we also have to include default prompt messages
...(!this.options.prompt?.custom &&
!providedPrompt && {
name: "builtin::rag-default",
}),
custom: [...(this.options.prompt?.custom || []), ...messages],
},
query: [...messages].reverse().find((msg) => msg.role === "user")!
.content[0].text,
}),
});

if (!response.ok) {
const bodyText = await response.text();
throw new Error(bodyText);
handleResponseError(response);
}

return response;
}

async queryRag(query: string, context = this.context): Promise<string> {
const res = await this.fetchRag({
async queryRag(request: RagRequest, context = this.context): Promise<string> {
const res = await this.fetchRag(
{
...request,
stream: false,
},
context,
query,
stream: false,
});
);

if (!res.headers.get("content-type")?.includes("application/json")) {
throw new Error(
Expand All @@ -114,28 +146,27 @@ export class EdgeDBAI {
typeof data.response !== "string"
) {
throw new Error(
"Expected response to be object with response key of type string",
"Expected response to be an object with response key of type string",
);
}

return data.response;
}

streamRag(
query: string,
request: RagRequest,
context = this.context,
): AsyncIterable<StreamingMessage> & PromiseLike<Response> {
const fetchRag = this.fetchRag.bind(this);

const ragOptions = {
context,
query,
stream: true,
};

return {
async *[Symbol.asyncIterator]() {
const res = await fetchRag(ragOptions);
const res = await fetchRag(
{
...request,
stream: true,
},
context,
);

if (!res.body) {
throw new Error("Expected response to include a body");
Expand Down Expand Up @@ -167,7 +198,13 @@ export class EdgeDBAI {
| undefined
| null,
): Promise<TResult1 | TResult2> {
return fetchRag(ragOptions).then(onfulfilled, onrejected);
return fetchRag(
{
...request,
stream: true,
},
context,
).then(onfulfilled, onrejected);
},
};
}
Expand Down Expand Up @@ -195,13 +232,3 @@ export class EdgeDBAI {
return data.data[0].embedding;
}
}

function extractMessageFromParsedEvent(
parsedEvent: ParsedEvent,
): StreamingMessage {
const { data } = parsedEvent;
if (!data) {
throw new Error("Expected SSE message to include a data payload");
}
return JSON.parse(data) as StreamingMessage;
}
116 changes: 97 additions & 19 deletions packages/ai/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
export type ChatParticipantRole = "system" | "user" | "assistant" | "tool";

export interface EdgeDBSystemMessage {
role: "system";
content: string;
}

export interface EdgeDBUserMessage {
role: "user";
content: { type: "text"; text: string }[];
}

export interface EdgeDBAssistantMessage {
role: "assistant";
content: string;
tool_calls?: {
id: string;
type: "function";
function: { name: string; arguments: string };
}[];
}

export interface EdgeDBToolMessage {
role: "tool";
content: string;
tool_call_id: string;
}

export type EdgeDBMessage =
| EdgeDBSystemMessage
| EdgeDBUserMessage
| EdgeDBAssistantMessage
| EdgeDBToolMessage;

export type Prompt =
| { name: string }
| { id: string }
| { custom: { role: ChatParticipantRole; content: string }[] };
| { name: string; custom?: EdgeDBMessage[] }
| { id: string; custom?: EdgeDBMessage[] }
| { custom: EdgeDBMessage[] };

export interface AIOptions {
model: string;
Expand All @@ -17,39 +49,73 @@ export interface QueryContext {
max_object_count?: number;
}

export interface RAGRequest {
model: string;
prompt?: Prompt;
context: QueryContext;
query: string;
interface RagRequestBase {
stream?: boolean;
[key: string]: unknown;
}

export type RagRequestPrompt = RagRequestBase & {
prompt: string;
};

export type RagRequestMessages = RagRequestBase & {
messages: EdgeDBMessage[];
};

export type RagRequest = RagRequestPrompt | RagRequestMessages;

export function isPromptRequest(
request: RagRequest,
): request is RagRequestPrompt {
return "prompt" in request;
}

export interface MessageStart {
type: "message_start";
message: {
role: "assistant" | "system" | "user";
id: string;
model: string;
role: "assistant" | "system" | "user"; //todo check this;
usage?: {
prompt_tokens: number;
completion_tokens: number;
} | null;
};
}

export interface ContentBlockStart {
type: "content_block_start";
index: number;
content_block: {
text: string;
type: "text";
};
content_block:
| {
type: "text";
text: string;
}
| {
type: "tool_use";
id?: string | null;
name: string;
args?: string | null;
};
}

export interface ContentBlockDelta {
type: "content_block_delta";
delta: {
type: "text_delta";
text: string;
};
index: number;
delta:
| {
type: "text_delta";
text: string;
}
| {
type: "tool_call_delta";
args: string;
};
logprobs?: {
tokens: string[];
token_logprobs: number[];
top_logprobs: Record<string, number>[] | null;
} | null;
}

export interface ContentBlockStop {
Expand All @@ -60,18 +126,30 @@ export interface ContentBlockStop {
export interface MessageDelta {
type: "message_delta";
delta: {
stop_reason: "stop";
stop_reason: string;
};
usage?: {
completion_tokens: number;
};
}

export interface MessageStop {
type: "message_stop";
}

export interface MessageError {
type: "error";
error: {
type: string;
message: string;
};
}

export type StreamingMessage =
| MessageStart
| ContentBlockStart
| ContentBlockDelta
| ContentBlockStop
| MessageDelta
| MessageStop;
| MessageStop
| MessageError;
30 changes: 30 additions & 0 deletions packages/ai/src/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import type { ParsedEvent } from "eventsource-parser";
import type { StreamingMessage } from "./types.js";

export function extractMessageFromParsedEvent(
parsedEvent: ParsedEvent,
): StreamingMessage {
const { data } = parsedEvent;
if (!data) {
throw new Error("Expected SSE message to include a data payload");
}
return JSON.parse(data) as StreamingMessage;
}

export async function handleResponseError(response: Response) {
const contentType = response.headers.get("content-type");
let errorMessage: string;

if (contentType && contentType.includes("application/json")) {
const json = await response.json();

errorMessage =
typeof json === "object" && json != null && "message" in json
? json.message
: `An error occurred: ${json}`;
} else {
const bodyText = await response.text();
errorMessage = bodyText || "An unknown error occurred";
}
throw new Error(errorMessage);
}
Loading
Loading