Skip to content

Commit

Permalink
use new openai package and models
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbjames committed Jan 9, 2024
1 parent 8129e98 commit e99fb2d
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 100 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"native-keymap": "3.3.0",
"native-watchdog": "1.4.0",
"node-pty": "0.11.0-beta11",
"openai": "^3.1.0",
"openai": "^4.24.1",
"spdlog": "^0.13.0",
"tas-client-umd": "0.1.6",
"v8-inspect-profiler": "^0.1.0",
Expand Down
47 changes: 2 additions & 45 deletions src/vs/editor/contrib/leap/browser/Leap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -337,55 +337,12 @@ class Leap implements IEditorContribution {
// TODO (lisa) bad hack below, should remove when the server logic is set up for the web version
const modelRequest = await this._utils.buildRequest(prefix, suffix);
this._logger.modelRequest(modelRequest);
const events: string[] = await this._utils.getCompletions(
const codes: string[] = await this._utils.getCompletions(
modelRequest,
signal,
(_e) => progressBar.worked(1));

console.debug('Got the following events from the server:\n', events);

// We're streaming the data, so we need to reconstruct it here.
// TODO (kas) we can technically be constructing it in the `onDownloadProgress` function,
// but this is a bit easier :/
// TODO (kas) Is there a library function for handling this, so we don't need to parse it
// manually?
const codes = [];
for (const eventStr of events) {
if (!eventStr.startsWith('data:')) {
console.error('Event line does NOT start with `data:`. This should not happen.');
continue;
}

if (eventStr === 'data: [DONE]') {
// We should be done.
continue;
}

const event = JSON.parse(eventStr.slice(5).trim());

if (!('choices' in event)) {
if ('error' in event && 'message' in event.error) {
// They sent us an error message.
rs.push(new ErrorMessage(event.error.message));
} else {
console.error('Event format not recognized, skipping.');
console.error(event);
rs.push(new ErrorMessage('Message not recognized. Please see logs for details.'));
}
continue;
}

for (const choice of event.choices) {
const i = choice.index;
const text = choice.text;

while (codes.length <= i) {
codes.push('');
}

codes[i] += text;
}
}
console.debug('Got the following completions from the server:\n', codes);

// Remove empty or repeated completions.
const set = new Set();
Expand Down
10 changes: 7 additions & 3 deletions src/vs/editor/contrib/leap/browser/LeapInterfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { ALogger, StudyGroup } from "../../rtv/browser/RTVInterfaces";
// remote and local versions.
export interface OpenAIRequest {
'model': string;
'prompt'?: string | null;
'messages': OpenAIMessage[];
'suffix'?: string | null;
'max_tokens'?: number | null;
'temperature'?: number | null;
Expand All @@ -19,10 +19,14 @@ export interface OpenAIRequest {
'presence_penalty'?: number | null;
'frequency_penalty'?: number | null;
'best_of'?: number | null;
'logit_bias'?: object | null;
'user'?: string;
}

export interface OpenAIMessage {
role: "assistant" | "system" | "user";
content: string;
}

export { StudyGroup } from "../../rtv/browser/RTVInterfaces";

export interface LeapUtils {
Expand Down Expand Up @@ -109,4 +113,4 @@ export abstract class ALeapLogger extends ALogger implements ILeapLogger {
async panelUnfocus() {
await this.log('leap.panel.unfocus');
}
}
}
82 changes: 66 additions & 16 deletions src/vs/editor/contrib/leap/browser/LeapUtils.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import * as openai from 'openai';
import OpenAI from 'openai';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { ALeapLogger, ILeapLogger, LeapConfig, LeapUtils, OpenAIRequest } from 'vs/editor/contrib/leap/browser/LeapInterfaces';
import { ALeapLogger, ILeapLogger, LeapConfig, LeapUtils, OpenAIMessage, OpenAIRequest } from 'vs/editor/contrib/leap/browser/LeapInterfaces';
import { StudyGroup } from '../../rtv/browser/RTVInterfaces';
import { ICodeEditor } from 'vs/editor/browser/editorBrowser';

class LocalUtils implements LeapUtils {
public readonly EOL: string = os.EOL;
private _openAi: openai.OpenAIApi;
private _openAi: OpenAI;
private _requestTemplate = {
model: "code-davinci-002",
model: "gpt-3.5-turbo",
temperature: 0.5,
n: 5,
max_tokens: 512,
Expand All @@ -20,39 +20,89 @@ class LocalUtils implements LeapUtils {

constructor() {
// Configure OpenAI api
const openAiConfig = new openai.Configuration({
this._openAi = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});
this._openAi = new openai.OpenAIApi(openAiConfig);
}

async getConfig(): Promise<LeapConfig> {
return new LeapConfig('local', StudyGroup.Treatment);
}

async getCompletions(request: OpenAIRequest, signal: AbortSignal, progressCallback: (e: any) => void): Promise<string[]> {
const completions = await this._openAi.createCompletion(
request,
{
onDownloadProgress: progressCallback,
signal: signal
});
return (completions.data as unknown as string).split('\n\n');
// @ts-ignore
const completionArgs: OpenAI.ChatCompletionCreateParamsStreaming = request;
const completions = await this._openAi.chat.completions.create(completionArgs);

signal.onabort = ((_) => { completions.controller.abort(); });
const codes = Array.from({ length: (request.n || 1) }, () => "");
for await (const part of completions) {
const i = part.choices[0].index;
const delta = part.choices[0].delta;
codes[i] += delta;
progressCallback(part);
}

return codes;
}

getLogger(editor: ICodeEditor): ILeapLogger {
return new LeapLogger(editor);
}

async buildRequest(prefix: string, suffix: string): Promise<OpenAIRequest> {
const messages = parsePromptFile("implement_it", { prefix, suffix });
return {
prompt: prefix,
suffix: suffix,
...this._requestTemplate
...this._requestTemplate,
messages,
};
}
}

export function parsePromptFile(filename: string, substitutions: { [key: string]: string }): OpenAIMessage[] {
/**
* 1. Parse the file at prompt/{{filename}}.txt
* 2. Subsitute all words in the text file of {{key}} to the value at substitutions[key].
* 3. Generate a message list for use with the OpenAI API.
*/
const currentDir = path.dirname(path.resolve(__filename));
const filePath = `${currentDir}/prompts/${filename}.txt`;

if (!fs.existsSync(filePath)) {
throw new Error(`Could not find prompt file ${filePath}`);
}

const text = fs.readFileSync(filePath, 'utf-8');
const sections = text.split('---\n');
const chatText: OpenAIMessage[] = [];

for (const section of sections) {
const sectionLines = section.split('\n');
const role = sectionLines[0];
let content = sectionLines.slice(1).join('\n');

for (const [key, value] of Object.entries(substitutions)) {
content = content.replace(new RegExp(`{${key}}`, 'g'), value);
}

const hasMustaches =
content.includes('{{') && content.includes('}}');

if (hasMustaches) {
throw new Error(`Mustache brackets were not replaced in ${filename}.txt`);
}
if (role !== 'system' && role !== 'user' && role !== 'assistant') {
throw new Error(`Role is not 'system', 'user', or 'assistant'. It is '${role}'.`);
}

chatText.push({
role,
content,
});
}
return chatText;
}

export function getUtils(): LeapUtils {
return new LocalUtils();
}
Expand Down
48 changes: 48 additions & 0 deletions src/vs/editor/contrib/leap/browser/prompts/implement_it.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
system
You are an expert Python programmer.
You will fill in the missing piece of Python code. Do not change any of the prefix. Do not change any of the suffix.
Do not repeat the prompt, prefix, or suffix in your answer. The prefix, suffix, and completion when put together, must be parsable as valid Python code.

You will receive a [[prefix]] and a [[suffix]] of Python code. You must fill in the middle.
---
user
[[prefix]]
def fib(n: int) -> int:
[[suffix]]

assert fib(0) == 1
assert fib(1) == 1
---
assistant
if n < 2:
return 1
return fib(n - 1) + fib(n - 2)
---
user
[[prefix]]
import yaml
import os
import openai
import re
import pandas as pd
import sys

pd.options.display.max_rows = 4000

# Read YAML file
with open("secrets.yaml", 'r') as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
[[suffix]]

openai.organization = ORG_ID
openai.api_key = API_KEY
---
assistant
ORG_ID = cfg['ORG_ID']
API_KEY = cfg['API_KEY']
---
user
[[prefix]]
{{prefix}}
[[suffix]]
{{suffix}}
24 changes: 13 additions & 11 deletions src/vs/editor/contrib/leap/browser/remote/LeapUtils.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import { ICodeEditor } from 'vs/editor/browser/editorBrowser';
import { LeapConfig, LeapUtils, OpenAIRequest, ALeapLogger, ILeapLogger } from 'vs/editor/contrib/leap/browser/LeapInterfaces';
import { parsePromptFile } from 'vs/editor/contrib/leap/browser/LeapUtils';
import { LogEventData, LogResultData } from 'vs/editor/contrib/rtv/browser/RTVInterfaces';

class RemoteUtils implements LeapUtils {
public readonly EOL: string = '\n';
private _requestTemplate?: OpenAIRequest = undefined;
private _requestTemplate = {
model: "gpt-3.5-turbo",
temperature: 0.5,
n: 5,
max_tokens: 512,
stop: [this.EOL + this.EOL],
stream: true,
};

constructor() {
this.fetchRequestTemplate();
Expand Down Expand Up @@ -60,17 +68,10 @@ class RemoteUtils implements LeapUtils {
}

async buildRequest(prefix: string, suffix: string): Promise<OpenAIRequest> {
let template;
if (this._requestTemplate) {
template = this._requestTemplate;
} else {
template = await this.fetchRequestTemplate();
}

const messages = parsePromptFile("implement_it", { prefix, suffix });
return {
prompt: prefix,
suffix: suffix,
...template
...this._requestTemplate,
messages,
};
}

Expand All @@ -84,6 +85,7 @@ class RemoteUtils implements LeapUtils {
}
);
const body: OpenAIRequest = await res.json();
// @ts-ignore
this._requestTemplate = body;
return body;
}
Expand Down
2 changes: 1 addition & 1 deletion src/vs/editor/contrib/rtv/browser/remote/RTVUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class RemoteUtils implements Utils {
logger(editor: ICodeEditor): IRTVLogger {
this._editor = editor;
if (!this._logger) {
this._logger = new RTVLogger();
this._logger = new RTVLogger(editor);
}
return this._logger;
}
Expand Down
Loading

0 comments on commit e99fb2d

Please sign in to comment.