Skip to content

Commit

Permalink
support any completion params
Browse files Browse the repository at this point in the history
  • Loading branch information
nvms committed Aug 19, 2023
1 parent b5ca981 commit c8564b8
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 22 deletions.
9 changes: 4 additions & 5 deletions src/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export class AnthropicProvider implements Provider {
conversationTextHistory: string | undefined;
_abort: AbortController = new AbortController();

async create(provider: PostableViewProvider, template?: Command) {
async create(provider: PostableViewProvider, template: Command) {
const apiKey = await getSecret<string>("anthropic.apiKey", "");

// If the user still uses the now deprecated anthropic.apiKey config, move it to the secrets store
Expand Down Expand Up @@ -85,11 +85,10 @@ export class AnthropicProvider implements Provider {
}

const samplingParameters: SamplingParameters = {
...template?.completionParams,
prompt,
temperature: template?.temperature ?? (getConfig("anthropic.temperature") as number),
max_tokens_to_sample: template?.maxTokens ?? (getConfig("anthropic.maxTokens") as number) ?? 4096,
top_k: template?.numberOfChoices ?? -1,
model: template?.model ?? (getConfig("anthropic.model") as string) ?? "claude-instant-v1",
temperature: template?.completionParams?.temperature ?? (getConfig("anthropic.temperature") as number),
model: template?.completionParams?.model ?? (getConfig("anthropic.model") as string) ?? "claude-instant-v1",
};

try {
Expand Down
2 changes: 1 addition & 1 deletion src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export interface Provider {
* @param options
* @param provider
*/
create(provider: vscode.WebviewViewProvider, template?: Command): Promise<void>;
create(provider: vscode.WebviewViewProvider, template: Command): Promise<void>;
destroy(): void;
send: (message: string, systemMessage?: string, template?: Command) => Promise<any>;
repeatLast: () => Promise<void>;
Expand Down
8 changes: 5 additions & 3 deletions src/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export class OpenAIProvider implements Provider {
conversationState: ConversationState = { conversationId: "", parentMessageId: "" };
_abort: AbortController = new AbortController();

async create(provider: PostableViewProvider, template?: Command) {
async create(provider: PostableViewProvider, template: Command) {
const apiKey = getConfig<string>("openai.apiKey") ?? (await getSecret<string>("openai.apiKey", "llama"));

// If the user still uses the now deprecated openai.apiKey config, move it to the secrets store
Expand All @@ -38,17 +38,19 @@ export class OpenAIProvider implements Provider {
temperature = 0.8,
} = {
apiBaseUrl: getConfig<string>("openai.apiBaseUrl") ?? getConfig<string>("apiBaseUrl"),
model: template?.model ?? getConfig<string>("openai.model") ?? getConfig<string>("model"),
temperature: template?.temperature ?? getConfig<number>("openai.temperature") ?? getConfig<number>("temperature"),
model: template?.completionParams?.model ?? getConfig<string>("openai.model"),
temperature: template?.completionParams?.temperature ?? getConfig<number>("openai.temperature"),
};
this.viewProvider = provider;

this.instance = new ChatGPTAPI({
apiKey,
apiBaseUrl,
debug: false,
// @ts-expect-error this works just fine
fetch,
completionParams: {
...template.completionParams,
model,
temperature,
// max_tokens: template?.maxTokens ?? getConfig<number>("openai.maxTokens") ?? 4096,
Expand Down
1 change: 0 additions & 1 deletion src/providers/sdks/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import type { Response as NodeFetchResponse } from "node-fetch";
export interface SamplingParameters {
prompt: string;
temperature?: number;
max_tokens_to_sample: number;
stop_sequences?: string[];
top_k?: number;
top_p?: number;
Expand Down
57 changes: 45 additions & 12 deletions src/templates/render.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ export enum AIProvider {
}

export interface Command {
model?: string;
maxTokens?: number;
temperature?: number;
numberOfChoices?: number;
command?: string;
label: string;
description?: string;
Expand All @@ -108,13 +104,12 @@ export interface Command {
callbackType?: CallbackType;
category?: string;
provider?: AIProvider;
completionParams?: {
[key: string]: any;
};
}

export const baseCommand: Command = {
maxTokens: 4096,
numberOfChoices: 1,
model: "gpt-3.5-turbo",
temperature: 0.3,
label: "Unnamed command",
systemMessageTemplate: "You are a {{language}} coding assistant.",
userMessageTemplate: "",
Expand All @@ -132,6 +127,21 @@ export const baseCommand: Command = {
provider: AIProvider.OpenAI,
};

// https://platform.openai.com/docs/api-reference/chat/create
const defaultOpenAICompletionParams = () => ({
n: 1,
model: "gpt-3.5-turbo",
temperature: 0.3,
});

// https://docs.anthropic.com/claude/reference/complete_post
const defaultAnthropicCompletionParams = () => ({
// max_tokens_to_sample: 4096,
top_k: 5,
model: "claude-instant-v1",
temperature: 0.3,
});

export const defaultCommands: Command[] = [
// Completion
{
Expand Down Expand Up @@ -416,19 +426,42 @@ export const buildCommandTemplate = (commandName: string): Command => {
// we want to fallback to the value defined in settings.json, NOT the value
// defined in baseCommand. Otherwise, the settings.json value will be ignored.
const provider = template.provider ?? "openai";
const model = template.model ?? getConfig<string>(`${provider}.model`);
const temperature = template.temperature ?? getConfig<number>(`${provider}.temperature`);
const model = template?.completionParams?.model ?? getConfig<string>(`${provider}.model`);
const temperature = template?.completionParams?.temperature ?? getConfig<number>(`${provider}.temperature`);

const languageInstructions = { ...base.languageInstructions, ...template.languageInstructions };
const userMessageTemplate = template.userMessageTemplate.trim();
const systemMessageTemplate = template.systemMessageTemplate?.trim();

let completionParams;

switch (provider) {
case "openai":
completionParams = {
...defaultOpenAICompletionParams(),
...template.completionParams,
model,
temperature,
};
break;
case "anthropic":
completionParams = {
...defaultAnthropicCompletionParams(),
...template.completionParams,
model,
temperature,
};
break;
default:
completionParams = { ...template.completionParams };
break;
}

return {
...base,
model,
temperature,
category: template.category ?? BuiltinCategory.Misc,
...template,
completionParams,
languageInstructions,
userMessageTemplate,
systemMessageTemplate,
Expand Down

0 comments on commit c8564b8

Please sign in to comment.