Skip to content

Commit

Permalink
Improve detection of Ollama or OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
stwiname committed Dec 18, 2024
1 parent 22605ef commit bae4561
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 41 deletions.
22 changes: 7 additions & 15 deletions src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import type { IPFSClient } from "./ipfs.ts";
import { Loader } from "./loader.ts";
import { getPrompt, getVersion } from "./util.ts";
import { getLogger } from "./logger.ts";
import { OllamaRunnerFactory } from "./runners/ollama.ts";
import { OpenAIRunnerFactory } from "./runners/openai.ts";
import { DEFAULT_LLM_HOST } from "./constants.ts";
import { createRunner } from "./runners/runner.ts";

const logger = await getLogger("app");

Expand All @@ -37,18 +35,12 @@ export async function runApp(config: {

const sandbox = await getDefaultSandbox(loader, config.toolTimeout);

const runnerFactory = sandbox.manifest.model.includes("gpt-")
? await OpenAIRunnerFactory.create(
config.host === DEFAULT_LLM_HOST ? undefined : config.host,
config.openAiApiKey,
sandbox,
loader,
)
: await OllamaRunnerFactory.create(
config.host,
sandbox,
loader,
);
const runnerFactory = await createRunner(
config.host,
sandbox,
loader,
config.openAiApiKey,
);

const runnerHost = new RunnerHost(() => {
const chatStorage = new MemoryChatStorage();
Expand Down
95 changes: 69 additions & 26 deletions src/runners/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ import type { IChatStorage } from "../chatStorage/index.ts";
import type { GenerateEmbedding } from "../embeddings/lance/writer.ts";
import OpenAI from "openai";
import { DEFAULT_LLM_HOST } from "../constants.ts";
import type { ISandbox } from "../sandbox/sandbox.ts";
import type { Loader } from "../loader.ts";
import { OpenAIRunnerFactory } from "./openai.ts";
import { OllamaRunnerFactory } from "./ollama.ts";

export interface IRunner {
prompt(message: string): Promise<string>;
Expand All @@ -13,32 +17,75 @@ export interface IRunnerFactory {
getRunner(chatStorage: IChatStorage): Promise<IRunner>;
}

export async function getGenerateFunction(
async function runForModels<T>(
runners: Record<string, () => Promise<T>>,
): Promise<T> {
const errors: Record<string, unknown> = {};
for (const [name, fn] of Object.entries(runners)) {
try {
return await fn();
} catch (e) {
errors[name] = e;
}
}

throw new Error(`All options failed to run:
\t${
Object.entries(errors).map(([name, error]) => `${name} error: ${error}`)
.join("\n\t")
}`);
}

export function createRunner(
endpoint: string,
sandbox: ISandbox,
loader: Loader,
openAiApiKey?: string,
): Promise<IRunnerFactory> {
return runForModels<IRunnerFactory>({
Ollama: () =>
OllamaRunnerFactory.create(
endpoint,
sandbox,
loader,
),
OpenAI: () =>
OpenAIRunnerFactory.create(
endpoint === DEFAULT_LLM_HOST ? undefined : endpoint,
openAiApiKey,
sandbox,
loader,
),
});
}

export function getGenerateFunction(
endpoint: string,
model: string,
apiKey?: string,
): Promise<GenerateEmbedding> {
try {
const ollama = new Ollama({ host: endpoint });
return runForModels<GenerateEmbedding>({
Ollama: async () => {
const ollama = new Ollama({ host: endpoint });

// If this throws then try OpenAI
await ollama.show({ model });
// If this throws then try OpenAI
await ollama.show({ model });

return async (input: string | string[], dimensions?: number) => {
const { embeddings } = await ollama.embed({ model, input });
// Ollama doesnt currentl allow specifying dimensions
// https://github.com/ollama/ollama/issues/651
if (dimensions != undefined && embeddings[0].length != dimensions) {
throw new Error(
`Dimensions mismatch, expected:"${dimensions}" received:"${
embeddings[0].length
}"`,
);
}
return embeddings;
};
} catch (ollamaError) {
try {
return async (input: string | string[], dimensions?: number) => {
const { embeddings } = await ollama.embed({ model, input });
// Ollama doesnt currentl allow specifying dimensions
// https://github.com/ollama/ollama/issues/651
if (dimensions != undefined && embeddings[0].length != dimensions) {
throw new Error(
`Dimensions mismatch, expected:"${dimensions}" received:"${
embeddings[0].length
}"`,
);
}
return embeddings;
};
},
OpenAI: async () => {
const openai = new OpenAI({
apiKey,
baseURL: endpoint === DEFAULT_LLM_HOST ? undefined : endpoint,
Expand All @@ -55,10 +102,6 @@ export async function getGenerateFunction(

return data.map((d) => d.embedding);
};
} catch (openAIError) {
throw new Error(`Unable to find model: ${model}.
Ollama error: ${ollamaError}
Openai error: ${openAIError}`);
}
}
},
});
}

0 comments on commit bae4561

Please sign in to comment.