Skip to content

Commit

Permalink
Accept dimensions & user for embedding calls
Browse files Browse the repository at this point in the history
  • Loading branch information
diksipav committed Dec 10, 2024
1 parent 48d11e8 commit dfd28ca
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 24 deletions.
7 changes: 4 additions & 3 deletions packages/ai/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
type QueryContext,
type StreamingMessage,
type RagRequest,
type EmbeddingRequest,
isPromptRequest,
} from "./types.js";
import { getHTTPSCRAMAuth } from "edgedb/dist/httpScram.js";
Expand Down Expand Up @@ -209,7 +210,7 @@ export class EdgeDBAI {
};
}

async generateEmbeddings(inputs: string[], model: string): Promise<number[]> {
async generateEmbeddings(request: EmbeddingRequest): Promise<number[]> {
const response = await (
await this.authenticatedFetch
)("embeddings", {
Expand All @@ -218,8 +219,8 @@ export class EdgeDBAI {
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
input: inputs,
...request,
input: request.inputs,
}),
});

Expand Down
20 changes: 12 additions & 8 deletions packages/ai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,15 @@ export interface QueryContext {
max_object_count?: number;
}

interface RagRequestBase {
stream?: boolean;
export interface RagRequestPrompt {
prompt: string;
[key: string]: unknown;
}

export type RagRequestPrompt = RagRequestBase & {
prompt: string;
};

export type RagRequestMessages = RagRequestBase & {
export interface RagRequestMessages {
messages: EdgeDBMessage[];
};
[key: string]: unknown;
}

export type RagRequest = RagRequestPrompt | RagRequestMessages;

Expand Down Expand Up @@ -153,3 +150,10 @@ export type StreamingMessage =
| MessageDelta
| MessageStop
| MessageError;

export interface EmbeddingRequest {
inputs: string[];
model: string;
dimensions?: number;
user?: string;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import type {
} from "@ai-sdk/provider";
import {
type ParseResult,
type FetchFunction,
createEventSourceResponseHandler,
createJsonResponseHandler,
postJsonToApi,
generateId,
combineHeaders,
} from "@ai-sdk/provider-utils";
import {
type EdgeDBChatConfig,
type EdgeDBChatModelId,
type EdgeDBChatSettings,
type EdgeDBMessage,
Expand All @@ -31,6 +31,13 @@ import {
import { convertToEdgeDBMessages } from "./convert-to-edgedb-messages";
import { prepareTools } from "./edgedb-prepare-tools";

export interface EdgeDBChatConfig {
provider: string;
fetch: FetchFunction;
baseURL: string | null;
headers: () => Record<string, string | undefined>;
}

export class EdgeDBChatLanguageModel implements LanguageModelV1 {
readonly specificationVersion = "v1";
readonly defaultObjectGenerationMode = "json";
Expand Down
9 changes: 0 additions & 9 deletions packages/vercel-ai-provider/src/edgedb-chat-settings.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import type { FetchFunction } from "@ai-sdk/provider-utils";

export type OpenAIModelId =
| "gpt-4o"
| "gpt-4o-mini"
Expand Down Expand Up @@ -67,13 +65,6 @@ export interface QueryContext {
max_object_count?: number;
}

export interface EdgeDBChatConfig {
provider: string;
fetch: FetchFunction;
baseURL: string;
headers: () => Record<string, string | undefined>;
}

export interface EdgeDBChatSettings {
context?: QueryContext;
prompt?: Prompt;
Expand Down
3 changes: 2 additions & 1 deletion packages/vercel-ai-provider/src/edgedb-embedding-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { edgedbFailedResponseHandler } from "./edgedb-error";
interface EdgeDBEmbeddingConfig {
provider: string;
fetch?: FetchFunction;
baseURL: string;
baseURL: string | null;
headers: () => Record<string, string | undefined>;
}

Expand Down Expand Up @@ -82,6 +82,7 @@ export class EdgeDBEmbeddingModel implements EmbeddingModelV1<string> {
model: this.modelId,
input: values,
encoding_format: "float",
// OpenAI props
dimensions: this.settings.dimensions,
user: this.settings.user,
},
Expand Down
4 changes: 2 additions & 2 deletions packages/vercel-ai-provider/src/edgedb-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ export interface EdgeDBProvider extends ProviderV1 {
export interface EdgeDBProviderSettings {
/**
Use a different URL prefix for API calls, e.g. to use proxy servers.
The default prefix is `https://api.mistral.ai/v1`.
*/
baseURL?: string;

Expand All @@ -56,8 +55,9 @@ export async function createEdgeDB(
options: EdgeDBProviderSettings = {},
): Promise<EdgeDBProvider> {
const connectConfig = await client.resolveConnectionParams();
const baseURL = withoutTrailingSlash(options.baseURL) ?? "";
const baseURL = withoutTrailingSlash(options.baseURL) ?? null;

// In case we want to add more things to this in the future
const getHeaders = () => ({
...options.headers,
});
Expand Down

0 comments on commit dfd28ca

Please sign in to comment.