From 48d11e84916ba636c13e7207dac4750d401ed284 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Tue, 3 Dec 2024 10:49:20 +0100 Subject: [PATCH] Update provider API with provider settings --- .../src/edgedb-chat-language-model.ts | 23 +++------ .../src/edgedb-chat-settings.ts | 2 + .../src/edgedb-embedding-model.ts | 9 +++- .../vercel-ai-provider/src/edgedb-provider.ts | 47 +++++++++++++++---- 4 files changed, 53 insertions(+), 28 deletions(-) diff --git a/packages/vercel-ai-provider/src/edgedb-chat-language-model.ts b/packages/vercel-ai-provider/src/edgedb-chat-language-model.ts index 6a8a50f2a..f37ebd4e4 100644 --- a/packages/vercel-ai-provider/src/edgedb-chat-language-model.ts +++ b/packages/vercel-ai-provider/src/edgedb-chat-language-model.ts @@ -13,6 +13,7 @@ import { createJsonResponseHandler, postJsonToApi, generateId, + combineHeaders, } from "@ai-sdk/provider-utils"; import { type EdgeDBChatConfig, @@ -30,11 +31,7 @@ import { import { convertToEdgeDBMessages } from "./convert-to-edgedb-messages"; import { prepareTools } from "./edgedb-prepare-tools"; -export interface EdgeDBLanguageModel extends LanguageModelV1 { - withSettings(settings: Partial): EdgeDBChatLanguageModel; -} - -export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel { +export class EdgeDBChatLanguageModel implements LanguageModelV1 { readonly specificationVersion = "v1"; readonly defaultObjectGenerationMode = "json"; readonly supportsImageUrls = false; @@ -58,14 +55,6 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel { return this.config.provider; } - withSettings(settings: Partial) { - return new EdgeDBChatLanguageModel( - this.modelId, - { ...this.settings, ...settings }, - this.config, - ); - } - private getArgs({ // it's not really deprecated since the v2 is not out yet that accepts toolChoice, and tools at the top level mode, @@ -217,8 +206,8 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel { const { messages } = args; const { responseHeaders, value: response } = await postJsonToApi({ - url: `rag`, - headers: options.headers, + url: this.config.baseURL ? `${this.config.baseURL}/rag` : "rag", + headers: combineHeaders(this.config.headers(), options.headers), body: { ...args, context: this.settings.context, @@ -266,8 +255,8 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel { const { messages } = args; const { responseHeaders, value: response } = await postJsonToApi({ - url: `rag`, - headers: options.headers, + url: this.config.baseURL ? `${this.config.baseURL}/rag` : "rag", + headers: combineHeaders(this.config.headers(), options.headers), body: { ...args, context: this.settings.context, diff --git a/packages/vercel-ai-provider/src/edgedb-chat-settings.ts b/packages/vercel-ai-provider/src/edgedb-chat-settings.ts index 98acf6078..a36e15b6b 100644 --- a/packages/vercel-ai-provider/src/edgedb-chat-settings.ts +++ b/packages/vercel-ai-provider/src/edgedb-chat-settings.ts @@ -70,6 +70,8 @@ export interface QueryContext { export interface EdgeDBChatConfig { provider: string; fetch: FetchFunction; + baseURL: string; + headers: () => Record; } export interface EdgeDBChatSettings { diff --git a/packages/vercel-ai-provider/src/edgedb-embedding-model.ts b/packages/vercel-ai-provider/src/edgedb-embedding-model.ts index 87f48208d..0a63f33d8 100644 --- a/packages/vercel-ai-provider/src/edgedb-embedding-model.ts +++ b/packages/vercel-ai-provider/src/edgedb-embedding-model.ts @@ -6,6 +6,7 @@ import { createJsonResponseHandler, type FetchFunction, postJsonToApi, + combineHeaders, } from "@ai-sdk/provider-utils"; import { z } from "zod"; import { @@ -18,6 +19,8 @@ import { edgedbFailedResponseHandler } from "./edgedb-error"; interface EdgeDBEmbeddingConfig { provider: string; fetch?: FetchFunction; + baseURL: string; + headers: () => Record; } export class EdgeDBEmbeddingModel implements EmbeddingModelV1 { @@ -71,8 +74,10 @@ export class EdgeDBEmbeddingModel implements EmbeddingModelV1 { } const { responseHeaders, value: response } = await postJsonToApi({ - url: `embeddings`, - headers, + url: this.config.baseURL + ? `${this.config.baseURL}/embeddings` + : "embeddings", + headers: combineHeaders(this.config.headers(), headers), body: { model: this.modelId, input: values, diff --git a/packages/vercel-ai-provider/src/edgedb-provider.ts b/packages/vercel-ai-provider/src/edgedb-provider.ts index d1b8ec427..e3321516f 100644 --- a/packages/vercel-ai-provider/src/edgedb-provider.ts +++ b/packages/vercel-ai-provider/src/edgedb-provider.ts @@ -7,10 +7,8 @@ import type { LanguageModelV1, ProviderV1, } from "@ai-sdk/provider"; -import { - EdgeDBChatLanguageModel, - type EdgeDBLanguageModel, -} from "./edgedb-chat-language-model"; +import { withoutTrailingSlash } from "@ai-sdk/provider-utils"; +import { EdgeDBChatLanguageModel } from "./edgedb-chat-language-model"; import type { EdgeDBChatModelId, EdgeDBChatSettings, @@ -24,12 +22,15 @@ import type { const httpSCRAMAuth = getHTTPSCRAMAuth(cryptoUtils); export interface EdgeDBProvider extends ProviderV1 { - (modelId: EdgeDBChatModelId | EdgeDBEmbeddingModelId): LanguageModelV1; + ( + modelId: EdgeDBChatModelId | EdgeDBEmbeddingModelId, + settings?: EdgeDBChatSettings, + ): LanguageModelV1; languageModel( modelId: EdgeDBChatModelId, settings?: EdgeDBChatSettings, - ): EdgeDBLanguageModel; + ): EdgeDBChatLanguageModel; textEmbeddingModel: ( modelId: EdgeDBEmbeddingModelId, @@ -37,8 +38,29 @@ export interface EdgeDBProvider extends ProviderV1 { ) => EmbeddingModelV1; } -export async function createEdgeDB(client: Client): Promise { +export interface EdgeDBProviderSettings { + /** +Use a different URL prefix for API calls, e.g. to use proxy servers. +The default prefix is `https://api.mistral.ai/v1`. + */ + baseURL?: string; + + /** +Custom headers to include in the requests. + */ + headers?: Record; +} + +export async function createEdgeDB( + client: Client, + options: EdgeDBProviderSettings = {}, +): Promise { const connectConfig = await client.resolveConnectionParams(); + const baseURL = withoutTrailingSlash(options.baseURL) ?? ""; + + const getHeaders = () => ({ + ...options.headers, + }); const fetch = await getAuthenticatedFetch( connectConfig, @@ -53,6 +75,8 @@ export async function createEdgeDB(client: Client): Promise { new EdgeDBChatLanguageModel(modelId, settings, { provider: "edgedb.chat", fetch, + baseURL, + headers: getHeaders, }); const createEmbeddingModel = ( @@ -62,17 +86,22 @@ export async function createEdgeDB(client: Client): Promise { return new EdgeDBEmbeddingModel(modelId, settings, { provider: "edgedb.embedding", fetch, + baseURL, + headers: getHeaders, }); }; - const provider = function (modelId: EdgeDBChatModelId) { + const provider = function ( + modelId: EdgeDBChatModelId, + settings?: EdgeDBChatSettings, + ) { if (new.target) { throw new Error( "The EdgeDB model function cannot be called with the new keyword.", ); } - return createChatModel(modelId); + return createChatModel(modelId, settings); }; provider.languageModel = createChatModel;