From be721f9a69e5512592e16f9ad673a0cd72c5915e Mon Sep 17 00:00:00 2001 From: synw Date: Thu, 24 Aug 2023 15:57:27 +0200 Subject: [PATCH 1/5] Add a Goinfer provider --- src/extension.ts | 2 +- src/providers/goinfer.ts | 156 ++++++++++++++++++++++++++++++++++ src/providers/index.ts | 2 + src/providers/sdks/goinfer.ts | 150 ++++++++++++++++++++++++++++++++ src/templates/render.ts | 14 +++ 5 files changed, 323 insertions(+), 1 deletion(-) create mode 100644 src/providers/goinfer.ts create mode 100644 src/providers/sdks/goinfer.ts diff --git a/src/extension.ts b/src/extension.ts index 969014a..4259274 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -46,7 +46,7 @@ function createCommandMap(templates: Command[]) { const allCommands = templates.map((template) => { return { ...template, - provider: template.provider ?? AIProvider.OpenAI, + provider: template.provider ?? AIProvider.Goinfer, command: template.command ? `wingman.command.${template.command}` : `wingman.command.${generateCommandName(template)}-${randomString()}`, category: template.category ?? BuiltinCategory.Misc, }; diff --git a/src/providers/goinfer.ts b/src/providers/goinfer.ts new file mode 100644 index 0000000..310e73a --- /dev/null +++ b/src/providers/goinfer.ts @@ -0,0 +1,156 @@ +/* eslint-disable unused-imports/no-unused-vars */ +import * as vscode from "vscode"; + +import { type PostableViewProvider, type ProviderResponse, type Provider } from "."; +import { Client, type InferParams, type InferResult, type StreamedMessage, DEFAULT_TEMPLATE } from "./sdks/goinfer"; +import { type Command } from "../templates/render"; +import { handleResponseCallbackType } from "../templates/runner"; +import { displayWarning, getConfig, getSecret, getSelectionInfo, setSecret, unsetConfig } from "../utils"; + +let lastMessage: string | undefined; +let lastTemplate: Command | undefined; +let lastSystemMessage: string | undefined; + +export class GoinferProvider implements Provider { + viewProvider: PostableViewProvider | undefined; + instance: Client | undefined; + conversationTextHistory: string | undefined; + _abort: AbortController = new AbortController(); + + async create(provider: PostableViewProvider, template: Command) { + const apiKey = await getSecret("openai.apiKey", ""); + + // If the user still uses the now deprecated openai.apiKey config, move it to the secrets store + // and unset the config. + if (getConfig("openai.apiKey")) { + setSecret("openai.apiKey", getConfig("openai.apiKey")); + unsetConfig("openai.apiKey"); + } + + const { apiBaseUrl } = { + apiBaseUrl: getConfig("openai.apiBaseUrl") as string | undefined, + }; + + this.viewProvider = provider; + this.conversationTextHistory = undefined; + this.instance = new Client(apiKey, { apiUrl: apiBaseUrl }); + } + + destroy() { + this.instance = undefined; + this.conversationTextHistory = undefined; + } + + abort() { + this._abort.abort(); + this._abort = new AbortController(); + } + + async send(message: string, systemMessage?: string, template?: Command): Promise { + let isFollowup = false; + + lastMessage = message; + + if (template) { + lastTemplate = template; + } + + if (!template && !lastTemplate) { + return; + } + + if (!template) { + template = lastTemplate!; + isFollowup = true; + } + + if (systemMessage) { + lastSystemMessage = systemMessage; + } + if (!systemMessage && !lastSystemMessage) { + return; + } + if (!systemMessage) { + systemMessage = lastSystemMessage!; + } + + let prompt; + if (!isFollowup) { + this.viewProvider?.postMessage({ type: "newChat" }); + // The first message should have the system message prepended + prompt = `${message}`; + } else { + // followups should have the conversation history prepended + prompt = `${this.conversationTextHistory ?? ""}${message}`; + } + + const samplingParameters: InferParams = { + prompt, + template: DEFAULT_TEMPLATE.replace("{system}", systemMessage), + ...template?.completionParams, + temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number), + model: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2", + }; + + try { + this.viewProvider?.postMessage({ type: "requestMessage", value: message }); + + const editor = vscode.window.activeTextEditor!; + const selection = getSelectionInfo(editor); + + const goinferResponse: InferResult = await this.instance!.completeStream(samplingParameters, { + onOpen: (response) => { + console.log("Opened stream, HTTP status code", response.status); + }, + onUpdate: (partialResponse: StreamedMessage) => { + const msg = this.toProviderStreamedResponse(partialResponse); + console.log("MSG:", msg.text); + this.viewProvider?.postMessage({ + type: "partialResponse", + value: msg, + }); + }, + signal: this._abort.signal, + }); + + // Reformat the API response into a ProvderResponse + const response = this.toProviderResponse(goinferResponse); + + // Append the last response to the conversation history + this.conversationTextHistory = `${this.conversationTextHistory ?? ""}${prompt} ${response.text}`; + this.viewProvider?.postMessage({ type: "responseFinished", value: response }); + + if (!isFollowup) { + handleResponseCallbackType(template, editor, selection, response.text); + } + } catch (error) { + displayWarning(String(error)); + } + } + + async repeatLast() { + if (!lastMessage || !lastSystemMessage || !lastTemplate) { + return; + } + + await this.send(lastMessage, lastSystemMessage, lastTemplate); + } + + toProviderResponse(response: InferResult) { + return { + text: response.text, + parentMessageId: "", + converastionId: "", + id: "", + }; + } + + toProviderStreamedResponse(response: StreamedMessage) { + return { + text: response.content, + parentMessageId: "", + converastionId: "", + id: "", + }; + } +} diff --git a/src/providers/index.ts b/src/providers/index.ts index 088865a..fbef838 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -1,6 +1,7 @@ import type * as vscode from "vscode"; import { AnthropicProvider } from "./anthropic"; +import { GoinferProvider } from "./goinfer"; import { OpenAIProvider } from "./openai"; import { AIProvider, type Command } from "../templates/render"; @@ -41,4 +42,5 @@ export interface Provider { export const providers = { [AIProvider.OpenAI]: OpenAIProvider, [AIProvider.Anthropic]: AnthropicProvider, + [AIProvider.Goinfer]: GoinferProvider, }; diff --git a/src/providers/sdks/goinfer.ts b/src/providers/sdks/goinfer.ts new file mode 100644 index 0000000..a9261f7 --- /dev/null +++ b/src/providers/sdks/goinfer.ts @@ -0,0 +1,150 @@ +import { fetchEventSource } from "@ai-zen/node-fetch-event-source"; +import fetch from "node-fetch"; +import type { Response as NodeFetchResponse } from "node-fetch"; + +export interface InferParams { + prompt: string; + template?: string; + stream?: boolean; + threads?: number; + model?: string; + n_predict?: number; + top_k?: number; + top_p?: number; + temperature?: number; + frequency_penalty?: number; + presence_penalty?: number; + repeat_penalty?: number; + tfs_z?: number; + stop?: string[]; +} + +export interface InferResult { + text: string; + thinkingTime: number; + thinkingTimeFormat: string; + inferenceTime: number; + emitTime: number; + emitTimeFormat: string; + totalTime: number; + totalTimeFormat: string; + tokensPerSecond: number; + totalTokens: number; +} + +export enum StreamedMsgType { + TokenMsgType = "token", + SystemMsgType = "system", + ErrorMsgType = "error", +} + +export interface StreamedMessage { + content: string; + num: number; + msg_type: StreamedMsgType; + data?: { [key: string]: any }; +} + +export type OnOpen = (response: NodeFetchResponse) => void | Promise; +export type OnUpdate = (completion: StreamedMessage) => void | Promise; + +const DEFAULT_API_URL = "https://localhost:5143"; +export const DEFAULT_TEMPLATE = "{system}\n\n### Instruction: {prompt}\n\n### Response:"; + +export class Client { + private apiUrl: string; + + constructor(private apiKey: string, options?: { apiUrl?: string }) { + this.apiUrl = options?.apiUrl ?? DEFAULT_API_URL; + } + + async complete(params: InferParams, options?: { signal?: AbortSignal }): Promise { + const response = await fetch(`${this.apiUrl}/completion`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify({ ...params, stream: false }), + signal: options?.signal, + }); + + if (!response.ok) { + const error = new Error(`Sampling error: ${response.status} ${response.statusText}`); + console.error(error); + throw error; + } + + const completion = (await response.json()) as InferResult; + return completion; + } + + completeStream(params: InferParams, { onOpen, onUpdate, signal }: { onOpen?: OnOpen; onUpdate?: OnUpdate; signal?: AbortSignal }): Promise { + const abortController = new AbortController(); + + return new Promise((resolve, reject) => { + signal?.addEventListener("abort", (event) => { + abortController.abort(event); + reject(new Error("Caller aborted completeStream")); + }); + + const body = JSON.stringify({ ...params, stream: true }); + fetchEventSource(`${this.apiUrl}/completion`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body, + signal: abortController.signal, + onopen: async (response) => { + if (!response.ok) { + abortController.abort(); + return reject(new Error(`Failed to open stream, HTTP status code ${response.status}: ${response.statusText}`)); + } + if (onOpen) { + await Promise.resolve(onOpen(response)); + } + }, + onmessage: (ev) => { + const completion = JSON.parse(ev.data) as StreamedMessage; + if (onUpdate) { + Promise.resolve(onUpdate(completion)).catch((error) => { + abortController.abort(); + reject(error); + }); + } + + // console.log(completion); + + switch (completion.msg_type) { + case "system": + // console.log("SYSTEM MSG") + // console.log(completion) + switch (completion.content) { + case "result": + return resolve(completion.data as InferResult); + case "error": + abortController.abort(); + return reject(new Error("inference error")); + default: + break; + } + break; + case "token": + if (onUpdate) { + onUpdate(completion); + } + } + }, + onerror: (error) => { + console.error("Inference error:", error); + abortController.abort(); + return reject(error); + }, + }); + }); + } +} diff --git a/src/templates/render.ts b/src/templates/render.ts index 63de198..96bbb50 100644 --- a/src/templates/render.ts +++ b/src/templates/render.ts @@ -90,6 +90,7 @@ export enum CallbackType { export enum AIProvider { OpenAI = "openai", Anthropic = "anthropic", + Goinfer = "goinfer", } export interface Command { @@ -142,6 +143,11 @@ const defaultAnthropicCompletionParams = () => ({ temperature: 0.3, }); +// https://synw.github.io/goinfer/llama_api/inference +const defaultGoinferCompletionParams = () => ({ + temperature: 0.3, +}); + export const defaultCommands: Command[] = [ // Completion { @@ -452,6 +458,14 @@ export const buildCommandTemplate = (commandName: string): Command => { temperature, }; break; + case "goinfer": + completionParams = { + ...defaultGoinferCompletionParams(), + ...template.completionParams, + model, + temperature, + }; + break; default: completionParams = { ...template.completionParams }; break; From a6dd4cdea8190b19c0de540db035d96d5b2f7130 Mon Sep 17 00:00:00 2001 From: synw Date: Thu, 24 Aug 2023 16:24:00 +0200 Subject: [PATCH 2/5] Fix streaming response display in the Goinfer provider --- src/providers/goinfer.ts | 22 ++++++++-------------- src/providers/sdks/goinfer.ts | 4 ++-- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/providers/goinfer.ts b/src/providers/goinfer.ts index 310e73a..0a1fd2f 100644 --- a/src/providers/goinfer.ts +++ b/src/providers/goinfer.ts @@ -97,14 +97,17 @@ export class GoinferProvider implements Provider { const editor = vscode.window.activeTextEditor!; const selection = getSelectionInfo(editor); + let partialText = ""; const goinferResponse: InferResult = await this.instance!.completeStream(samplingParameters, { onOpen: (response) => { console.log("Opened stream, HTTP status code", response.status); }, onUpdate: (partialResponse: StreamedMessage) => { - const msg = this.toProviderStreamedResponse(partialResponse); - console.log("MSG:", msg.text); + partialText += partialResponse.content; + // console.log("P", partialText); + const msg = this.toProviderResponse(partialText); + // console.log("MSG:", msg.text); this.viewProvider?.postMessage({ type: "partialResponse", value: msg, @@ -114,7 +117,7 @@ export class GoinferProvider implements Provider { }); // Reformat the API response into a ProvderResponse - const response = this.toProviderResponse(goinferResponse); + const response = this.toProviderResponse(goinferResponse.text); // Append the last response to the conversation history this.conversationTextHistory = `${this.conversationTextHistory ?? ""}${prompt} ${response.text}`; @@ -136,18 +139,9 @@ export class GoinferProvider implements Provider { await this.send(lastMessage, lastSystemMessage, lastTemplate); } - toProviderResponse(response: InferResult) { + toProviderResponse(text: string) { return { - text: response.text, - parentMessageId: "", - converastionId: "", - id: "", - }; - } - - toProviderStreamedResponse(response: StreamedMessage) { - return { - text: response.content, + text, parentMessageId: "", converastionId: "", id: "", diff --git a/src/providers/sdks/goinfer.ts b/src/providers/sdks/goinfer.ts index a9261f7..0c927d6 100644 --- a/src/providers/sdks/goinfer.ts +++ b/src/providers/sdks/goinfer.ts @@ -110,12 +110,12 @@ export class Client { }, onmessage: (ev) => { const completion = JSON.parse(ev.data) as StreamedMessage; - if (onUpdate) { + /* if (onUpdate) { Promise.resolve(onUpdate(completion)).catch((error) => { abortController.abort(); reject(error); }); - } + } */ // console.log(completion); From 7f0e465607887205ba7c1b4f14123bbb36e7880d Mon Sep 17 00:00:00 2001 From: synw Date: Sat, 26 Aug 2023 16:28:41 +0200 Subject: [PATCH 3/5] Add templating and model params support in Goinfer provider --- package.json | 13 ++++++++++++- src/providers/goinfer.ts | 10 +++++++--- src/providers/sdks/goinfer.ts | 10 ++++++++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/package.json b/package.json index 2545cfb..ca5b434 100644 --- a/package.json +++ b/package.json @@ -89,6 +89,16 @@ "default": 60000, "description": "Request timeout in milliseconds" }, + "wingman.defaultCtx": { + "type": "number", + "default": 2048, + "description": "Default context window length for the model (a fallback if the command does not specify)" + }, + "wingman.defaultTemplate": { + "type": "string", + "default": "[INST] <>\n{system}\n<>\n\n{prompt} [/INST]", + "description": "Default template (a fallback if the command does not specify)" + }, "wingman.anthropic.model": { "type": "string", "default": "claude-instant-v1", @@ -266,7 +276,8 @@ "default": "openai", "enum": [ "anthropic", - "openai" + "openai", + "goinfer" ], "description": "Which provider should this command use?" } diff --git a/src/providers/goinfer.ts b/src/providers/goinfer.ts index 0a1fd2f..37c2386 100644 --- a/src/providers/goinfer.ts +++ b/src/providers/goinfer.ts @@ -2,7 +2,7 @@ import * as vscode from "vscode"; import { type PostableViewProvider, type ProviderResponse, type Provider } from "."; -import { Client, type InferParams, type InferResult, type StreamedMessage, DEFAULT_TEMPLATE } from "./sdks/goinfer"; +import { Client, type InferParams, type InferResult, type StreamedMessage } from "./sdks/goinfer"; import { type Command } from "../templates/render"; import { handleResponseCallbackType } from "../templates/runner"; import { displayWarning, getConfig, getSecret, getSelectionInfo, setSecret, unsetConfig } from "../utils"; @@ -84,12 +84,16 @@ export class GoinferProvider implements Provider { prompt = `${this.conversationTextHistory ?? ""}${message}`; } + const modelTemplate = template?.completionParams?.template ?? (getConfig("defaultTemplate") as string); const samplingParameters: InferParams = { prompt, - template: DEFAULT_TEMPLATE.replace("{system}", systemMessage), + template: modelTemplate.replace("{system}", systemMessage), ...template?.completionParams, temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number), - model: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2", + model: { + name: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2", + ctx: template?.completionParams?.ctx ?? (getConfig("defaultCtx") as number) ?? 2048, + }, }; try { diff --git a/src/providers/sdks/goinfer.ts b/src/providers/sdks/goinfer.ts index 0c927d6..4fb956c 100644 --- a/src/providers/sdks/goinfer.ts +++ b/src/providers/sdks/goinfer.ts @@ -2,12 +2,19 @@ import { fetchEventSource } from "@ai-zen/node-fetch-event-source"; import fetch from "node-fetch"; import type { Response as NodeFetchResponse } from "node-fetch"; +interface ModelConf { + name: string; + ctx?: number; + freq_rope_base?: number; + freq_rope_scale?: number; +} + export interface InferParams { prompt: string; template?: string; stream?: boolean; threads?: number; - model?: string; + model?: ModelConf; n_predict?: number; top_k?: number; top_p?: number; @@ -49,7 +56,6 @@ export type OnOpen = (response: NodeFetchResponse) => void | Promise; export type OnUpdate = (completion: StreamedMessage) => void | Promise; const DEFAULT_API_URL = "https://localhost:5143"; -export const DEFAULT_TEMPLATE = "{system}\n\n### Instruction: {prompt}\n\n### Response:"; export class Client { private apiUrl: string; From 17cb1d73a7057c7c73b89bdcf9aa8fea22b765cf Mon Sep 17 00:00:00 2001 From: synw Date: Tue, 29 Aug 2023 13:22:22 +0200 Subject: [PATCH 4/5] Fixes for Koboldcpp provider --- package.json | 5 +- src/extension.ts | 2 +- src/providers/index.ts | 2 + src/providers/koboldcpp.ts | 141 ++++++++++++++++++++++++++++++++ src/providers/sdks/koboldcpp.ts | 93 +++++++++++++++++++++ src/templates/render.ts | 12 +++ 6 files changed, 252 insertions(+), 3 deletions(-) create mode 100644 src/providers/koboldcpp.ts create mode 100644 src/providers/sdks/koboldcpp.ts diff --git a/package.json b/package.json index ca5b434..7d80907 100644 --- a/package.json +++ b/package.json @@ -96,7 +96,7 @@ }, "wingman.defaultTemplate": { "type": "string", - "default": "[INST] <>\n{system}\n<>\n\n{prompt} [/INST]", + "default": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n### User:\n{prompt}\n\n### Assistant:", "description": "Default template (a fallback if the command does not specify)" }, "wingman.anthropic.model": { @@ -277,7 +277,8 @@ "enum": [ "anthropic", "openai", - "goinfer" + "goinfer", + "koboldcpp" ], "description": "Which provider should this command use?" } diff --git a/src/extension.ts b/src/extension.ts index 4259274..98f90dd 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -46,7 +46,7 @@ function createCommandMap(templates: Command[]) { const allCommands = templates.map((template) => { return { ...template, - provider: template.provider ?? AIProvider.Goinfer, + provider: template.provider ?? AIProvider.KoboldCpp, command: template.command ? `wingman.command.${template.command}` : `wingman.command.${generateCommandName(template)}-${randomString()}`, category: template.category ?? BuiltinCategory.Misc, }; diff --git a/src/providers/index.ts b/src/providers/index.ts index fbef838..2295931 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -2,6 +2,7 @@ import type * as vscode from "vscode"; import { AnthropicProvider } from "./anthropic"; import { GoinferProvider } from "./goinfer"; +import { KoboldcppProvider } from "./koboldcpp"; import { OpenAIProvider } from "./openai"; import { AIProvider, type Command } from "../templates/render"; @@ -43,4 +44,5 @@ export const providers = { [AIProvider.OpenAI]: OpenAIProvider, [AIProvider.Anthropic]: AnthropicProvider, [AIProvider.Goinfer]: GoinferProvider, + [AIProvider.KoboldCpp]: KoboldcppProvider, }; diff --git a/src/providers/koboldcpp.ts b/src/providers/koboldcpp.ts new file mode 100644 index 0000000..a6d376d --- /dev/null +++ b/src/providers/koboldcpp.ts @@ -0,0 +1,141 @@ +/* eslint-disable unused-imports/no-unused-vars */ +import * as vscode from "vscode"; + +import { type PostableViewProvider, type ProviderResponse, type Provider } from "."; +import { Client, type KoboldInferParams } from "./sdks/koboldcpp"; +import { type Command } from "../templates/render"; +import { handleResponseCallbackType } from "../templates/runner"; +import { displayWarning, getConfig, getSelectionInfo } from "../utils"; + +let lastMessage: string | undefined; +let lastTemplate: Command | undefined; +let lastSystemMessage: string | undefined; + +export class KoboldcppProvider implements Provider { + viewProvider: PostableViewProvider | undefined; + instance: Client | undefined; + conversationTextHistory: string | undefined; + _abort: AbortController = new AbortController(); + + async create(provider: PostableViewProvider, template: Command) { + const { apiBaseUrl } = { + apiBaseUrl: getConfig("openai.apiBaseUrl") as string | undefined, + }; + + this.viewProvider = provider; + this.conversationTextHistory = undefined; + this.instance = new Client("", { apiUrl: apiBaseUrl }); + } + + destroy() { + this.instance = undefined; + this.conversationTextHistory = undefined; + } + + abort() { + this._abort.abort(); + this._abort = new AbortController(); + } + + async send(message: string, systemMessage?: string, template?: Command): Promise { + let isFollowup = false; + + lastMessage = message; + + if (template) { + lastTemplate = template; + } + + if (!template && !lastTemplate) { + return; + } + + if (!template) { + template = lastTemplate!; + isFollowup = true; + } + + if (systemMessage) { + lastSystemMessage = systemMessage; + } + if (!systemMessage && !lastSystemMessage) { + return; + } + if (!systemMessage) { + systemMessage = lastSystemMessage!; + } + + let prompt; + if (!isFollowup) { + this.viewProvider?.postMessage({ type: "newChat" }); + // The first message should have the system message prepended + prompt = `${message}`; + } else { + // followups should have the conversation history prepended + prompt = `${this.conversationTextHistory ?? ""}${message}`; + } + + const modelTemplate = template?.completionParams?.template ?? (getConfig("defaultTemplate") as string); + const samplingParameters: KoboldInferParams = { + prompt: modelTemplate.replace("{system}", systemMessage).replace("{prompt}", prompt), + ...template?.completionParams, + temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number), + max_length: 512, + }; + + try { + this.viewProvider?.postMessage({ type: "requestMessage", value: message }); + + const editor = vscode.window.activeTextEditor!; + const selection = getSelectionInfo(editor); + let partialText = ""; + + await this.instance!.completeStream(samplingParameters, { + onOpen: (response) => { + console.log("Opened stream, HTTP status code", response.status); + }, + onUpdate: (partialResponse: string) => { + partialText += partialResponse; + // console.log("P", partialText); + const msg = this.toProviderResponse(partialText); + // console.log("MSG:", msg.text); + this.viewProvider?.postMessage({ + type: "partialResponse", + value: msg, + }); + }, + signal: this._abort.signal, + }); + + // Reformat the API response into a ProvderResponse + const response = this.toProviderResponse(partialText); + + // Append the last response to the conversation history + this.conversationTextHistory = `${this.conversationTextHistory ?? ""}${prompt} ${response.text}`; + this.viewProvider?.postMessage({ type: "responseFinished", value: response }); + + if (!isFollowup) { + handleResponseCallbackType(template, editor, selection, response.text); + } + } catch (error) { + displayWarning(String(error)); + } + } + + async repeatLast() { + if (!lastMessage || !lastSystemMessage || !lastTemplate) { + return; + } + + await this.send(lastMessage, lastSystemMessage, lastTemplate); + } + + toProviderResponse(text: string) { + return { + text, + parentMessageId: "", + converastionId: "", + id: "", + }; + } +} diff --git a/src/providers/sdks/koboldcpp.ts b/src/providers/sdks/koboldcpp.ts new file mode 100644 index 0000000..6869603 --- /dev/null +++ b/src/providers/sdks/koboldcpp.ts @@ -0,0 +1,93 @@ +import { fetchEventSource } from "@ai-zen/node-fetch-event-source"; +import type { Response as NodeFetchResponse } from "node-fetch"; + +export interface KoboldInferParams { + prompt: string; + max_length?: number; + top_k?: number; + top_p?: number; + temperature?: number; + tfs?: number; + stop_sequence?: string[]; +} + +export type OnOpen = (response: NodeFetchResponse) => void | Promise; +export type OnUpdate = (completion: string) => void | Promise; + +const DEFAULT_API_URL = "https://localhost:5001"; + +export class Client { + private apiUrl: string; + + constructor(private apiKey: string, options?: { apiUrl?: string }) { + this.apiUrl = options?.apiUrl ?? DEFAULT_API_URL; + } + + /* async complete(params: KoboldInferParams, options?: { signal?: AbortSignal }): Promise { + const response = await fetch(`${this.apiUrl}/api/extra/generate/stream`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + }, + body: JSON.stringify({ ...params, stream: false }), + signal: options?.signal, + }); + + if (!response.ok) { + const error = new Error(`Sampling error: ${response.status} ${response.statusText}`); + console.error(error); + throw error; + } + + const completion = (await response.json()); + return "" + } */ + + completeStream(params: KoboldInferParams, { onOpen, onUpdate, signal }: { onOpen?: OnOpen; onUpdate?: OnUpdate; signal?: AbortSignal }): Promise { + const abortController = new AbortController(); + + return new Promise((resolve, reject) => { + signal?.addEventListener("abort", (event) => { + abortController.abort(event); + reject(new Error("Caller aborted completeStream")); + }); + + const body = JSON.stringify({ ...params }); + fetchEventSource(`${this.apiUrl}/api/extra/generate/stream`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body, + signal: abortController.signal, + onopen: async (response) => { + if (!response.ok) { + abortController.abort(); + return reject(new Error(`Failed to open stream, HTTP status code ${response.status}: ${response.statusText}`)); + } + if (onOpen) { + await Promise.resolve(onOpen(response)); + } + }, + onmessage: (ev) => { + const completion = JSON.parse(ev.data); + if (onUpdate) { + onUpdate(completion.token); + } + }, + onclose: () => { + // console.log("Close stream") + return resolve(); + }, + onerror: (error) => { + console.error("Inference error:", error); + abortController.abort(); + return reject(error); + }, + }); + }); + } +} diff --git a/src/templates/render.ts b/src/templates/render.ts index 96bbb50..dfb44ad 100644 --- a/src/templates/render.ts +++ b/src/templates/render.ts @@ -91,6 +91,7 @@ export enum AIProvider { OpenAI = "openai", Anthropic = "anthropic", Goinfer = "goinfer", + KoboldCpp = "koboldcpp", } export interface Command { @@ -148,6 +149,10 @@ const defaultGoinferCompletionParams = () => ({ temperature: 0.3, }); +const defaultKoboldcppCompletionParams = () => ({ + temperature: 0.3, +}); + export const defaultCommands: Command[] = [ // Completion { @@ -466,6 +471,13 @@ export const buildCommandTemplate = (commandName: string): Command => { temperature, }; break; + case "koboldcpp": + completionParams = { + ...defaultKoboldcppCompletionParams(), + ...template.completionParams, + temperature, + }; + break; default: completionParams = { ...template.completionParams }; break; From 0cd66c86b0ff4a00d8040fddb080ad202758d4e7 Mon Sep 17 00:00:00 2001 From: synw Date: Wed, 30 Aug 2023 13:42:50 +0200 Subject: [PATCH 5/5] Cleanup and prepare PR for the new providers --- package.json | 11 +---------- pnpm-lock.yaml | 32 +++++++++++++++++++++++++++++++- src/extension.ts | 2 +- src/providers/goinfer.ts | 9 +++++---- src/providers/koboldcpp.ts | 11 ++++++----- src/providers/sdks/goinfer.ts | 2 ++ src/providers/sdks/koboldcpp.ts | 4 ++++ src/utils.ts | 10 ++++++++++ 8 files changed, 60 insertions(+), 21 deletions(-) diff --git a/package.json b/package.json index 7d80907..4754ac7 100644 --- a/package.json +++ b/package.json @@ -89,16 +89,6 @@ "default": 60000, "description": "Request timeout in milliseconds" }, - "wingman.defaultCtx": { - "type": "number", - "default": 2048, - "description": "Default context window length for the model (a fallback if the command does not specify)" - }, - "wingman.defaultTemplate": { - "type": "string", - "default": "### System:\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n### User:\n{prompt}\n\n### Assistant:", - "description": "Default template (a fallback if the command does not specify)" - }, "wingman.anthropic.model": { "type": "string", "default": "claude-instant-v1", @@ -360,6 +350,7 @@ "cheerio": "1.0.0-rc.12", "fast-glob": "^3.2.12", "fetch": "^1.1.0", + "llama-tokenizer-js": "^1.1.3", "node-fetch": "^3.3.1" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a345a01..dee74af 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.1' +lockfileVersion: '6.0' settings: autoInstallPeers: true @@ -23,6 +23,9 @@ dependencies: fetch: specifier: ^1.1.0 version: 1.1.0 + llama-tokenizer-js: + specifier: ^1.1.3 + version: 1.1.3 node-fetch: specifier: ^3.3.1 version: 3.3.1 @@ -1241,6 +1244,7 @@ packages: /bl@4.1.0: resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==} + requiresBuild: true dependencies: buffer: 5.7.1 inherits: 2.0.4 @@ -1302,6 +1306,7 @@ packages: /buffer@5.7.1: resolution: {integrity: sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==} + requiresBuild: true dependencies: base64-js: 1.5.1 ieee754: 1.2.1 @@ -1504,6 +1509,7 @@ packages: /chownr@1.1.4: resolution: {integrity: sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==} + requiresBuild: true dev: true optional: true @@ -1730,6 +1736,7 @@ packages: /deep-extend@0.6.0: resolution: {integrity: sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==} engines: {node: '>=4.0.0'} + requiresBuild: true dev: true /deep-is@0.1.4: @@ -1788,6 +1795,7 @@ packages: /detect-libc@2.0.1: resolution: {integrity: sha512-463v3ZeIrcWtdgIg6vI6XUncguvr2TnGl4SzDXinkt9mSLpBJKXT3mW6xT3VQdDN11+WVs29pgvivTc4Lp8v+w==} engines: {node: '>=8'} + requiresBuild: true dev: true optional: true @@ -1876,6 +1884,7 @@ packages: /end-of-stream@1.4.4: resolution: {integrity: sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==} + requiresBuild: true dependencies: once: 1.4.0 dev: true @@ -2493,6 +2502,7 @@ packages: /expand-template@2.0.3: resolution: {integrity: sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==} engines: {node: '>=6'} + requiresBuild: true dev: true optional: true @@ -2634,6 +2644,7 @@ packages: /fs-constants@1.0.0: resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} + requiresBuild: true dev: true optional: true @@ -2732,6 +2743,7 @@ packages: /github-from-package@0.0.0: resolution: {integrity: sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==} + requiresBuild: true dev: true optional: true @@ -2975,6 +2987,7 @@ packages: /ieee754@1.2.1: resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==} + requiresBuild: true dev: true optional: true @@ -3476,6 +3489,10 @@ packages: wrap-ansi: 7.0.0 dev: true + /llama-tokenizer-js@1.1.3: + resolution: {integrity: sha512-+BUgsLCXVQJkjiD/t7PdESLn+yXJIRX/BJfwzVVYfKZ9aN3gsP9xoadBZxKnCxGz2Slby+S7x41gUr2TKNaS4Q==} + dev: false + /loader-runner@4.3.0: resolution: {integrity: sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==} engines: {node: '>=6.11.5'} @@ -3718,6 +3735,7 @@ packages: /mkdirp-classic@0.5.3: resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} + requiresBuild: true dev: true optional: true @@ -3789,6 +3807,7 @@ packages: /napi-build-utils@1.0.2: resolution: {integrity: sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==} + requiresBuild: true dev: true optional: true @@ -3807,6 +3826,7 @@ packages: /node-abi@3.40.0: resolution: {integrity: sha512-zNy02qivjjRosswoYmPi8hIKJRr8MpQyeKT6qlcq/OnOgA3Rhoae+IYOqsM9V5+JnHWmxKnWOT2GxvtqdtOCXA==} engines: {node: '>=10'} + requiresBuild: true dependencies: semver: 7.5.1 dev: true @@ -3814,6 +3834,7 @@ packages: /node-addon-api@4.3.0: resolution: {integrity: sha512-73sE9+3UaLYYFmDsFZnqCInzPyh3MqIwZO9cw58yIqAZhONrrabrYyYe3TuIqtIiOuTXVhsGau8hcrhhwSsDIQ==} + requiresBuild: true dev: true optional: true @@ -4200,6 +4221,7 @@ packages: resolution: {integrity: sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==} engines: {node: '>=10'} hasBin: true + requiresBuild: true dependencies: detect-libc: 2.0.1 expand-template: 2.0.3 @@ -4254,6 +4276,7 @@ packages: /pump@3.0.0: resolution: {integrity: sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==} + requiresBuild: true dependencies: end-of-stream: 1.4.4 once: 1.4.0 @@ -4375,6 +4398,7 @@ packages: /readable-stream@3.6.2: resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} engines: {node: '>= 6'} + requiresBuild: true dependencies: inherits: 2.0.4 string_decoder: 1.3.0 @@ -4641,11 +4665,13 @@ packages: /simple-concat@1.0.1: resolution: {integrity: sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==} + requiresBuild: true dev: true optional: true /simple-get@4.0.1: resolution: {integrity: sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==} + requiresBuild: true dependencies: decompress-response: 6.0.0 once: 1.4.0 @@ -4800,6 +4826,7 @@ packages: /string_decoder@1.3.0: resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==} + requiresBuild: true dependencies: safe-buffer: 5.2.1 dev: true @@ -4844,6 +4871,7 @@ packages: /strip-json-comments@2.0.1: resolution: {integrity: sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==} engines: {node: '>=0.10.0'} + requiresBuild: true dev: true /strip-json-comments@3.1.1: @@ -4907,6 +4935,7 @@ packages: /tar-fs@2.1.1: resolution: {integrity: sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==} + requiresBuild: true dependencies: chownr: 1.1.4 mkdirp-classic: 0.5.3 @@ -4918,6 +4947,7 @@ packages: /tar-stream@2.2.0: resolution: {integrity: sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==} engines: {node: '>=6'} + requiresBuild: true dependencies: bl: 4.1.0 end-of-stream: 1.4.4 diff --git a/src/extension.ts b/src/extension.ts index 98f90dd..969014a 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -46,7 +46,7 @@ function createCommandMap(templates: Command[]) { const allCommands = templates.map((template) => { return { ...template, - provider: template.provider ?? AIProvider.KoboldCpp, + provider: template.provider ?? AIProvider.OpenAI, command: template.command ? `wingman.command.${template.command}` : `wingman.command.${generateCommandName(template)}-${randomString()}`, category: template.category ?? BuiltinCategory.Misc, }; diff --git a/src/providers/goinfer.ts b/src/providers/goinfer.ts index 37c2386..8e051bc 100644 --- a/src/providers/goinfer.ts +++ b/src/providers/goinfer.ts @@ -2,10 +2,10 @@ import * as vscode from "vscode"; import { type PostableViewProvider, type ProviderResponse, type Provider } from "."; -import { Client, type InferParams, type InferResult, type StreamedMessage } from "./sdks/goinfer"; +import { Client, type InferParams, type InferResult, type StreamedMessage, DEFAULT_CTX, DEFAULT_TEMPLATE } from "./sdks/goinfer"; import { type Command } from "../templates/render"; import { handleResponseCallbackType } from "../templates/runner"; -import { displayWarning, getConfig, getSecret, getSelectionInfo, setSecret, unsetConfig } from "../utils"; +import { displayWarning, getConfig, getSecret, getSelectionInfo, llamaMaxTokens, setSecret, unsetConfig } from "../utils"; let lastMessage: string | undefined; let lastTemplate: Command | undefined; @@ -84,7 +84,7 @@ export class GoinferProvider implements Provider { prompt = `${this.conversationTextHistory ?? ""}${message}`; } - const modelTemplate = template?.completionParams?.template ?? (getConfig("defaultTemplate") as string); + const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE; const samplingParameters: InferParams = { prompt, template: modelTemplate.replace("{system}", systemMessage), @@ -92,8 +92,9 @@ export class GoinferProvider implements Provider { temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number), model: { name: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2", - ctx: template?.completionParams?.ctx ?? (getConfig("defaultCtx") as number) ?? 2048, + ctx: template?.completionParams?.ctx ?? DEFAULT_CTX, }, + n_predict: llamaMaxTokens(prompt, DEFAULT_CTX), }; try { diff --git a/src/providers/koboldcpp.ts b/src/providers/koboldcpp.ts index a6d376d..c8a265b 100644 --- a/src/providers/koboldcpp.ts +++ b/src/providers/koboldcpp.ts @@ -2,10 +2,10 @@ import * as vscode from "vscode"; import { type PostableViewProvider, type ProviderResponse, type Provider } from "."; -import { Client, type KoboldInferParams } from "./sdks/koboldcpp"; +import { Client, DEFAULT_TEMPLATE, DEFAULT_CTX, type KoboldInferParams } from "./sdks/koboldcpp"; import { type Command } from "../templates/render"; import { handleResponseCallbackType } from "../templates/runner"; -import { displayWarning, getConfig, getSelectionInfo } from "../utils"; +import { displayWarning, formatPrompt, getConfig, getSelectionInfo, llamaMaxTokens } from "../utils"; let lastMessage: string | undefined; let lastTemplate: Command | undefined; @@ -75,13 +75,14 @@ export class KoboldcppProvider implements Provider { prompt = `${this.conversationTextHistory ?? ""}${message}`; } - const modelTemplate = template?.completionParams?.template ?? (getConfig("defaultTemplate") as string); + const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE; const samplingParameters: KoboldInferParams = { - prompt: modelTemplate.replace("{system}", systemMessage).replace("{prompt}", prompt), + prompt: formatPrompt(prompt, modelTemplate, systemMessage), ...template?.completionParams, temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number), - max_length: 512, + max_length: llamaMaxTokens(prompt, DEFAULT_CTX), }; + console.log("Params", samplingParameters); try { this.viewProvider?.postMessage({ type: "requestMessage", value: message }); diff --git a/src/providers/sdks/goinfer.ts b/src/providers/sdks/goinfer.ts index 4fb956c..cd691e5 100644 --- a/src/providers/sdks/goinfer.ts +++ b/src/providers/sdks/goinfer.ts @@ -56,6 +56,8 @@ export type OnOpen = (response: NodeFetchResponse) => void | Promise; export type OnUpdate = (completion: StreamedMessage) => void | Promise; const DEFAULT_API_URL = "https://localhost:5143"; +export const DEFAULT_CTX = 2048; +export const DEFAULT_TEMPLATE = "{system}\n\n{prompt}"; export class Client { private apiUrl: string; diff --git a/src/providers/sdks/koboldcpp.ts b/src/providers/sdks/koboldcpp.ts index 6869603..0d008ea 100644 --- a/src/providers/sdks/koboldcpp.ts +++ b/src/providers/sdks/koboldcpp.ts @@ -15,6 +15,8 @@ export type OnOpen = (response: NodeFetchResponse) => void | Promise; export type OnUpdate = (completion: string) => void | Promise; const DEFAULT_API_URL = "https://localhost:5001"; +export const DEFAULT_TEMPLATE = "{system}\n\n{prompt}"; +export const DEFAULT_CTX = 2048; export class Client { private apiUrl: string; @@ -47,6 +49,8 @@ export class Client { completeStream(params: KoboldInferParams, { onOpen, onUpdate, signal }: { onOpen?: OnOpen; onUpdate?: OnUpdate; signal?: AbortSignal }): Promise { const abortController = new AbortController(); + console.log("Url", this.apiUrl); + return new Promise((resolve, reject) => { signal?.addEventListener("abort", (event) => { abortController.abort(event); diff --git a/src/utils.ts b/src/utils.ts index 4a68b41..e9921f0 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -2,6 +2,7 @@ import fs from "node:fs"; import path from "node:path"; import Glob from "fast-glob"; +import llamaTokenizer from "llama-tokenizer-js"; import * as vscode from "vscode"; import { ExtensionState } from "./extension"; @@ -231,3 +232,12 @@ export const randomString = (): string => { } return result; }; + +export function llamaMaxTokens(prompt: string, ctx: number) { + const n = llamaTokenizer.encode(prompt).length; + return ctx - n; +} + +export function formatPrompt(prompt: string, template: string, systemMessage: string) { + return template.replace("{system}", systemMessage).replace("{prompt}", prompt); +}