Skip to content

Commit

Permalink
Merge pull request #22 from synw/main
Browse files Browse the repository at this point in the history
Add new providers: Koboldcpp and Goinfer
  • Loading branch information
nvms authored Aug 30, 2023
2 parents 9632fb1 + 0cd66c8 commit 24d6232
Show file tree
Hide file tree
Showing 9 changed files with 627 additions and 2 deletions.
5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@
"default": "openai",
"enum": [
"anthropic",
"openai"
"openai",
"goinfer",
"koboldcpp"
],
"description": "Which provider should this command use?"
}
Expand Down Expand Up @@ -348,6 +350,7 @@
"cheerio": "1.0.0-rc.12",
"fast-glob": "^3.2.12",
"fetch": "^1.1.0",
"llama-tokenizer-js": "^1.1.3",
"node-fetch": "^3.3.1"
}
}
32 changes: 31 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

155 changes: 155 additions & 0 deletions src/providers/goinfer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/* eslint-disable unused-imports/no-unused-vars */
import * as vscode from "vscode";

import { type PostableViewProvider, type ProviderResponse, type Provider } from ".";
import { Client, type InferParams, type InferResult, type StreamedMessage, DEFAULT_CTX, DEFAULT_TEMPLATE } from "./sdks/goinfer";
import { type Command } from "../templates/render";
import { handleResponseCallbackType } from "../templates/runner";
import { displayWarning, getConfig, getSecret, getSelectionInfo, llamaMaxTokens, setSecret, unsetConfig } from "../utils";

let lastMessage: string | undefined;
let lastTemplate: Command | undefined;
let lastSystemMessage: string | undefined;

export class GoinferProvider implements Provider {
viewProvider: PostableViewProvider | undefined;
instance: Client | undefined;
conversationTextHistory: string | undefined;
_abort: AbortController = new AbortController();

async create(provider: PostableViewProvider, template: Command) {
const apiKey = await getSecret<string>("openai.apiKey", "");

// If the user still uses the now deprecated openai.apiKey config, move it to the secrets store
// and unset the config.
if (getConfig<string>("openai.apiKey")) {
setSecret("openai.apiKey", getConfig<string>("openai.apiKey"));
unsetConfig("openai.apiKey");
}

const { apiBaseUrl } = {
apiBaseUrl: getConfig("openai.apiBaseUrl") as string | undefined,
};

this.viewProvider = provider;
this.conversationTextHistory = undefined;
this.instance = new Client(apiKey, { apiUrl: apiBaseUrl });
}

destroy() {
this.instance = undefined;
this.conversationTextHistory = undefined;
}

abort() {
this._abort.abort();
this._abort = new AbortController();
}

async send(message: string, systemMessage?: string, template?: Command): Promise<void | ProviderResponse> {
let isFollowup = false;

lastMessage = message;

if (template) {
lastTemplate = template;
}

if (!template && !lastTemplate) {
return;
}

if (!template) {
template = lastTemplate!;
isFollowup = true;
}

if (systemMessage) {
lastSystemMessage = systemMessage;
}
if (!systemMessage && !lastSystemMessage) {
return;
}
if (!systemMessage) {
systemMessage = lastSystemMessage!;
}

let prompt;
if (!isFollowup) {
this.viewProvider?.postMessage({ type: "newChat" });
// The first message should have the system message prepended
prompt = `${message}`;
} else {
// followups should have the conversation history prepended
prompt = `${this.conversationTextHistory ?? ""}${message}`;
}

const modelTemplate = template?.completionParams?.template ?? DEFAULT_TEMPLATE;
const samplingParameters: InferParams = {
prompt,
template: modelTemplate.replace("{system}", systemMessage),
...template?.completionParams,
temperature: template?.completionParams?.temperature ?? (getConfig("openai.temperature") as number),
model: {
name: template?.completionParams?.model ?? (getConfig("openai.model") as string) ?? "llama2",
ctx: template?.completionParams?.ctx ?? DEFAULT_CTX,
},
n_predict: llamaMaxTokens(prompt, DEFAULT_CTX),
};

try {
this.viewProvider?.postMessage({ type: "requestMessage", value: message });

const editor = vscode.window.activeTextEditor!;
const selection = getSelectionInfo(editor);
let partialText = "";

const goinferResponse: InferResult = await this.instance!.completeStream(samplingParameters, {
onOpen: (response) => {
console.log("Opened stream, HTTP status code", response.status);
},
onUpdate: (partialResponse: StreamedMessage) => {
partialText += partialResponse.content;
// console.log("P", partialText);
const msg = this.toProviderResponse(partialText);
// console.log("MSG:", msg.text);
this.viewProvider?.postMessage({
type: "partialResponse",
value: msg,
});
},
signal: this._abort.signal,
});

// Reformat the API response into a ProvderResponse
const response = this.toProviderResponse(goinferResponse.text);

// Append the last response to the conversation history
this.conversationTextHistory = `${this.conversationTextHistory ?? ""}${prompt} ${response.text}`;
this.viewProvider?.postMessage({ type: "responseFinished", value: response });

if (!isFollowup) {
handleResponseCallbackType(template, editor, selection, response.text);
}
} catch (error) {
displayWarning(String(error));
}
}

async repeatLast() {
if (!lastMessage || !lastSystemMessage || !lastTemplate) {
return;
}

await this.send(lastMessage, lastSystemMessage, lastTemplate);
}

toProviderResponse(text: string) {
return {
text,
parentMessageId: "",
converastionId: "",
id: "",
};
}
}
4 changes: 4 additions & 0 deletions src/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import type * as vscode from "vscode";

import { AnthropicProvider } from "./anthropic";
import { GoinferProvider } from "./goinfer";
import { KoboldcppProvider } from "./koboldcpp";
import { OpenAIProvider } from "./openai";
import { AIProvider, type Command } from "../templates/render";

Expand Down Expand Up @@ -41,4 +43,6 @@ export interface Provider {
export const providers = {
[AIProvider.OpenAI]: OpenAIProvider,
[AIProvider.Anthropic]: AnthropicProvider,
[AIProvider.Goinfer]: GoinferProvider,
[AIProvider.KoboldCpp]: KoboldcppProvider,
};
Loading

0 comments on commit 24d6232

Please sign in to comment.