Skip to content

Commit

Permalink
Merge pull request #7 from subquery/polish
Browse files Browse the repository at this point in the history
Polishing and minor changes
  • Loading branch information
stwiname authored Oct 17, 2024
2 parents eb943c2 + cf17ca9 commit 8884909
Show file tree
Hide file tree
Showing 17 changed files with 356 additions and 109 deletions.
5 changes: 3 additions & 2 deletions src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export async function runApp(config: {
port: number;
ipfs: IPFSClient;
forceReload?: boolean;
toolTimeout: number;
}): Promise<void> {
const model = new Ollama({ host: config.host });

Expand All @@ -30,7 +31,7 @@ export async function runApp(config: {
config.forceReload,
);

const sandbox = await getDefaultSandbox(loader);
const sandbox = await getDefaultSandbox(loader, config.toolTimeout);

const ctx = await makeContext(
sandbox,
Expand Down Expand Up @@ -79,7 +80,7 @@ async function makeContext(
if (!loadRes) throw new Error("Failed to load vector db");
const connection = await lancedb.connect(loadRes[0]);

return new Context(model, connection);
return new Context(model, connection, sandbox.manifest.embeddingsModel);
}

async function cli(runnerHost: RunnerHost): Promise<void> {
Expand Down
44 changes: 39 additions & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ import yargs, {

import { IPFSClient } from "./ipfs.ts";
import ora from "ora";
import { setSpinner } from "./util.ts";

import { getPrompt, setSpinner } from "./util.ts";
const DEFAULT_PORT = 7827;

const sharedArgs = {
Expand Down Expand Up @@ -81,6 +80,12 @@ yargs(Deno.args)
type: "boolean",
default: false,
},
toolTimeout: {
description:
"Set a limit for how long a tool can take to run, unit is MS",
type: "number",
default: 10_000, // 10s
},
},
async (argv) => {
try {
Expand All @@ -92,6 +97,7 @@ yargs(Deno.args)
port: argv.port,
ipfs: ipfsFromArgs(argv),
forceReload: argv.forceReload,
toolTimeout: argv.toolTimeout,
});
} catch (e) {
console.log(e);
Expand All @@ -112,7 +118,7 @@ yargs(Deno.args)
},
async (argv) => {
try {
const { projectInfo } = await import("./info.ts");
const { projectInfo } = await import("./subcommands/info.ts");
await projectInfo(argv.project, ipfsFromArgs(argv), argv.json);
Deno.exit(0);
} catch (e) {
Expand Down Expand Up @@ -177,7 +183,7 @@ yargs(Deno.args)
},
},
async (argv) => {
const { httpCli } = await import("./httpCli.ts");
const { httpCli } = await import("./subcommands/httpCli.ts");
await httpCli(argv.host);
},
)
Expand All @@ -193,7 +199,7 @@ yargs(Deno.args)
},
async (argv) => {
try {
const { publishProject } = await import("./bundle.ts");
const { publishProject } = await import("./subcommands/bundle.ts");
if (argv.silent) {
setSpinner(ora({ isSilent: true }));
}
Expand All @@ -211,6 +217,34 @@ yargs(Deno.args)
}
},
)
.command(
"init",
"Create a new project skeleton",
{
name: {
description:
"The name of your project, this will create a directory with that name.",
type: "string",
},
model: {
description: "The LLM model you wish to use",
type: "string",
},
},
async (argv) => {
try {
argv.name ??= getPrompt("Enter a project name: ");
argv.model ??= getPrompt("Enter a LLM model", "llama3.1");

const { initProject } = await import("./subcommands/init.ts");

await initProject({ name: argv.name, model: argv.model });
} catch (e) {
console.log(e);
Deno.exit(1);
}
},
)
// .fail(() => {}) // Disable logging --help if theres an error with a command // TODO need to fix so it only logs when error is with yargs
.help()
.argv;
2 changes: 1 addition & 1 deletion src/loader_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect } from "@std/expect/expect";
import { getOSTempDir, pullContent } from "./loader.ts";
import { resolve } from "@std/path/resolve";
import { IPFSClient } from "./ipfs.ts";
import { tarDir } from "./bundle.ts";
import { tarDir } from "./subcommands/bundle.ts";

const ipfs = new IPFSClient(
Deno.env.get("IPFS_ENDPOINT") ??
Expand Down
21 changes: 16 additions & 5 deletions src/project/project.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,23 @@ export const VectorConfig = Type.Object({

export const ProjectManifest = Type.Object({
specVersion: Type.Literal("0.0.1"),
model: Type.String(),
entry: Type.String(),
model: Type.String({ description: "The Ollama LLM model to be used" }),
embeddingsModel: Type.Optional(Type.String({
description: "The Ollama LLM model to be used for vector embeddings",
})),
entry: Type.String({
description: "File path to the project entrypoint",
}),
vectorStorage: Type.Optional(Type.Object({
type: Type.String(),
path: Type.String(),
type: Type.String({
description:
"The type of vector storage, currently only lancedb is supported.",
}),
path: Type.String({ description: "The path to the db" }),
})),
endpoints: Type.Optional(Type.Array(Type.String())),
endpoints: Type.Optional(Type.Array(Type.String({
description: "Allowed endpoints the tools are allowed to make requests to",
}))),
config: Type.Optional(Type.Any()), // TODO how can this be a JSON Schema type?
});

Expand All @@ -44,6 +54,7 @@ export const ProjectEntry = Type.Function(
Type.Union([Project, Type.Promise(Project)]),
);

export type FunctionToolType = Static<typeof FunctionToolType>;
export type ProjectManifest = Static<typeof ProjectManifest>;
export type Project = Static<typeof Project>;
export type ProjectEntry = Static<typeof ProjectEntry>;
Expand Down
18 changes: 11 additions & 7 deletions src/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ export class Runner {
// Run tools and use their responses
const toolResponses = await Promise.all(
(res.message.tool_calls ?? []).map(async (toolCall) => {
const res = await this.sandbox.runTool(
toolCall.function.name,
toolCall.function.arguments,
this.#context,
);

return res;
try {
return await this.sandbox.runTool(
toolCall.function.name,
toolCall.function.arguments,
this.#context,
);
} catch (e: unknown) {
console.error(`Tool call failed: ${e}`);
// Don't throw the error this will exit the application, instead pass the message back to the LLM
return (e as Error).message;
}
}),
);

Expand Down
3 changes: 2 additions & 1 deletion src/sandbox/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ export * from "./unsafeSandbox.ts";

export function getDefaultSandbox(
loader: Loader,
timeout: number,
): Promise<ISandbox> {
// return UnsafeSandbox.create(loader);

return WebWorkerSandbox.create(loader);
return WebWorkerSandbox.create(loader, timeout);
}

export { WebWorkerSandbox };
124 changes: 85 additions & 39 deletions src/sandbox/webWorker/webWorkerSandbox.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import {
CtxComputeQueryEmbedding,
CtxVectorSearch,
Init,
type IProjectJson,
Load,
} from "./messages.ts";
import {
extractConfigHostNames,
loadRawConfigFromEnv,
type Source,
timeout,
} from "../../util.ts";
import type { IContext } from "../../context/context.ts";
import type { ProjectManifest } from "../../project/project.ts";
Expand Down Expand Up @@ -57,13 +59,54 @@ function getPermisionsForSource(
}
}

export class WebWorkerSandbox implements ISandbox {
#connection: rpc.MessageConnection;
async function workerFactory(
manifest: ProjectManifest,
entryPath: string,
config: Record<string, string>,
permissions: Deno.PermissionOptionsObject,
): Promise<[Worker, rpc.MessageConnection, IProjectJson]> {
const w = new Worker(
import.meta.resolve("./webWorker.ts"),
{
type: "module",
deno: {
permissions: permissions,
},
},
);

// Setup a JSON RPC for interaction to the worker
const conn = rpc.createMessageConnection(
new BrowserMessageReader(w),
new BrowserMessageWriter(w),
);

conn.listen();

await conn.sendRequest(Load, entryPath);

const pJson = await conn.sendRequest(
Init,
manifest,
config,
);

return [w, conn, pJson];
}

export class WebWorkerSandbox implements ISandbox {
#tools: Tool[];
#initWorker: () => ReturnType<typeof workerFactory>;

/**
* Create a new WebWorkerSandbox
* @param loader The loader for loading any project resources
* @param timeout Tool call timeout in MS
* @returns A sandbox instance
*/
public static async create(
loader: Loader,
timeout: number,
): Promise<WebWorkerSandbox> {
const [manifestPath, manifest, source] = await loader.getManifest();
const config = loadRawConfigFromEnv(manifest.config);
Expand All @@ -78,75 +121,78 @@ export class WebWorkerSandbox implements ISandbox {
]),
];

const w = new Worker(
import.meta.resolve("./webWorker.ts"),
{
type: "module",
deno: {
permissions: {
...permissions,
env: false, // Should be passed through in loadRawConfigFromEnv
net: hostnames,
run: false,
write: false,
},
},
},
);

// Setup a JSON RPC for interaction to the worker
const conn = rpc.createMessageConnection(
new BrowserMessageReader(w),
new BrowserMessageWriter(w),
);

conn.listen();

const [entryPath] = await loader.getProject();
await conn.sendRequest(Load, entryPath);

const { tools, systemPrompt } = await conn.sendRequest(
Init,
manifest,
config,
);
const initProjectWorker = () =>
workerFactory(
manifest,
entryPath,
config as Record<string, string>,
{
...permissions,
env: false,
net: hostnames,
run: false,
write: false,
},
);

const [_worker, _conn, { tools, systemPrompt }] = await initProjectWorker();

return new WebWorkerSandbox(
conn,
manifest,
systemPrompt,
tools,
initProjectWorker,
timeout,
);
}

private constructor(
connection: rpc.MessageConnection,
readonly manifest: ProjectManifest,
readonly systemPrompt: string,
tools: Tool[],
initWorker: () => ReturnType<typeof workerFactory>,
readonly timeout: number = 100,
) {
this.#tools = tools;
this.#connection = connection;
this.#initWorker = initWorker;
}

// deno-lint-ignore require-await
async getTools(): Promise<Tool[]> {
return this.#tools;
}

runTool(toolName: string, args: unknown, ctx: IContext): Promise<string> {
async runTool(
toolName: string,
args: unknown,
ctx: IContext,
): Promise<string> {
// Create a worker just for the tool call, this is so we can terminate if it exceeds the timeout.
const [worker, conn] = await this.#initWorker();

// Connect up context so sandbox can call application
this.#connection.onRequest(CtxVectorSearch, async (tableName, vector) => {
conn.onRequest(CtxVectorSearch, async (tableName, vector) => {
const res = await ctx.vectorSearch(tableName, vector);

// lancedb returns classes (Apache Arrow - Struct Row). It needs to be made serializable
// This is done here as its specific to the webworker sandbox
return res.map((r) => JSON.parse(JSON.stringify(r)));
});
this.#connection.onRequest(CtxComputeQueryEmbedding, async (query) => {
conn.onRequest(CtxComputeQueryEmbedding, async (query) => {
return await ctx.computeQueryEmbedding(query);
});

return this.#connection.sendRequest(CallTool, toolName, args);
// Add timeout to the tool call, then clean up the worker.
return Promise.race([
timeout(this.timeout).then(() => {
throw new Error(`Timeout calling tool ${toolName}`);
}),
conn.sendRequest(CallTool, toolName, args),
]).finally(() => {
// Dispose of the worker, a new one will be created for each tool call
worker.terminate();
});
}
}
Loading

0 comments on commit 8884909

Please sign in to comment.