diff --git a/package.json b/package.json index 2545cfb..4754ac7 100644 --- a/package.json +++ b/package.json @@ -266,7 +266,9 @@ "default": "openai", "enum": [ "anthropic", - "openai" + "openai", + "goinfer", + "koboldcpp" ], "description": "Which provider should this command use?" } @@ -348,6 +350,7 @@ "cheerio": "1.0.0-rc.12", "fast-glob": "^3.2.12", "fetch": "^1.1.0", + "llama-tokenizer-js": "^1.1.3", "node-fetch": "^3.3.1" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a345a01..dee74af 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.1' +lockfileVersion: '6.0' settings: autoInstallPeers: true @@ -23,6 +23,9 @@ dependencies: fetch: specifier: ^1.1.0 version: 1.1.0 + llama-tokenizer-js: + specifier: ^1.1.3 + version: 1.1.3 node-fetch: specifier: ^3.3.1 version: 3.3.1 @@ -1241,6 +1244,7 @@ packages: /bl@4.1.0: resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==} + requiresBuild: true dependencies: buffer: 5.7.1 inherits: 2.0.4 @@ -1302,6 +1306,7 @@ packages: /buffer@5.7.1: resolution: {integrity: sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==} + requiresBuild: true dependencies: base64-js: 1.5.1 ieee754: 1.2.1 @@ -1504,6 +1509,7 @@ packages: /chownr@1.1.4: resolution: {integrity: sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==} + requiresBuild: true dev: true optional: true @@ -1730,6 +1736,7 @@ packages: /deep-extend@0.6.0: resolution: {integrity: sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==} engines: {node: '>=4.0.0'} + requiresBuild: true dev: true /deep-is@0.1.4: @@ -1788,6 +1795,7 @@ packages: /detect-libc@2.0.1: resolution: {integrity: sha512-463v3ZeIrcWtdgIg6vI6XUncguvr2TnGl4SzDXinkt9mSLpBJKXT3mW6xT3VQdDN11+WVs29pgvivTc4Lp8v+w==} engines: {node: '>=8'} + requiresBuild: true dev: true optional: true @@ -1876,6 +1884,7 @@ packages: /end-of-stream@1.4.4: resolution: {integrity: sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==} + requiresBuild: true dependencies: once: 1.4.0 dev: true @@ -2493,6 +2502,7 @@ packages: /expand-template@2.0.3: resolution: {integrity: sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==} engines: {node: '>=6'} + requiresBuild: true dev: true optional: true @@ -2634,6 +2644,7 @@ packages: /fs-constants@1.0.0: resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} + requiresBuild: true dev: true optional: true @@ -2732,6 +2743,7 @@ packages: /github-from-package@0.0.0: resolution: {integrity: sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==} + requiresBuild: true dev: true optional: true @@ -2975,6 +2987,7 @@ packages: /ieee754@1.2.1: resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==} + requiresBuild: true dev: true optional: true @@ -3476,6 +3489,10 @@ packages: wrap-ansi: 7.0.0 dev: true + /llama-tokenizer-js@1.1.3: + resolution: {integrity: sha512-+BUgsLCXVQJkjiD/t7PdESLn+yXJIRX/BJfwzVVYfKZ9aN3gsP9xoadBZxKnCxGz2Slby+S7x41gUr2TKNaS4Q==} + dev: false + /loader-runner@4.3.0: resolution: {integrity: sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==} engines: {node: '>=6.11.5'} @@ -3718,6 +3735,7 @@ packages: /mkdirp-classic@0.5.3: resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} + requiresBuild: true dev: true optional: true @@ -3789,6 +3807,7 @@ packages: /napi-build-utils@1.0.2: resolution: {integrity: sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==} + requiresBuild: true dev: true optional: true @@ -3807,6 +3826,7 @@ packages: /node-abi@3.40.0: resolution: {integrity: sha512-zNy02qivjjRosswoYmPi8hIKJRr8MpQyeKT6qlcq/OnOgA3Rhoae+IYOqsM9V5+JnHWmxKnWOT2GxvtqdtOCXA==} engines: {node: '>=10'} + requiresBuild: true dependencies: semver: 7.5.1 dev: true @@ -3814,6 +3834,7 @@ packages: /node-addon-api@4.3.0: resolution: {integrity: sha512-73sE9+3UaLYYFmDsFZnqCInzPyh3MqIwZO9cw58yIqAZhONrrabrYyYe3TuIqtIiOuTXVhsGau8hcrhhwSsDIQ==} + requiresBuild: true dev: true optional: true @@ -4200,6 +4221,7 @@ packages: resolution: {integrity: sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==} engines: {node: '>=10'} hasBin: true + requiresBuild: true dependencies: detect-libc: 2.0.1 expand-template: 2.0.3 @@ -4254,6 +4276,7 @@ packages: /pump@3.0.0: resolution: {integrity: sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==} + requiresBuild: true dependencies: end-of-stream: 1.4.4 once: 1.4.0 @@ -4375,6 +4398,7 @@ packages: /readable-stream@3.6.2: resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} engines: {node: '>= 6'} + requiresBuild: true dependencies: inherits: 2.0.4 string_decoder: 1.3.0 @@ -4641,11 +4665,13 @@ packages: /simple-concat@1.0.1: resolution: {integrity: sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==} + requiresBuild: true dev: true optional: true /simple-get@4.0.1: resolution: {integrity: sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==} + requiresBuild: true dependencies: decompress-response: 6.0.0 once: 1.4.0 @@ -4800,6 +4826,7 @@ packages: /string_decoder@1.3.0: resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==} + requiresBuild: true dependencies: safe-buffer: 5.2.1 dev: true @@ -4844,6 +4871,7 @@ packages: /strip-json-comments@2.0.1: resolution: {integrity: sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==} engines: {node: '>=0.10.0'} + requiresBuild: true dev: true /strip-json-comments@3.1.1: @@ -4907,6 +4935,7 @@ packages: /tar-fs@2.1.1: resolution: {integrity: sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==} + requiresBuild: true dependencies: chownr: 1.1.4 mkdirp-classic: 0.5.3 @@ -4918,6 +4947,7 @@ packages: /tar-stream@2.2.0: resolution: {integrity: sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==} engines: {node: '>=6'} + requiresBuild: true dependencies: bl: 4.1.0 end-of-stream: 1.4.4 diff --git a/src/providers/goinfer.ts b/src/providers/goinfer.ts new file mode 100644 index 0000000..8e051bc --- /dev/null +++ b/src/providers/goinfer.ts @@ -0,0 +1,155 @@ +/* eslint-disable unused-imports/no-unused-vars */ +import * as vscode from "vscode"; + +import { type PostableViewProvider, type ProviderResponse, type Provider } from "."; +import { Client, type InferParams, type InferResult, type StreamedMessage, DEFAULT_CTX, DEFAULT_TEMPLATE } from "./sdks/goinfer"; +import { type Command } from "../templates/render"; +import { handleResponseCallbackType } from "../templates/runner"; +import { displayWarning, getConfig, getSecret, getSelectionInfo, llamaMaxTokens, setSecret, unsetConfig } from "../utils"; + +let lastMessage: string | undefined; +let lastTemplate: Command | undefined; +let lastSystemMessage: string | undefined; + +export class GoinferProvider implements Provider { + viewProvider: PostableViewProvider | undefined; + instance: Client | undefined; + conversationTextHistory: string | undefined; + _abort: AbortController = new AbortController(); + + async create(provider: PostableViewProvider, template: Command) { + const apiKey = await getSecret("openai.apiKey", ""); + + // If the user still uses the now deprecated openai.apiKey config, move it to the secrets store + // and unset the config. + if (getConfig("openai.apiKey")) { + setSecret("openai.apiKey", getConfig("openai.apiKey")); + unsetConfig("openai.apiKey"); + } + + const { apiBaseUrl } = { + apiBaseUrl: getConfig("openai.apiBaseUrl") as string | undefined, + }; + + this.viewProvider = provider; + this.conversationTextHistory = undefined; + this.instance = new Client(apiKey, { apiUrl: apiBaseUrl }); + } + + destroy() { + this.instance = undefined; + this.conversationTextHistory = undefined; + } + + abort() { + this._abort.abort(); + this._abort = new AbortController(); + } + + async send(message: string, systemMessage?: string, template?: Command): Promise { + let isFollowup = false; + + lastMessage = message; + + if (template) { + lastTemplate = template; + } + + if (!template && !lastTemplate) { + return; + } + + if (!template) { + template = lastTemplate!; + isFollowup = true; + } + + if (systemMessage) { + lastSystemMessage = systemMessage; + } + if (!systemMessage && !lastSystemMessage) { + return; + } + if (!systemMessage) { + systemMessage = lastSystemMessage!; + } + + let prompt; + if (!isFollowup) { + this.viewProvider?.postMessage({ type: "newChat" }); + // The first message should have the system message prepended + prompt = `${message}`; + } else { + // followups should have the conversation history prepended + prompt = `${this.conversationTextHistory ?? ""}${message}`; + } + + const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE; + const samplingParameters: InferParams = { + prompt, + template: modelTemplate.replace("{system}", systemMessage), + ...template?.completionParams, + temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number), + model: { + name: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2", + ctx: template?.completionParams?.ctx ?? DEFAULT_CTX, + }, + n_predict: llamaMaxTokens(prompt, DEFAULT_CTX), + }; + + try { + this.viewProvider?.postMessage({ type: "requestMessage", value: message }); + + const editor = vscode.window.activeTextEditor!; + const selection = getSelectionInfo(editor); + let partialText = ""; + + const goinferResponse: InferResult = await this.instance!.completeStream(samplingParameters, { + onOpen: (response) => { + console.log("Opened stream, HTTP status code", response.status); + }, + onUpdate: (partialResponse: StreamedMessage) => { + partialText += partialResponse.content; + // console.log("P", partialText); + const msg = this.toProviderResponse(partialText); + // console.log("MSG:", msg.text); + this.viewProvider?.postMessage({ + type: "partialResponse", + value: msg, + }); + }, + signal: this._abort.signal, + }); + + // Reformat the API response into a ProvderResponse + const response = this.toProviderResponse(goinferResponse.text); + + // Append the last response to the conversation history + this.conversationTextHistory = `${this.conversationTextHistory ?? ""}${prompt} ${response.text}`; + this.viewProvider?.postMessage({ type: "responseFinished", value: response }); + + if (!isFollowup) { + handleResponseCallbackType(template, editor, selection, response.text); + } + } catch (error) { + displayWarning(String(error)); + } + } + + async repeatLast() { + if (!lastMessage || !lastSystemMessage || !lastTemplate) { + return; + } + + await this.send(lastMessage, lastSystemMessage, lastTemplate); + } + + toProviderResponse(text: string) { + return { + text, + parentMessageId: "", + converastionId: "", + id: "", + }; + } +} diff --git a/src/providers/index.ts b/src/providers/index.ts index 088865a..2295931 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -1,6 +1,8 @@ import type * as vscode from "vscode"; import { AnthropicProvider } from "./anthropic"; +import { GoinferProvider } from "./goinfer"; +import { KoboldcppProvider } from "./koboldcpp"; import { OpenAIProvider } from "./openai"; import { AIProvider, type Command } from "../templates/render"; @@ -41,4 +43,6 @@ export interface Provider { export const providers = { [AIProvider.OpenAI]: OpenAIProvider, [AIProvider.Anthropic]: AnthropicProvider, + [AIProvider.Goinfer]: GoinferProvider, + [AIProvider.KoboldCpp]: KoboldcppProvider, }; diff --git a/src/providers/koboldcpp.ts b/src/providers/koboldcpp.ts new file mode 100644 index 0000000..c8a265b --- /dev/null +++ b/src/providers/koboldcpp.ts @@ -0,0 +1,142 @@ +/* eslint-disable unused-imports/no-unused-vars */ +import * as vscode from "vscode"; + +import { type PostableViewProvider, type ProviderResponse, type Provider } from "."; +import { Client, DEFAULT_TEMPLATE, DEFAULT_CTX, type KoboldInferParams } from "./sdks/koboldcpp"; +import { type Command } from "../templates/render"; +import { handleResponseCallbackType } from "../templates/runner"; +import { displayWarning, formatPrompt, getConfig, getSelectionInfo, llamaMaxTokens } from "../utils"; + +let lastMessage: string | undefined; +let lastTemplate: Command | undefined; +let lastSystemMessage: string | undefined; + +export class KoboldcppProvider implements Provider { + viewProvider: PostableViewProvider | undefined; + instance: Client | undefined; + conversationTextHistory: string | undefined; + _abort: AbortController = new AbortController(); + + async create(provider: PostableViewProvider, template: Command) { + const { apiBaseUrl } = { + apiBaseUrl: getConfig("openai.apiBaseUrl") as string | undefined, + }; + + this.viewProvider = provider; + this.conversationTextHistory = undefined; + this.instance = new Client("", { apiUrl: apiBaseUrl }); + } + + destroy() { + this.instance = undefined; + this.conversationTextHistory = undefined; + } + + abort() { + this._abort.abort(); + this._abort = new AbortController(); + } + + async send(message: string, systemMessage?: string, template?: Command): Promise { + let isFollowup = false; + + lastMessage = message; + + if (template) { + lastTemplate = template; + } + + if (!template && !lastTemplate) { + return; + } + + if (!template) { + template = lastTemplate!; + isFollowup = true; + } + + if (systemMessage) { + lastSystemMessage = systemMessage; + } + if (!systemMessage && !lastSystemMessage) { + return; + } + if (!systemMessage) { + systemMessage = lastSystemMessage!; + } + + let prompt; + if (!isFollowup) { + this.viewProvider?.postMessage({ type: "newChat" }); + // The first message should have the system message prepended + prompt = `${message}`; + } else { + // followups should have the conversation history prepended + prompt = `${this.conversationTextHistory ?? ""}${message}`; + } + + const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE; + const samplingParameters: KoboldInferParams = { + prompt: formatPrompt(prompt, modelTemplate, systemMessage), + ...template?.completionParams, + temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number), + max_length: llamaMaxTokens(prompt, DEFAULT_CTX), + }; + console.log("Params", samplingParameters); + + try { + this.viewProvider?.postMessage({ type: "requestMessage", value: message }); + + const editor = vscode.window.activeTextEditor!; + const selection = getSelectionInfo(editor); + let partialText = ""; + + await this.instance!.completeStream(samplingParameters, { + onOpen: (response) => { + console.log("Opened stream, HTTP status code", response.status); + }, + onUpdate: (partialResponse: string) => { + partialText += partialResponse; + // console.log("P", partialText); + const msg = this.toProviderResponse(partialText); + // console.log("MSG:", msg.text); + this.viewProvider?.postMessage({ + type: "partialResponse", + value: msg, + }); + }, + signal: this._abort.signal, + }); + + // Reformat the API response into a ProvderResponse + const response = this.toProviderResponse(partialText); + + // Append the last response to the conversation history + this.conversationTextHistory = `${this.conversationTextHistory ?? ""}${prompt} ${response.text}`; + this.viewProvider?.postMessage({ type: "responseFinished", value: response }); + + if (!isFollowup) { + handleResponseCallbackType(template, editor, selection, response.text); + } + } catch (error) { + displayWarning(String(error)); + } + } + + async repeatLast() { + if (!lastMessage || !lastSystemMessage || !lastTemplate) { + return; + } + + await this.send(lastMessage, lastSystemMessage, lastTemplate); + } + + toProviderResponse(text: string) { + return { + text, + parentMessageId: "", + converastionId: "", + id: "", + }; + } +} diff --git a/src/providers/sdks/goinfer.ts b/src/providers/sdks/goinfer.ts new file mode 100644 index 0000000..cd691e5 --- /dev/null +++ b/src/providers/sdks/goinfer.ts @@ -0,0 +1,158 @@ +import { fetchEventSource } from "@ai-zen/node-fetch-event-source"; +import fetch from "node-fetch"; +import type { Response as NodeFetchResponse } from "node-fetch"; + +interface ModelConf { + name: string; + ctx?: number; + freq_rope_base?: number; + freq_rope_scale?: number; +} + +export interface InferParams { + prompt: string; + template?: string; + stream?: boolean; + threads?: number; + model?: ModelConf; + n_predict?: number; + top_k?: number; + top_p?: number; + temperature?: number; + frequency_penalty?: number; + presence_penalty?: number; + repeat_penalty?: number; + tfs_z?: number; + stop?: string[]; +} + +export interface InferResult { + text: string; + thinkingTime: number; + thinkingTimeFormat: string; + inferenceTime: number; + emitTime: number; + emitTimeFormat: string; + totalTime: number; + totalTimeFormat: string; + tokensPerSecond: number; + totalTokens: number; +} + +export enum StreamedMsgType { + TokenMsgType = "token", + SystemMsgType = "system", + ErrorMsgType = "error", +} + +export interface StreamedMessage { + content: string; + num: number; + msg_type: StreamedMsgType; + data?: { [key: string]: any }; +} + +export type OnOpen = (response: NodeFetchResponse) => void | Promise; +export type OnUpdate = (completion: StreamedMessage) => void | Promise; + +const DEFAULT_API_URL = "https://localhost:5143"; +export const DEFAULT_CTX = 2048; +export const DEFAULT_TEMPLATE = "{system}\n\n{prompt}"; + +export class Client { + private apiUrl: string; + + constructor(private apiKey: string, options?: { apiUrl?: string }) { + this.apiUrl = options?.apiUrl ?? DEFAULT_API_URL; + } + + async complete(params: InferParams, options?: { signal?: AbortSignal }): Promise { + const response = await fetch(`${this.apiUrl}/completion`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify({ ...params, stream: false }), + signal: options?.signal, + }); + + if (!response.ok) { + const error = new Error(`Sampling error: ${response.status} ${response.statusText}`); + console.error(error); + throw error; + } + + const completion = (await response.json()) as InferResult; + return completion; + } + + completeStream(params: InferParams, { onOpen, onUpdate, signal }: { onOpen?: OnOpen; onUpdate?: OnUpdate; signal?: AbortSignal }): Promise { + const abortController = new AbortController(); + + return new Promise((resolve, reject) => { + signal?.addEventListener("abort", (event) => { + abortController.abort(event); + reject(new Error("Caller aborted completeStream")); + }); + + const body = JSON.stringify({ ...params, stream: true }); + fetchEventSource(`${this.apiUrl}/completion`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body, + signal: abortController.signal, + onopen: async (response) => { + if (!response.ok) { + abortController.abort(); + return reject(new Error(`Failed to open stream, HTTP status code ${response.status}: ${response.statusText}`)); + } + if (onOpen) { + await Promise.resolve(onOpen(response)); + } + }, + onmessage: (ev) => { + const completion = JSON.parse(ev.data) as StreamedMessage; + /* if (onUpdate) { + Promise.resolve(onUpdate(completion)).catch((error) => { + abortController.abort(); + reject(error); + }); + } */ + + // console.log(completion); + + switch (completion.msg_type) { + case "system": + // console.log("SYSTEM MSG") + // console.log(completion) + switch (completion.content) { + case "result": + return resolve(completion.data as InferResult); + case "error": + abortController.abort(); + return reject(new Error("inference error")); + default: + break; + } + break; + case "token": + if (onUpdate) { + onUpdate(completion); + } + } + }, + onerror: (error) => { + console.error("Inference error:", error); + abortController.abort(); + return reject(error); + }, + }); + }); + } +} diff --git a/src/providers/sdks/koboldcpp.ts b/src/providers/sdks/koboldcpp.ts new file mode 100644 index 0000000..0d008ea --- /dev/null +++ b/src/providers/sdks/koboldcpp.ts @@ -0,0 +1,97 @@ +import { fetchEventSource } from "@ai-zen/node-fetch-event-source"; +import type { Response as NodeFetchResponse } from "node-fetch"; + +export interface KoboldInferParams { + prompt: string; + max_length?: number; + top_k?: number; + top_p?: number; + temperature?: number; + tfs?: number; + stop_sequence?: string[]; +} + +export type OnOpen = (response: NodeFetchResponse) => void | Promise; +export type OnUpdate = (completion: string) => void | Promise; + +const DEFAULT_API_URL = "https://localhost:5001"; +export const DEFAULT_TEMPLATE = "{system}\n\n{prompt}"; +export const DEFAULT_CTX = 2048; + +export class Client { + private apiUrl: string; + + constructor(private apiKey: string, options?: { apiUrl?: string }) { + this.apiUrl = options?.apiUrl ?? DEFAULT_API_URL; + } + + /* async complete(params: KoboldInferParams, options?: { signal?: AbortSignal }): Promise { + const response = await fetch(`${this.apiUrl}/api/extra/generate/stream`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + }, + body: JSON.stringify({ ...params, stream: false }), + signal: options?.signal, + }); + + if (!response.ok) { + const error = new Error(`Sampling error: ${response.status} ${response.statusText}`); + console.error(error); + throw error; + } + + const completion = (await response.json()); + return "" + } */ + + completeStream(params: KoboldInferParams, { onOpen, onUpdate, signal }: { onOpen?: OnOpen; onUpdate?: OnUpdate; signal?: AbortSignal }): Promise { + const abortController = new AbortController(); + + console.log("Url", this.apiUrl); + + return new Promise((resolve, reject) => { + signal?.addEventListener("abort", (event) => { + abortController.abort(event); + reject(new Error("Caller aborted completeStream")); + }); + + const body = JSON.stringify({ ...params }); + fetchEventSource(`${this.apiUrl}/api/extra/generate/stream`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body, + signal: abortController.signal, + onopen: async (response) => { + if (!response.ok) { + abortController.abort(); + return reject(new Error(`Failed to open stream, HTTP status code ${response.status}: ${response.statusText}`)); + } + if (onOpen) { + await Promise.resolve(onOpen(response)); + } + }, + onmessage: (ev) => { + const completion = JSON.parse(ev.data); + if (onUpdate) { + onUpdate(completion.token); + } + }, + onclose: () => { + // console.log("Close stream") + return resolve(); + }, + onerror: (error) => { + console.error("Inference error:", error); + abortController.abort(); + return reject(error); + }, + }); + }); + } +} diff --git a/src/templates/render.ts b/src/templates/render.ts index 63de198..dfb44ad 100644 --- a/src/templates/render.ts +++ b/src/templates/render.ts @@ -90,6 +90,8 @@ export enum CallbackType { export enum AIProvider { OpenAI = "openai", Anthropic = "anthropic", + Goinfer = "goinfer", + KoboldCpp = "koboldcpp", } export interface Command { @@ -142,6 +144,15 @@ const defaultAnthropicCompletionParams = () => ({ temperature: 0.3, }); +// https://synw.github.io/goinfer/llama_api/inference +const defaultGoinferCompletionParams = () => ({ + temperature: 0.3, +}); + +const defaultKoboldcppCompletionParams = () => ({ + temperature: 0.3, +}); + export const defaultCommands: Command[] = [ // Completion { @@ -452,6 +463,21 @@ export const buildCommandTemplate = (commandName: string): Command => { temperature, }; break; + case "goinfer": + completionParams = { + ...defaultGoinferCompletionParams(), + ...template.completionParams, + model, + temperature, + }; + break; + case "koboldcpp": + completionParams = { + ...defaultKoboldcppCompletionParams(), + ...template.completionParams, + temperature, + }; + break; default: completionParams = { ...template.completionParams }; break; diff --git a/src/utils.ts b/src/utils.ts index 4a68b41..e9921f0 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -2,6 +2,7 @@ import fs from "node:fs"; import path from "node:path"; import Glob from "fast-glob"; +import llamaTokenizer from "llama-tokenizer-js"; import * as vscode from "vscode"; import { ExtensionState } from "./extension"; @@ -231,3 +232,12 @@ export const randomString = (): string => { } return result; }; + +export function llamaMaxTokens(prompt: string, ctx: number) { + const n = llamaTokenizer.encode(prompt).length; + return ctx - n; +} + +export function formatPrompt(prompt: string, template: string, systemMessage: string) { + return template.replace("{system}", systemMessage).replace("{prompt}", prompt); +}