Skip to content

Commit

Permalink
Cleanup and prepare PR for the new providers
Browse files Browse the repository at this point in the history
  • Loading branch information
synw committed Aug 30, 2023
1 parent 17cb1d7 commit 0cd66c8
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 21 deletions.
11 changes: 1 addition & 10 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,6 @@
"default": 60000,
"description": "Request timeout in milliseconds"
},
"wingman.defaultCtx": {
"type": "number",
"default": 2048,
"description": "Default context window length for the model (a fallback if the command does not specify)"
},
"wingman.defaultTemplate": {
"type": "string",
"default": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n### User:\n{prompt}\n\n### Assistant:",
"description": "Default template (a fallback if the command does not specify)"
},
"wingman.anthropic.model": {
"type": "string",
"default": "claude-instant-v1",
Expand Down Expand Up @@ -360,6 +350,7 @@
"cheerio": "1.0.0-rc.12",
"fast-glob": "^3.2.12",
"fetch": "^1.1.0",
"llama-tokenizer-js": "^1.1.3",
"node-fetch": "^3.3.1"
}
}
32 changes: 31 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function createCommandMap(templates: Command[]) {
const allCommands = templates.map((template) => {
return {
...template,
provider: template.provider ?? AIProvider.KoboldCpp,
provider: template.provider ?? AIProvider.OpenAI,
command: template.command ? `wingman.command.${template.command}` : `wingman.command.${generateCommandName(template)}-${randomString()}`,
category: template.category ?? BuiltinCategory.Misc,
};
Expand Down
9 changes: 5 additions & 4 deletions src/providers/goinfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import * as vscode from "vscode";

import { type PostableViewProvider, type ProviderResponse, type Provider } from ".";
import { Client, type InferParams, type InferResult, type StreamedMessage } from "./sdks/goinfer";
import { Client, type InferParams, type InferResult, type StreamedMessage, DEFAULT_CTX, DEFAULT_TEMPLATE } from "./sdks/goinfer";
import { type Command } from "../templates/render";
import { handleResponseCallbackType } from "../templates/runner";
import { displayWarning, getConfig, getSecret, getSelectionInfo, setSecret, unsetConfig } from "../utils";
import { displayWarning, getConfig, getSecret, getSelectionInfo, llamaMaxTokens, setSecret, unsetConfig } from "../utils";

let lastMessage: string | undefined;
let lastTemplate: Command | undefined;
Expand Down Expand Up @@ -84,16 +84,17 @@ export class GoinferProvider implements Provider {
prompt = `${this.conversationTextHistory ?? ""}${message}`;
}

const modelTemplate = template?.completionParams?.template ?? (getConfig("defaultTemplate") as string);
const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE;
const samplingParameters: InferParams = {
prompt,
template: modelTemplate.replace("{system}", systemMessage),
...template?.completionParams,
temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number),
model: {
name: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2",
ctx: template?.completionParams?.ctx ?? (getConfig("defaultCtx") as number) ?? 2048,
ctx: template?.completionParams?.ctx ?? DEFAULT_CTX,
},
n_predict: llamaMaxTokens(prompt, DEFAULT_CTX),
};

try {
Expand Down
11 changes: 6 additions & 5 deletions src/providers/koboldcpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import * as vscode from "vscode";

import { type PostableViewProvider, type ProviderResponse, type Provider } from ".";
import { Client, type KoboldInferParams } from "./sdks/koboldcpp";
import { Client, DEFAULT_TEMPLATE, DEFAULT_CTX, type KoboldInferParams } from "./sdks/koboldcpp";
import { type Command } from "../templates/render";
import { handleResponseCallbackType } from "../templates/runner";
import { displayWarning, getConfig, getSelectionInfo } from "../utils";
import { displayWarning, formatPrompt, getConfig, getSelectionInfo, llamaMaxTokens } from "../utils";

let lastMessage: string | undefined;
let lastTemplate: Command | undefined;
Expand Down Expand Up @@ -75,13 +75,14 @@ export class KoboldcppProvider implements Provider {
prompt = `${this.conversationTextHistory ?? ""}${message}`;
}

const modelTemplate = template?.completionParams?.template ?? (getConfig("defaultTemplate") as string);
const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE;
const samplingParameters: KoboldInferParams = {
prompt: modelTemplate.replace("{system}", systemMessage).replace("{prompt}", prompt),
prompt: formatPrompt(prompt, modelTemplate, systemMessage),
...template?.completionParams,
temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number),
max_length: 512,
max_length: llamaMaxTokens(prompt, DEFAULT_CTX),
};
console.log("Params", samplingParameters);

try {
this.viewProvider?.postMessage({ type: "requestMessage", value: message });
Expand Down
2 changes: 2 additions & 0 deletions src/providers/sdks/goinfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ export type OnOpen = (response: NodeFetchResponse) => void | Promise<void>;
export type OnUpdate = (completion: StreamedMessage) => void | Promise<void>;

const DEFAULT_API_URL = "https://localhost:5143";
export const DEFAULT_CTX = 2048;
export const DEFAULT_TEMPLATE = "{system}\n\n{prompt}";

export class Client {
private apiUrl: string;
Expand Down
4 changes: 4 additions & 0 deletions src/providers/sdks/koboldcpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export type OnOpen = (response: NodeFetchResponse) => void | Promise<void>;
export type OnUpdate = (completion: string) => void | Promise<void>;

const DEFAULT_API_URL = "https://localhost:5001";
export const DEFAULT_TEMPLATE = "{system}\n\n{prompt}";
export const DEFAULT_CTX = 2048;

export class Client {
private apiUrl: string;
Expand Down Expand Up @@ -47,6 +49,8 @@ export class Client {
completeStream(params: KoboldInferParams, { onOpen, onUpdate, signal }: { onOpen?: OnOpen; onUpdate?: OnUpdate; signal?: AbortSignal }): Promise<void> {
const abortController = new AbortController();

console.log("Url", this.apiUrl);

return new Promise((resolve, reject) => {
signal?.addEventListener("abort", (event) => {
abortController.abort(event);
Expand Down
10 changes: 10 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import fs from "node:fs";
import path from "node:path";

import Glob from "fast-glob";
import llamaTokenizer from "llama-tokenizer-js";
import * as vscode from "vscode";

import { ExtensionState } from "./extension";
Expand Down Expand Up @@ -231,3 +232,12 @@ export const randomString = (): string => {
}
return result;
};

export function llamaMaxTokens(prompt: string, ctx: number) {
const n = llamaTokenizer.encode(prompt).length;
return ctx - n;
}

export function formatPrompt(prompt: string, template: string, systemMessage: string) {
return template.replace("{system}", systemMessage).replace("{prompt}", prompt);
}

0 comments on commit 0cd66c8

Please sign in to comment.