Skip to content

Commit

Permalink
Update provider API with provider settings
Browse files Browse the repository at this point in the history
  • Loading branch information
diksipav committed Dec 3, 2024
1 parent 807e866 commit 48d11e8
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 28 deletions.
23 changes: 6 additions & 17 deletions packages/vercel-ai-provider/src/edgedb-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
createJsonResponseHandler,
postJsonToApi,
generateId,
combineHeaders,
} from "@ai-sdk/provider-utils";
import {
type EdgeDBChatConfig,
Expand All @@ -30,11 +31,7 @@ import {
import { convertToEdgeDBMessages } from "./convert-to-edgedb-messages";
import { prepareTools } from "./edgedb-prepare-tools";

export interface EdgeDBLanguageModel extends LanguageModelV1 {
withSettings(settings: Partial<EdgeDBChatSettings>): EdgeDBChatLanguageModel;
}

export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
export class EdgeDBChatLanguageModel implements LanguageModelV1 {
readonly specificationVersion = "v1";
readonly defaultObjectGenerationMode = "json";
readonly supportsImageUrls = false;
Expand All @@ -58,14 +55,6 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
return this.config.provider;
}

withSettings(settings: Partial<EdgeDBChatSettings>) {
return new EdgeDBChatLanguageModel(
this.modelId,
{ ...this.settings, ...settings },
this.config,
);
}

private getArgs({
// it's not really deprecated since the v2 is not out yet that accepts toolChoice, and tools at the top level
mode,
Expand Down Expand Up @@ -217,8 +206,8 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
const { messages } = args;

const { responseHeaders, value: response } = await postJsonToApi({
url: `rag`,
headers: options.headers,
url: this.config.baseURL ? `${this.config.baseURL}/rag` : "rag",
headers: combineHeaders(this.config.headers(), options.headers),
body: {
...args,
context: this.settings.context,
Expand Down Expand Up @@ -266,8 +255,8 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
const { messages } = args;

const { responseHeaders, value: response } = await postJsonToApi({
url: `rag`,
headers: options.headers,
url: this.config.baseURL ? `${this.config.baseURL}/rag` : "rag",
headers: combineHeaders(this.config.headers(), options.headers),
body: {
...args,
context: this.settings.context,
Expand Down
2 changes: 2 additions & 0 deletions packages/vercel-ai-provider/src/edgedb-chat-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ export interface QueryContext {
export interface EdgeDBChatConfig {
provider: string;
fetch: FetchFunction;
baseURL: string;
headers: () => Record<string, string | undefined>;
}

export interface EdgeDBChatSettings {
Expand Down
9 changes: 7 additions & 2 deletions packages/vercel-ai-provider/src/edgedb-embedding-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
createJsonResponseHandler,
type FetchFunction,
postJsonToApi,
combineHeaders,
} from "@ai-sdk/provider-utils";
import { z } from "zod";
import {
Expand All @@ -18,6 +19,8 @@ import { edgedbFailedResponseHandler } from "./edgedb-error";
interface EdgeDBEmbeddingConfig {
provider: string;
fetch?: FetchFunction;
baseURL: string;
headers: () => Record<string, string | undefined>;
}

export class EdgeDBEmbeddingModel implements EmbeddingModelV1<string> {
Expand Down Expand Up @@ -71,8 +74,10 @@ export class EdgeDBEmbeddingModel implements EmbeddingModelV1<string> {
}

const { responseHeaders, value: response } = await postJsonToApi({
url: `embeddings`,
headers,
url: this.config.baseURL
? `${this.config.baseURL}/embeddings`
: "embeddings",
headers: combineHeaders(this.config.headers(), headers),
body: {
model: this.modelId,
input: values,
Expand Down
47 changes: 38 additions & 9 deletions packages/vercel-ai-provider/src/edgedb-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import type {
LanguageModelV1,
ProviderV1,
} from "@ai-sdk/provider";
import {
EdgeDBChatLanguageModel,
type EdgeDBLanguageModel,
} from "./edgedb-chat-language-model";
import { withoutTrailingSlash } from "@ai-sdk/provider-utils";
import { EdgeDBChatLanguageModel } from "./edgedb-chat-language-model";
import type {
EdgeDBChatModelId,
EdgeDBChatSettings,
Expand All @@ -24,21 +22,45 @@ import type {
const httpSCRAMAuth = getHTTPSCRAMAuth(cryptoUtils);

export interface EdgeDBProvider extends ProviderV1 {
(modelId: EdgeDBChatModelId | EdgeDBEmbeddingModelId): LanguageModelV1;
(
modelId: EdgeDBChatModelId | EdgeDBEmbeddingModelId,
settings?: EdgeDBChatSettings,
): LanguageModelV1;

languageModel(
modelId: EdgeDBChatModelId,
settings?: EdgeDBChatSettings,
): EdgeDBLanguageModel;
): EdgeDBChatLanguageModel;

textEmbeddingModel: (
modelId: EdgeDBEmbeddingModelId,
settings?: EdgeDBEmbeddingSettings,
) => EmbeddingModelV1<string>;
}

export async function createEdgeDB(client: Client): Promise<EdgeDBProvider> {
export interface EdgeDBProviderSettings {
/**
Use a different URL prefix for API calls, e.g. to use proxy servers.
The default prefix is `https://api.mistral.ai/v1`.
*/
baseURL?: string;

/**
Custom headers to include in the requests.
*/
headers?: Record<string, string>;
}

export async function createEdgeDB(
client: Client,
options: EdgeDBProviderSettings = {},
): Promise<EdgeDBProvider> {
const connectConfig = await client.resolveConnectionParams();
const baseURL = withoutTrailingSlash(options.baseURL) ?? "";

const getHeaders = () => ({
...options.headers,
});

const fetch = await getAuthenticatedFetch(
connectConfig,
Expand All @@ -53,6 +75,8 @@ export async function createEdgeDB(client: Client): Promise<EdgeDBProvider> {
new EdgeDBChatLanguageModel(modelId, settings, {
provider: "edgedb.chat",
fetch,
baseURL,
headers: getHeaders,
});

const createEmbeddingModel = (
Expand All @@ -62,17 +86,22 @@ export async function createEdgeDB(client: Client): Promise<EdgeDBProvider> {
return new EdgeDBEmbeddingModel(modelId, settings, {
provider: "edgedb.embedding",
fetch,
baseURL,
headers: getHeaders,
});
};

const provider = function (modelId: EdgeDBChatModelId) {
const provider = function (
modelId: EdgeDBChatModelId,
settings?: EdgeDBChatSettings,
) {
if (new.target) {
throw new Error(
"The EdgeDB model function cannot be called with the new keyword.",
);
}

return createChatModel(modelId);
return createChatModel(modelId, settings);
};

provider.languageModel = createChatModel;
Expand Down

0 comments on commit 48d11e8

Please sign in to comment.