diff --git a/src/providers/anthropic.ts b/src/providers/anthropic.ts index 7dcbf72..e7781f5 100644 --- a/src/providers/anthropic.ts +++ b/src/providers/anthropic.ts @@ -17,7 +17,7 @@ export class AnthropicProvider implements Provider { conversationTextHistory: string | undefined; _abort: AbortController = new AbortController(); - async create(provider: PostableViewProvider, template?: Command) { + async create(provider: PostableViewProvider, template: Command) { const apiKey = await getSecret("anthropic.apiKey", ""); // If the user still uses the now deprecated anthropic.apiKey config, move it to the secrets store @@ -85,11 +85,10 @@ export class AnthropicProvider implements Provider { } const samplingParameters: SamplingParameters = { + ...template?.completionParams, prompt, - temperature: template?.temperature ?? (getConfig("anthropic.temperature") as number), - max_tokens_to_sample: template?.maxTokens ?? (getConfig("anthropic.maxTokens") as number) ?? 4096, - top_k: template?.numberOfChoices ?? -1, - model: template?.model ?? (getConfig("anthropic.model") as string) ?? "claude-instant-v1", + temperature: template?.completionParams?.temperature ?? (getConfig("anthropic.temperature") as number), + model: template?.completionParams?.model ?? (getConfig("anthropic.model") as string) ?? "claude-instant-v1", }; try { diff --git a/src/providers/index.ts b/src/providers/index.ts index 68d80a9..088865a 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -31,7 +31,7 @@ export interface Provider { * @param options * @param provider */ - create(provider: vscode.WebviewViewProvider, template?: Command): Promise; + create(provider: vscode.WebviewViewProvider, template: Command): Promise; destroy(): void; send: (message: string, systemMessage?: string, template?: Command) => Promise; repeatLast: () => Promise; diff --git a/src/providers/openai.ts b/src/providers/openai.ts index 73c0a8b..24c43a4 100644 --- a/src/providers/openai.ts +++ b/src/providers/openai.ts @@ -22,7 +22,7 @@ export class OpenAIProvider implements Provider { conversationState: ConversationState = { conversationId: "", parentMessageId: "" }; _abort: AbortController = new AbortController(); - async create(provider: PostableViewProvider, template?: Command) { + async create(provider: PostableViewProvider, template: Command) { const apiKey = getConfig("openai.apiKey") ?? (await getSecret("openai.apiKey", "llama")); // If the user still uses the now deprecated openai.apiKey config, move it to the secrets store @@ -38,10 +38,11 @@ export class OpenAIProvider implements Provider { temperature = 0.8, } = { apiBaseUrl: getConfig("openai.apiBaseUrl") ?? getConfig("apiBaseUrl"), - model: template?.model ?? getConfig("openai.model") ?? getConfig("model"), - temperature: template?.temperature ?? getConfig("openai.temperature") ?? getConfig("temperature"), + model: template?.completionParams?.model ?? getConfig("openai.model"), + temperature: template?.completionParams?.temperature ?? getConfig("openai.temperature"), }; this.viewProvider = provider; + this.instance = new ChatGPTAPI({ apiKey, apiBaseUrl, @@ -49,6 +50,7 @@ export class OpenAIProvider implements Provider { // @ts-expect-error this works just fine fetch, completionParams: { + ...template.completionParams, model, temperature, // max_tokens: template?.maxTokens ?? getConfig("openai.maxTokens") ?? 4096, diff --git a/src/providers/sdks/anthropic.ts b/src/providers/sdks/anthropic.ts index efb3d4f..cc63cc0 100644 --- a/src/providers/sdks/anthropic.ts +++ b/src/providers/sdks/anthropic.ts @@ -5,7 +5,6 @@ import type { Response as NodeFetchResponse } from "node-fetch"; export interface SamplingParameters { prompt: string; temperature?: number; - max_tokens_to_sample: number; stop_sequences?: string[]; top_k?: number; top_p?: number; diff --git a/src/templates/render.ts b/src/templates/render.ts index 6925c78..8b8fa47 100644 --- a/src/templates/render.ts +++ b/src/templates/render.ts @@ -93,10 +93,6 @@ export enum AIProvider { } export interface Command { - model?: string; - maxTokens?: number; - temperature?: number; - numberOfChoices?: number; command?: string; label: string; description?: string; @@ -108,13 +104,12 @@ export interface Command { callbackType?: CallbackType; category?: string; provider?: AIProvider; + completionParams?: { + [key: string]: any; + }; } export const baseCommand: Command = { - maxTokens: 4096, - numberOfChoices: 1, - model: "gpt-3.5-turbo", - temperature: 0.3, label: "Unnamed command", systemMessageTemplate: "You are a {{language}} coding assistant.", userMessageTemplate: "", @@ -132,6 +127,21 @@ export const baseCommand: Command = { provider: AIProvider.OpenAI, }; +// https://platform.openai.com/docs/api-reference/chat/create +const defaultOpenAICompletionParams = () => ({ + n: 1, + model: "gpt-3.5-turbo", + temperature: 0.3, +}); + +// https://docs.anthropic.com/claude/reference/complete_post +const defaultAnthropicCompletionParams = () => ({ + // max_tokens_to_sample: 4096, + top_k: 5, + model: "claude-instant-v1", + temperature: 0.3, +}); + export const defaultCommands: Command[] = [ // Completion { @@ -416,19 +426,42 @@ export const buildCommandTemplate = (commandName: string): Command => { // we want to fallback to the value defined in settings.json, NOT the value // defined in baseCommand. Otherwise, the settings.json value will be ignored. const provider = template.provider ?? "openai"; - const model = template.model ?? getConfig(`${provider}.model`); - const temperature = template.temperature ?? getConfig(`${provider}.temperature`); + const model = template?.completionParams?.model ?? getConfig(`${provider}.model`); + const temperature = template?.completionParams?.temperature ?? getConfig(`${provider}.temperature`); const languageInstructions = { ...base.languageInstructions, ...template.languageInstructions }; const userMessageTemplate = template.userMessageTemplate.trim(); const systemMessageTemplate = template.systemMessageTemplate?.trim(); + let completionParams; + + switch (provider) { + case "openai": + completionParams = { + ...defaultOpenAICompletionParams(), + ...template.completionParams, + model, + temperature, + }; + break; + case "anthropic": + completionParams = { + ...defaultAnthropicCompletionParams(), + ...template.completionParams, + model, + temperature, + }; + break; + default: + completionParams = { ...template.completionParams }; + break; + } + return { ...base, - model, - temperature, category: template.category ?? BuiltinCategory.Misc, ...template, + completionParams, languageInstructions, userMessageTemplate, systemMessageTemplate,