Skip to content

Commit

Permalink
Update rag generation to support multiple models
Browse files Browse the repository at this point in the history
  • Loading branch information
stwiname committed Dec 17, 2024
1 parent a79548f commit 4012d77
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/context/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import * as lancedb from "@lancedb/lancedb";

const logger = await getLogger("ToolContext");

type GetEmbedding = (input: string | string[]) => Promise<number[]>;
type GetEmbedding = (input: string /* | string[] */) => Promise<number[]>;

export class Context implements IContext {
#getEmbedding: GetEmbedding;
Expand Down
8 changes: 3 additions & 5 deletions src/embeddings/generator/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ import {
type BaseEmbeddingSource,
MarkdownEmbeddingSource,
} from "./mdSource.ts";
import ollama from "ollama";
import { glob } from "glob";
import { LanceWriter } from "../lance/index.ts";
import { type GenerateEmbedding, LanceWriter } from "../lance/index.ts";
import { getLogger } from "../../logger.ts";
import { getSpinner } from "../../util.ts";

Expand All @@ -20,8 +19,8 @@ export async function generate(
path: string,
lanceDbPath: string,
tableName: string,
generateEmbedding: GenerateEmbedding,
ignoredPaths = DEFAULT_IGNORED_PATHS,
model = "nomic-embed-text",
overwrite = false,
) {
const embeddingSources: BaseEmbeddingSource[] =
Expand All @@ -37,8 +36,7 @@ export async function generate(
const lanceWriter = await LanceWriter.createNewTable(
lanceDbPath,
tableName,
ollama,
model,
generateEmbedding,
overwrite,
);

Expand Down
24 changes: 10 additions & 14 deletions src/embeddings/lance/writer.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import * as lancedb from "@lancedb/lancedb";
import { Field, FixedSizeList, Float64, Schema, Utf8 } from "apache-arrow";
import type { IEmbeddingWriter } from "../embeddings.ts";
import ollama, { type Ollama } from "ollama";

export type GenerateEmbedding = (
input: string | string[],
) => Promise<number[][]>;

export class LanceWriter implements IEmbeddingWriter {
#table: lancedb.Table;
#model: Ollama;
#embedModel: string;
#generateEmbedding: GenerateEmbedding;

static #dim = 768;

Expand All @@ -21,19 +23,16 @@ export class LanceWriter implements IEmbeddingWriter {

constructor(
table: lancedb.Table,
model: Ollama,
embedModel = "nomic-embed-text",
generateEmbedding: GenerateEmbedding,
) {
this.#table = table;
this.#model = model;
this.#embedModel = embedModel;
this.#generateEmbedding = generateEmbedding;
}

static async createNewTable(
dbPath: string,
tableName: string,
model: Ollama = ollama,
embedModel = "nomic-embed-text",
generateEmbedding: GenerateEmbedding,
overwrite = false,
): Promise<LanceWriter> {
const db = await lancedb.connect(dbPath);
Expand All @@ -44,14 +43,11 @@ export class LanceWriter implements IEmbeddingWriter {
{ mode: overwrite ? "overwrite" : "create" },
);

return new LanceWriter(table, model, embedModel);
return new LanceWriter(table, generateEmbedding);
}

async write(input: string | string[]): Promise<void> {
const { embeddings } = await this.#model.embed({
model: this.#embedModel,
input,
});
const embeddings = await this.#generateEmbedding(input);

const inputArr = Array.isArray(input) ? input : [input];
const data = inputArr.map((input, idx) => {
Expand Down
40 changes: 26 additions & 14 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import ora from "ora";
import { getPrompt, getVersion, setSpinner } from "./util.ts";
import { initLogger } from "./logger.ts";
import { DEFAULT_LLM_HOST, DEFAULT_PORT } from "./constants.ts";
import { getGenerateFunction } from "./runners/runner.ts";

const sharedArgs = {
project: {
Expand All @@ -46,6 +47,21 @@ const sharedArgs = {
},
} satisfies Record<string, Options>;

const llmHostArgs = {
host: {
alias: "h",
description:
"The LLM RPC host. If the project model uses ChatGPT then the default value is not used.",
default: DEFAULT_LLM_HOST,
type: "string",
},
openAiApiKey: {
description:
"If the project models use OpenAI models, then this api key will be parsed on to the OpenAI client",
type: "string",
},
} satisfies Record<string, Options>;

const debugArgs = {
debug: {
description: "Enable debug logging",
Expand Down Expand Up @@ -78,13 +94,7 @@ yargs(Deno.args)
{
...sharedArgs,
...debugArgs,
host: {
alias: "h",
description:
"The LLM RPC host. If the project model uses ChatGPT then the default value is not used.",
default: DEFAULT_LLM_HOST,
type: "string",
},
...llmHostArgs,
interface: {
alias: "i",
description: "The interface to interact with the app",
Expand Down Expand Up @@ -116,11 +126,6 @@ yargs(Deno.args)
type: "number",
default: 5_000, // 5s
},
openAiApiKey: {
description:
"If the project models use OpenAI models, then this api key will be parsed on to the OpenAI client",
type: "string",
},
},
async (argv) => {
try {
Expand Down Expand Up @@ -181,6 +186,7 @@ yargs(Deno.args)
"Creates a Lance db table with embeddings from MDX files",
{
...debugArgs,
...llmHostArgs,
input: {
alias: "i",
description: "Path to a directory containing MD or MDX files",
Expand Down Expand Up @@ -208,7 +214,7 @@ yargs(Deno.args)
model: {
description:
"The embedding LLM model to use, this should be the same as embeddingsModel in your app manifest",
default: "nomic-embed-text",
required: true,
type: "string",
},
overwrite: {
Expand All @@ -226,12 +232,18 @@ yargs(Deno.args)
const { generate } = await import(
"./embeddings/generator/generator.ts"
);

const generateFunction = await getGenerateFunction(
argv.host,
argv.model,
argv.openAiApiKey,
);
return await generate(
resolve(argv.input),
resolve(argv.output),
argv.table,
generateFunction,
argv.ignoredFiles?.map((f) => resolve(f)),
argv.model,
argv.overwrite,
);
} catch (e) {
Expand Down
20 changes: 16 additions & 4 deletions src/runners/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,28 @@ export class OllamaRunnerFactory implements IRunnerFactory {
await ollama.show({ model: sandbox.manifest.embeddingsModel });
}

return new OllamaRunnerFactory(ollama, sandbox, loader);
const factory = new OllamaRunnerFactory(ollama, sandbox, loader);

// Makes sure vectorStorage is loaded
await factory.getContext();

return factory;
}

async runEmbedding(input: string | string[]): Promise<number[]> {
const { embeddings: [embedding] } = await this.#ollama.embed({
async runEmbedding(input: string): Promise<number[]> {
const { embeddings: [embed] } = await this.#ollama.embed({
model: this.#sandbox.manifest.embeddingsModel ?? "nomic-embed-text",
input,
});

return embedding;
return embed;

// const { embedding } = await this.#ollama.embeddings({
// model: this.#sandbox.manifest.embeddingsModel ?? "nomic-embed-text",
// prompt: input,
// });

// return embedding;
}

@Memoize()
Expand Down
7 changes: 6 additions & 1 deletion src/runners/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ export class OpenAIRunnerFactory implements IRunnerFactory {
await openai.models.retrieve(sandbox.manifest.embeddingsModel);
}

return new OpenAIRunnerFactory(
const factory = new OpenAIRunnerFactory(
openai,
sandbox,
loader,
);

// Makes sure vector storage is loaded
await factory.getContext();

return factory;
}

async runEmbedding(input: string | string[]): Promise<number[]> {
Expand Down
44 changes: 43 additions & 1 deletion src/runners/runner.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type { ChatResponse, Message } from "ollama";
import { type ChatResponse, type Message, Ollama } from "ollama";
import type { IChatStorage } from "../chatStorage/index.ts";
import type { GenerateEmbedding } from "../embeddings/lance/writer.ts";
import OpenAI from "openai";

export interface IRunner {
prompt(message: string): Promise<string>;
Expand All @@ -9,3 +11,43 @@ export interface IRunner {
export interface IRunnerFactory {
getRunner(chatStorage: IChatStorage): Promise<IRunner>;
}

export async function getGenerateFunction(
endpoint: string,
model: string,
apiKey?: string,
): Promise<GenerateEmbedding> {
try {
const ollama = new Ollama({ host: endpoint });

// If this throws then try OpenAI
await ollama.show({ model });

return async (input: string | string[]) => {
const { embeddings } = await ollama.embed({ model, input });
return embeddings;
};
} catch (ollamaError) {
try {
const openai = new OpenAI({
apiKey,
baseURL: endpoint,
});

await openai.models.retrieve(model);

return async (input: string | string[]) => {
const { data } = await openai.embeddings.create({
model,
input,
});

return data.map((d) => d.embedding);
};
} catch (openAIError) {
throw new Error(`Unable to find model: ${model}.
Ollama error: ${ollamaError}
Openai error: ${openAIError}`);
}
}
}

0 comments on commit 4012d77

Please sign in to comment.