Skip to content

Commit

Permalink
Add logic to generate prompt and parse response
Browse files Browse the repository at this point in the history
  • Loading branch information
taichimaeda committed Apr 19, 2024
1 parent 7c73dca commit b327646
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 13 deletions.
9 changes: 7 additions & 2 deletions src/api/clients/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ import Markpilot from 'src/main';
import { validateURL } from 'src/utils';
import { APIClient } from '..';
import { CostsTracker } from '../costs';
import { PromptGenerator } from '../prompts/generator';
import { OpenAICompatibleAPIClient } from './openai-compatible';

export class OllamaAPIClient
extends OpenAICompatibleAPIClient
implements APIClient
{
constructor(tracker: CostsTracker, plugin: Markpilot) {
super(tracker, plugin);
constructor(
generator: PromptGenerator,
tracker: CostsTracker,
plugin: Markpilot,
) {
super(generator, tracker, plugin);
}

get openai(): OpenAI | undefined {
Expand Down
13 changes: 9 additions & 4 deletions src/api/clients/openai-compatible.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import Markpilot from 'src/main';
import { APIClient } from '..';
import { ChatMessage } from '../../types';
import { CostsTracker } from '../costs';
import { PromptGenerator } from '../prompts/generator';

export abstract class OpenAICompatibleAPIClient implements APIClient {
constructor(
protected generator: PromptGenerator,
protected tracker: CostsTracker,
protected plugin: Markpilot,
) {}
Expand Down Expand Up @@ -69,10 +71,9 @@ export abstract class OpenAICompatibleAPIClient implements APIClient {
}

try {
// TODO:
// Get messages from the prompt generator.
const messages = this.generator.generate(prefix, suffix);
const completions = await this.openai.chat.completions.create({
messages: [],
messages,
model: settings.completions.model,
max_tokens: settings.completions.maxTokens,
temperature: settings.completions.temperature,
Expand All @@ -91,7 +92,11 @@ export abstract class OpenAICompatibleAPIClient implements APIClient {
outputTokens,
);

return completions.choices[0].message.content ?? undefined;
const content = completions.choices[0].message.content;
if (content === null) {
return;
}
return this.generator.parse(content);
} catch (error) {
console.error(error);
new Notice(
Expand Down
9 changes: 7 additions & 2 deletions src/api/clients/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ import OpenAI from 'openai';
import Markpilot from 'src/main';
import { APIClient } from '..';
import { CostsTracker } from '../costs';
import { PromptGenerator } from '../prompts/generator';
import { OpenAICompatibleAPIClient } from './openai-compatible';

export class OpenAIAPIClient
extends OpenAICompatibleAPIClient
implements APIClient
{
constructor(tracker: CostsTracker, plugin: Markpilot) {
super(tracker, plugin);
constructor(
generator: PromptGenerator,
tracker: CostsTracker,
plugin: Markpilot,
) {
super(generator, tracker, plugin);
}

get openai(): OpenAI | undefined {
Expand Down
9 changes: 7 additions & 2 deletions src/api/clients/openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ import OpenAI from 'openai';
import Markpilot from 'src/main';
import { APIClient } from '..';
import { CostsTracker } from '../costs';
import { PromptGenerator } from '../prompts/generator';
import { OpenAICompatibleAPIClient } from './openai-compatible';

export class OpenRouterAPIClient
extends OpenAICompatibleAPIClient
implements APIClient
{
constructor(tracker: CostsTracker, plugin: Markpilot) {
super(tracker, plugin);
constructor(
generator: PromptGenerator,
tracker: CostsTracker,
plugin: Markpilot,
) {
super(generator, tracker, plugin);
}

get openai(): OpenAI | undefined {
Expand Down
5 changes: 5 additions & 0 deletions src/api/prompts/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,9 @@ export class PromptGenerator {
},
] as ChatMessage[];
}

parse(content: string) {
const lines = content.split('\n');
return lines.slice(lines.indexOf('<INSERT>') + 1).join('\n');
}
}
8 changes: 5 additions & 3 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { OllamaAPIClient } from './api/clients/ollama';
import { OpenAIAPIClient } from './api/clients/openai';
import { OpenRouterAPIClient } from './api/clients/openrouter';
import { CostsTracker } from './api/costs';
import { PromptGenerator } from './api/prompts/generator';
import { Provider } from './api/provider';
import { MemoryCacheProxy } from './api/proxies/memory-cache';
import { UsageMonitorProxy } from './api/proxies/usage-monitor';
Expand Down Expand Up @@ -199,15 +200,16 @@ export default class Markpilot extends Plugin {
}

createAPIClient(provider: Provider) {
const generator = new PromptGenerator(this);
const tracker = new CostsTracker(this);
const client = (() => {
switch (provider) {
case 'openai':
return new OpenAIAPIClient(tracker, this);
return new OpenAIAPIClient(generator, tracker, this);
case 'openrouter':
return new OpenRouterAPIClient(tracker, this);
return new OpenRouterAPIClient(generator, tracker, this);
case 'ollama':
return new OllamaAPIClient(tracker, this);
return new OllamaAPIClient(generator, tracker, this);
}
})();
const clientWithMonitor = new UsageMonitorProxy(client, this);
Expand Down

0 comments on commit b327646

Please sign in to comment.