diff --git a/README.md b/README.md index c9198ce..fc7ba3b 100644 --- a/README.md +++ b/README.md @@ -30,13 +30,15 @@ Explore more options by using `cli-gpt --help`. # or (if you have ~/.local/bin created and added to your PATH) mv cli-gpt ~/.local/bin ``` -5. Configure `cli-gpt` by creating a `~/.cli-gpt` file: +5. Configure `cli-gpt` by creating a `~/.cli-gpt` file with your desired parameters: ``` OPENAI_API_KEY= MODEL=gpt-4 ``` Replace `` with your API key, and `gpt-4` with `gpt-3.5-turbo` if you don't have access to GPT-4 yet. + You can also provide additional parameters in the `~/.cli-gpt` file, such as `TEMPERATURE`, `TOP_P`, `N`, `MAX_TOKENS`, `PRESENCE_PENALTY`, `FREQUENCY_PENALTY`, `STOP`, and `LOGIT_BIAS`. To learn more about these parameters, refer to the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create). + ### Tips - You can also clone this repo and run `deno install` (follow the instructions on changing your PATH): diff --git a/src/ChatCompletion.ts b/src/ChatCompletion.ts index bdcfd6d..06f50de 100644 --- a/src/ChatCompletion.ts +++ b/src/ChatCompletion.ts @@ -1,4 +1,5 @@ import { Message } from './ConversationPersistance.ts'; +import { Config } from './loadConfig.ts'; function normalizeMessages(messages: Message[]): Message[] { return messages.reduce((result, message) => { @@ -13,13 +14,11 @@ function normalizeMessages(messages: Message[]): Message[] { } export class ChatCompletion { - private openaiApiKey: string; - private model: string; + private config: Config; private messages: Message[] = []; - constructor(openaiApiKey: string, model: string = 'gpt-4') { - this.openaiApiKey = openaiApiKey; - this.model = model; + constructor(config: Config) { + this.config = config; } setMessages(messages: Message[]) { @@ -27,15 +26,15 @@ export class ChatCompletion { } async *complete(): AsyncGenerator { - const { model, openaiApiKey } = this; + const { config: { api_key, ...config } } = this; const response = await fetch('https://api.openai.com/v1/chat/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', - 'Authorization': `Bearer ${openaiApiKey}`, + 'Authorization': `Bearer ${api_key}`, }, body: JSON.stringify({ - model, + ...config, messages: normalizeMessages(this.messages), stream: true, }), diff --git a/src/index.ts b/src/index.ts index 1e96075..c43289a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,23 +1,12 @@ -import { load } from 'std/dotenv/mod.ts'; -import { join } from 'std/path/mod.ts'; import { writeText as copyToClipboard } from 'copy_paste/mod.ts'; import { prompt } from './prompt.ts'; import { ChatCompletion } from './ChatCompletion.ts'; import { printHelp } from './printHelp.ts'; import { ConversationPersistance } from './ConversationPersistance.ts'; import { parseArgs } from './parseArgs.ts'; +import { loadConfig } from './loadConfig.ts'; -// TODO: add ability to set the other model params with env vars - -const env = await load({ - envPath: join(Deno.env.get('HOME')!, '.cli-gpt'), -}); - -if (env.OPENAI_API_KEY === undefined) { - console.error('OPENAI_API_KEY environment variable is not set'); - Deno.exit(1); -} - +const config = await loadConfig(); const conversationPersistance = new ConversationPersistance(); const { flags, role, readFiles, prompt: promptFromArgs } = parseArgs(); const { affectInitialMessages } = flags; @@ -55,7 +44,7 @@ if (flags.help) { } if (role === 'user') { - const chatCompletion = new ChatCompletion(env.OPENAI_API_KEY, env.MODEL); + const chatCompletion = new ChatCompletion(config); const encoder = new TextEncoder(); const write = (chunk: string) => Deno.stdout.write(encoder.encode(chunk)); const responseContent = []; diff --git a/src/loadConfig.ts b/src/loadConfig.ts new file mode 100644 index 0000000..1c2c936 --- /dev/null +++ b/src/loadConfig.ts @@ -0,0 +1,108 @@ +import { load } from 'std/dotenv/mod.ts'; +import { join } from 'std/path/mod.ts'; + +export type Config = { + api_key: string; + model: string; + temperature?: number; + top_p?: number; + n?: number; + max_tokens?: number; + presence_penalty?: number; + frequency_penalty?: number; + stop?: string | string[]; + logit_bias?: Record; +}; + +function validateParam( + paramName: string, + parseValue: () => T | undefined, + isValid: (value: T) => boolean, +): T | undefined { + try { + const value = parseValue(); + + if (value === undefined) { + return undefined; + } + + if (!isValid(value)) { + throw new Error(`Invalid value for parameter ${paramName}: ${value}`); + } + + return value; + } catch (error) { + console.error(`Error parsing parameter ${paramName}: ${error.message}`); + Deno.exit(1); + } +} + +export async function loadConfig(): Promise { + const env = await load({ + envPath: join(Deno.env.get('HOME')!, '.cli-gpt'), + }); + + if (env.OPENAI_API_KEY === undefined) { + console.error('OPENAI_API_KEY environment variable is not set'); + Deno.exit(1); + } + + return { + api_key: env.OPENAI_API_KEY, + model: env.MODEL ?? 'gpt-4', + temperature: validateParam( + 'temperature', + () => + env.TEMPERATURE !== undefined ? parseFloat(env.TEMPERATURE) : undefined, + (value) => value >= 0 && value <= 2, + ), + top_p: validateParam( + 'top_p', + () => (env.TOP_P !== undefined ? parseFloat(env.TOP_P) : undefined), + (value) => value >= 0 && value <= 1, + ), + n: validateParam( + 'n', + () => (env.N !== undefined ? parseInt(env.N) : undefined), + (value) => value > 0, + ), + max_tokens: validateParam( + 'max_tokens', + () => (env.MAX_TOKENS !== undefined ? parseInt(env.MAX_TOKENS) : undefined), + (value) => value >= 0, + ), + presence_penalty: validateParam( + 'presence_penalty', + () => + env.PRESENCE_PENALTY !== undefined + ? parseFloat(env.PRESENCE_PENALTY) + : undefined, + (value) => value >= -2 && value <= 2, + ), + frequency_penalty: validateParam( + 'frequency_penalty', + () => + env.FREQUENCY_PENALTY !== undefined + ? parseFloat(env.FREQUENCY_PENALTY) + : undefined, + (value) => value >= -2 && value <= 2, + ), + stop: validateParam( + 'stop', + () => (env.STOP !== undefined ? JSON.parse(env.STOP) : undefined), + (value) => + (typeof value === 'string' && value.trim() !== '') || + (Array.isArray(value) && value.length > 0 && value.length <= 4), + ), + logit_bias: validateParam>( + 'logit_bias', + () => (env.LOGIT_BIAS !== undefined ? JSON.parse(env.LOGIT_BIAS) : undefined), + (value) => { + for (const bias of Object.values(value)) { + if (bias < -100 || bias > 100) return false; + } + return true; + }, + ), + } +}