Skip to content

Commit

Permalink
✨ switch between autocomplete models
Browse files Browse the repository at this point in the history
  • Loading branch information
sestinj committed Jun 11, 2024
1 parent d6a3b48 commit 894e1a3
Show file tree
Hide file tree
Showing 15 changed files with 200 additions and 55 deletions.
50 changes: 30 additions & 20 deletions core/config/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -293,26 +293,36 @@ async function intermediateToFinalConfig(
}

// Tab autocomplete model
let autocompleteLlm: BaseLLM | undefined = undefined;
let tabAutocompleteModels: BaseLLM[] = [];
if (config.tabAutocompleteModel) {
if (isModelDescription(config.tabAutocompleteModel)) {
autocompleteLlm = await llmFromDescription(
config.tabAutocompleteModel,
ide.readFile.bind(ide),
uniqueId,
ideSettings,
writeLog,
config.completionOptions,
config.systemMessage,
);

if (autocompleteLlm?.providerName === "free-trial") {
const ghAuthToken = await ide.getGitHubAuthToken();
(autocompleteLlm as FreeTrial).setupGhAuthToken(ghAuthToken);
}
} else {
autocompleteLlm = new CustomLLMClass(config.tabAutocompleteModel);
}
tabAutocompleteModels = (
await Promise.all(
(Array.isArray(config.tabAutocompleteModel)
? config.tabAutocompleteModel
: [config.tabAutocompleteModel]
).map(async (desc) => {
if (isModelDescription(desc)) {
const llm = await llmFromDescription(
desc,
ide.readFile.bind(ide),
uniqueId,
ideSettings,
writeLog,
config.completionOptions,
config.systemMessage,
);

if (llm?.providerName === "free-trial") {
const ghAuthToken = await ide.getGitHubAuthToken();
(llm as FreeTrial).setupGhAuthToken(ghAuthToken);
}
return llm;
} else {
return new CustomLLMClass(desc);
}
}),
)
).filter((x) => x !== undefined) as BaseLLM[];
}

// Context providers
Expand Down Expand Up @@ -381,7 +391,7 @@ async function intermediateToFinalConfig(
contextProviders,
models,
embeddingsProvider: config.embeddingsProvider as any,
tabAutocompleteModel: autocompleteLlm,
tabAutocompleteModels,
reranker: config.reranker as any,
};
}
Expand Down
34 changes: 23 additions & 11 deletions core/core.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import { v4 as uuidv4 } from "uuid";
import type {
ContextItemId,
IDE,
IndexingProgressUpdate,
SiteIndexingConfig,
} from ".";
import type { FromCoreProtocol, ToCoreProtocol } from "./protocol";
import type { IMessenger, Message } from "./util/messenger";
import { v4 as uuidv4 } from "uuid";
import { CompletionProvider } from "./autocomplete/completionProvider.js";
import { ConfigHandler } from "./config/handler.js";
import {
Expand All @@ -22,11 +20,13 @@ import { indexDocs } from "./indexing/docs/index.js";
import TransformersJsEmbeddingsProvider from "./indexing/embeddings/TransformersJsEmbeddingsProvider.js";
import { CodebaseIndexer, PauseToken } from "./indexing/indexCodebase.js";
import Ollama from "./llm/llms/Ollama.js";
import type { FromCoreProtocol, ToCoreProtocol } from "./protocol";
import { GlobalContext } from "./util/GlobalContext.js";
import { logDevData } from "./util/devdata.js";
import { DevDataSqliteDb } from "./util/devdataSqlite.js";
import { fetchwithRequestOptions } from "./util/fetchWithOptions.js";
import historyManager from "./util/history.js";
import type { IMessenger, Message } from "./util/messenger";
import { editConfigJson, getConfigJsonPath } from "./util/paths.js";
import { Telemetry } from "./util/posthog.js";
import { streamDiffLines } from "./util/verticalEdit.js";
Expand All @@ -38,6 +38,7 @@ export class Core {
completionProvider: CompletionProvider;
continueServerClientPromise: Promise<ContinueServerClient>;
indexingState: IndexingProgressUpdate;
private globalContext = new GlobalContext();

private abortedMessageIds: Set<string> = new Set();

Expand Down Expand Up @@ -78,7 +79,7 @@ export class Core {

// Codebase Indexer and ContinueServerClient depend on IdeSettings
const indexingPauseToken = new PauseToken(
new GlobalContext().get("indexingPaused") === true,
this.globalContext.get("indexingPaused") === true,
);
let codebaseIndexerResolve: (_: any) => void | undefined;
this.codebaseIndexerPromise = new Promise(
Expand Down Expand Up @@ -112,7 +113,12 @@ export class Core {

const getLlm = async () => {
const config = await this.configHandler.loadConfig();
return config.tabAutocompleteModel;
const selected = this.globalContext.get("selectedTabAutocompleteModel");
return (
config.tabAutocompleteModels?.find(
(model) => model.title === selected,
) ?? config.tabAutocompleteModels?.[0]
);
};
this.completionProvider = new CompletionProvider(
this.configHandler,
Expand All @@ -134,6 +140,11 @@ export class Core {
this.selectedModelTitle = msg.data;
});

on("update/selectTabAutocompleteModel", async (msg) => {
this.globalContext.update("selectedTabAutocompleteModel", msg.data);
this.configHandler.reloadConfig();
});

// Special
on("abort", (msg) => {
this.abortedMessageIds.add(msg.messageId);
Expand Down Expand Up @@ -223,6 +234,7 @@ export class Core {
rootUrl: msg.data.rootUrl,
title: msg.data.title,
maxDepth: msg.data.maxDepth,
faviconUrl: new URL("/favicon.ico", msg.data.rootUrl).toString(),
};

for await (const _ of indexDocs(
Expand Down Expand Up @@ -498,12 +510,12 @@ export class Core {
mode === "local"
? setupLocalMode
: mode === "freeTrial"
? setupFreeTrialMode
: mode === "localAfterFreeTrial"
? setupLocalAfterFreeTrial
: mode === "apiKeys"
? setupApiKeysMode
: setupOptimizedExistingUserMode,
? setupFreeTrialMode
: mode === "localAfterFreeTrial"
? setupLocalAfterFreeTrial
: mode === "apiKeys"
? setupApiKeysMode
: setupOptimizedExistingUserMode,
);
this.configHandler.reloadConfig();
});
Expand Down
11 changes: 8 additions & 3 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,15 @@ export interface ContextSubmenuItem {
id: string;
title: string;
description: string;
iconUrl?: string;
}

export interface SiteIndexingConfig {
startUrl: string;
rootUrl: string;
title: string;
maxDepth?: number;
faviconUrl: string;
}

export interface IContextProvider {
Expand Down Expand Up @@ -782,7 +784,7 @@ export interface SerializedContinueConfig {
disableSessionTitles?: boolean;
userToken?: string;
embeddingsProvider?: EmbeddingsProviderDescription;
tabAutocompleteModel?: ModelDescription;
tabAutocompleteModel?: ModelDescription | ModelDescription[];
tabAutocompleteOptions?: Partial<TabAutocompleteOptions>;
ui?: ContinueUIConfig;
reranker?: RerankerDescription;
Expand Down Expand Up @@ -824,7 +826,10 @@ export interface Config {
/** The provider used to calculate embeddings. If left empty, Continue will use transformers.js to calculate the embeddings with all-MiniLM-L6-v2 */
embeddingsProvider?: EmbeddingsProviderDescription | EmbeddingsProvider;
/** The model that Continue will use for tab autocompletions. */
tabAutocompleteModel?: CustomLLM | ModelDescription;
tabAutocompleteModel?:
| CustomLLM
| ModelDescription
| (CustomLLM | ModelDescription)[];
/** Options for tab autocomplete */
tabAutocompleteOptions?: Partial<TabAutocompleteOptions>;
/** UI styles customization */
Expand All @@ -847,7 +852,7 @@ export interface ContinueConfig {
disableIndexing?: boolean;
userToken?: string;
embeddingsProvider: EmbeddingsProvider;
tabAutocompleteModel?: ILLM;
tabAutocompleteModels?: ILLM[];
tabAutocompleteOptions?: Partial<TabAutocompleteOptions>;
ui?: ContinueUIConfig;
reranker?: Reranker;
Expand Down
2 changes: 1 addition & 1 deletion core/protocol/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ export interface ListHistoryOptions {
}

export type ToCoreFromIdeOrWebviewProtocol = {
// New
"update/modelChange": [string, void];
"update/selectTabAutocompleteModel": [string, void];

// Special
ping: [string, string];
Expand Down
1 change: 1 addition & 0 deletions core/util/GlobalContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { getGlobalContextFilePath } from "./paths.js";

export type GlobalContextType = {
indexingPaused: boolean;
selectedTabAutocompleteModel: string;
};

/**
Expand Down
12 changes: 11 additions & 1 deletion docs/static/schemas/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,17 @@
"provider": "ollama",
"model": "deepseek-coder:1.3b-base"
},
"$ref": "#/definitions/ModelDescription"
"oneOf": [
{
"$ref": "#/definitions/ModelDescription"
},
{
"type": "array",
"items": {
"$ref": "#/definitions/ModelDescription"
}
}
]
},
"tabAutocompleteOptions": {
"title": "TabAutocompleteOptions",
Expand Down
12 changes: 11 additions & 1 deletion extensions/intellij/src/main/resources/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,17 @@
"provider": "ollama",
"model": "deepseek-coder:1.3b-base"
},
"$ref": "#/definitions/ModelDescription"
"oneOf": [
{
"$ref": "#/definitions/ModelDescription"
},
{
"type": "array",
"items": {
"$ref": "#/definitions/ModelDescription"
}
}
]
},
"tabAutocompleteOptions": {
"title": "TabAutocompleteOptions",
Expand Down
12 changes: 11 additions & 1 deletion extensions/vscode/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,17 @@
"provider": "ollama",
"model": "deepseek-coder:1.3b-base"
},
"$ref": "#/definitions/ModelDescription"
"oneOf": [
{
"$ref": "#/definitions/ModelDescription"
},
{
"type": "array",
"items": {
"$ref": "#/definitions/ModelDescription"
}
}
]
},
"tabAutocompleteOptions": {
"title": "TabAutocompleteOptions",
Expand Down
12 changes: 11 additions & 1 deletion extensions/vscode/continue_rc_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2248,7 +2248,17 @@
"provider": "ollama",
"model": "deepseek-coder:1.3b-base"
},
"$ref": "#/definitions/ModelDescription"
"oneOf": [
{
"$ref": "#/definitions/ModelDescription"
},
{
"type": "array",
"items": {
"$ref": "#/definitions/ModelDescription"
}
}
]
},
"tabAutocompleteOptions": {
"title": "TabAutocompleteOptions",
Expand Down
2 changes: 1 addition & 1 deletion extensions/vscode/src/autocomplete/statusBar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export function setupStatusBar(
? "$(loading~spin) Continue"
: statusBarItemText(enabled);
statusBarItem.tooltip = statusBarItemTooltip(enabled);
statusBarItem.command = "continue.toggleTabAutocompleteEnabled";
statusBarItem.command = "continue.openTabAutocompleteConfigMenu";

statusBarItem.show();

Expand Down
58 changes: 58 additions & 0 deletions extensions/vscode/src/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ContextMenuConfig, IDE } from "core";
import { CompletionProvider } from "core/autocomplete/completionProvider";
import { ConfigHandler } from "core/config/handler";
import { fetchwithRequestOptions } from "core/util/fetchWithOptions";
import { GlobalContext } from "core/util/GlobalContext";
import { getConfigJsonPath } from "core/util/paths";
import { ContinueGUIWebviewViewProvider } from "./debugPanel";
import { DiffManager } from "./diff/horizontal";
Expand Down Expand Up @@ -532,6 +533,63 @@ const commandsMap: (
vscode.ConfigurationTarget.Global,
);
},
"continue.openTabAutocompleteConfigMenu": async () => {
const config = vscode.workspace.getConfiguration("continue");
const enabled = config.get("enableTabAutocomplete");
const quickPick = vscode.window.createQuickPick();
const selected = new GlobalContext().get("selectedTabAutocompleteModel");
const autocompleteModelTitles = ((
await configHandler.loadConfig()
).tabAutocompleteModels
?.map((model) => model.title)
.filter((t) => t !== undefined) || []) as string[];
quickPick.items = [
{
label: enabled
? "$(check) Disable autocomplete"
: "$(circle-slash) Enable autocomplete",
},
{
label: "$(gear) Configure autocomplete options",
},
{
kind: vscode.QuickPickItemKind.Separator,
label: "Switch model",
},
...autocompleteModelTitles.map((title) => ({
label: title === selected ? `$(check) ${title}` : title,
description: title === selected ? "Currently selected" : undefined,
})),
];
quickPick.onDidAccept(() => {
const selectedOption = quickPick.selectedItems[0].label;
if (selectedOption === "$(circle-slash) Enable autocomplete") {
config.update(
"enableTabAutocomplete",
true,
vscode.ConfigurationTarget.Global,
);
} else if (selectedOption === "$(check) Disable autocomplete") {
config.update(
"enableTabAutocomplete",
false,
vscode.ConfigurationTarget.Global,
);
} else if (
selectedOption === "$(gear) Configure autocomplete options"
) {
ide.openFile(getConfigJsonPath());
} else if (autocompleteModelTitles.includes(selectedOption)) {
new GlobalContext().update(
"selectedTabAutocompleteModel",
selectedOption,
);
configHandler.reloadConfig();
}
quickPick.dispose();
});
quickPick.show();
},
};
};

Expand Down
Loading

0 comments on commit 894e1a3

Please sign in to comment.