Skip to content

Commit

Permalink
Add additional params for the model
Browse files Browse the repository at this point in the history
  • Loading branch information
synaptiko committed Apr 9, 2023
1 parent cbe7426 commit 5dd3f9b
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 23 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ Explore more options by using `cli-gpt --help`.
# or (if you have ~/.local/bin created and added to your PATH)
mv cli-gpt ~/.local/bin
```
5. Configure `cli-gpt` by creating a `~/.cli-gpt` file:
5. Configure `cli-gpt` by creating a `~/.cli-gpt` file with your desired parameters:
```
OPENAI_API_KEY=<Your API Key>
MODEL=gpt-4
```
Replace `<Your API Key>` with your API key, and `gpt-4` with `gpt-3.5-turbo` if you don't have access to GPT-4 yet.

You can also provide additional parameters in the `~/.cli-gpt` file, such as `TEMPERATURE`, `TOP_P`, `N`, `MAX_TOKENS`, `PRESENCE_PENALTY`, `FREQUENCY_PENALTY`, `STOP`, and `LOGIT_BIAS`. To learn more about these parameters, refer to the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create).

### Tips

- You can also clone this repo and run `deno install` (follow the instructions on changing your PATH):
Expand Down
15 changes: 7 additions & 8 deletions src/ChatCompletion.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Message } from './ConversationPersistance.ts';
import { Config } from './loadConfig.ts';

function normalizeMessages(messages: Message[]): Message[] {
return messages.reduce<Message[]>((result, message) => {
Expand All @@ -13,29 +14,27 @@ function normalizeMessages(messages: Message[]): Message[] {
}

export class ChatCompletion {
private openaiApiKey: string;
private model: string;
private config: Config;
private messages: Message[] = [];

constructor(openaiApiKey: string, model: string = 'gpt-4') {
this.openaiApiKey = openaiApiKey;
this.model = model;
constructor(config: Config) {
this.config = config;
}

setMessages(messages: Message[]) {
this.messages = messages;
}

async *complete(): AsyncGenerator<string> {
const { model, openaiApiKey } = this;
const { config: { api_key, ...config } } = this;
const response = await fetch('https://api.openai.com/v1/chat/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${openaiApiKey}`,
'Authorization': `Bearer ${api_key}`,
},
body: JSON.stringify({
model,
...config,
messages: normalizeMessages(this.messages),
stream: true,
}),
Expand Down
17 changes: 3 additions & 14 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
import { load } from 'std/dotenv/mod.ts';
import { join } from 'std/path/mod.ts';
import { writeText as copyToClipboard } from 'copy_paste/mod.ts';
import { prompt } from './prompt.ts';
import { ChatCompletion } from './ChatCompletion.ts';
import { printHelp } from './printHelp.ts';
import { ConversationPersistance } from './ConversationPersistance.ts';
import { parseArgs } from './parseArgs.ts';
import { loadConfig } from './loadConfig.ts';

// TODO: add ability to set the other model params with env vars

const env = await load({
envPath: join(Deno.env.get('HOME')!, '.cli-gpt'),
});

if (env.OPENAI_API_KEY === undefined) {
console.error('OPENAI_API_KEY environment variable is not set');
Deno.exit(1);
}

const config = await loadConfig();
const conversationPersistance = new ConversationPersistance();
const { flags, role, readFiles, prompt: promptFromArgs } = parseArgs();
const { affectInitialMessages } = flags;
Expand Down Expand Up @@ -55,7 +44,7 @@ if (flags.help) {
}

if (role === 'user') {
const chatCompletion = new ChatCompletion(env.OPENAI_API_KEY, env.MODEL);
const chatCompletion = new ChatCompletion(config);
const encoder = new TextEncoder();
const write = (chunk: string) => Deno.stdout.write(encoder.encode(chunk));
const responseContent = [];
Expand Down
108 changes: 108 additions & 0 deletions src/loadConfig.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { load } from 'std/dotenv/mod.ts';
import { join } from 'std/path/mod.ts';

export type Config = {
api_key: string;
model: string;
temperature?: number;
top_p?: number;
n?: number;
max_tokens?: number;
presence_penalty?: number;
frequency_penalty?: number;
stop?: string | string[];
logit_bias?: Record<string, number>;
};

function validateParam<T>(
paramName: string,
parseValue: () => T | undefined,
isValid: (value: T) => boolean,
): T | undefined {
try {
const value = parseValue();

if (value === undefined) {
return undefined;
}

if (!isValid(value)) {
throw new Error(`Invalid value for parameter ${paramName}: ${value}`);
}

return value;
} catch (error) {
console.error(`Error parsing parameter ${paramName}: ${error.message}`);
Deno.exit(1);
}
}

export async function loadConfig(): Promise<Config> {
const env = await load({
envPath: join(Deno.env.get('HOME')!, '.cli-gpt'),
});

if (env.OPENAI_API_KEY === undefined) {
console.error('OPENAI_API_KEY environment variable is not set');
Deno.exit(1);
}

return {
api_key: env.OPENAI_API_KEY,
model: env.MODEL ?? 'gpt-4',
temperature: validateParam<number>(
'temperature',
() =>
env.TEMPERATURE !== undefined ? parseFloat(env.TEMPERATURE) : undefined,
(value) => value >= 0 && value <= 2,
),
top_p: validateParam<number>(
'top_p',
() => (env.TOP_P !== undefined ? parseFloat(env.TOP_P) : undefined),
(value) => value >= 0 && value <= 1,
),
n: validateParam<number>(
'n',
() => (env.N !== undefined ? parseInt(env.N) : undefined),
(value) => value > 0,
),
max_tokens: validateParam<number>(
'max_tokens',
() => (env.MAX_TOKENS !== undefined ? parseInt(env.MAX_TOKENS) : undefined),
(value) => value >= 0,
),
presence_penalty: validateParam<number>(
'presence_penalty',
() =>
env.PRESENCE_PENALTY !== undefined
? parseFloat(env.PRESENCE_PENALTY)
: undefined,
(value) => value >= -2 && value <= 2,
),
frequency_penalty: validateParam<number>(
'frequency_penalty',
() =>
env.FREQUENCY_PENALTY !== undefined
? parseFloat(env.FREQUENCY_PENALTY)
: undefined,
(value) => value >= -2 && value <= 2,
),
stop: validateParam<string | string[]>(
'stop',
() => (env.STOP !== undefined ? JSON.parse(env.STOP) : undefined),
(value) =>
(typeof value === 'string' && value.trim() !== '') ||
(Array.isArray(value) && value.length > 0 && value.length <= 4),
),
logit_bias: validateParam<Record<string, number>>(
'logit_bias',
() => (env.LOGIT_BIAS !== undefined ? JSON.parse(env.LOGIT_BIAS) : undefined),
(value) => {
for (const bias of Object.values(value)) {
if (bias < -100 || bias > 100) return false;
}
return true;
},
),
}
}

0 comments on commit 5dd3f9b

Please sign in to comment.