From b327646648fb2eab8e7b8cb66776b58b73465a45 Mon Sep 17 00:00:00 2001 From: Taichi Maeda Date: Sat, 20 Apr 2024 02:20:52 +0900 Subject: [PATCH] Add logic to generate prompt and parse response --- src/api/clients/ollama.ts | 9 +++++++-- src/api/clients/openai-compatible.ts | 13 +++++++++---- src/api/clients/openai.ts | 9 +++++++-- src/api/clients/openrouter.ts | 9 +++++++-- src/api/prompts/generator.ts | 5 +++++ src/main.ts | 8 +++++--- 6 files changed, 40 insertions(+), 13 deletions(-) diff --git a/src/api/clients/ollama.ts b/src/api/clients/ollama.ts index 05ce474..f3afd0d 100644 --- a/src/api/clients/ollama.ts +++ b/src/api/clients/ollama.ts @@ -4,14 +4,19 @@ import Markpilot from 'src/main'; import { validateURL } from 'src/utils'; import { APIClient } from '..'; import { CostsTracker } from '../costs'; +import { PromptGenerator } from '../prompts/generator'; import { OpenAICompatibleAPIClient } from './openai-compatible'; export class OllamaAPIClient extends OpenAICompatibleAPIClient implements APIClient { - constructor(tracker: CostsTracker, plugin: Markpilot) { - super(tracker, plugin); + constructor( + generator: PromptGenerator, + tracker: CostsTracker, + plugin: Markpilot, + ) { + super(generator, tracker, plugin); } get openai(): OpenAI | undefined { diff --git a/src/api/clients/openai-compatible.ts b/src/api/clients/openai-compatible.ts index 2435853..113796a 100644 --- a/src/api/clients/openai-compatible.ts +++ b/src/api/clients/openai-compatible.ts @@ -5,9 +5,11 @@ import Markpilot from 'src/main'; import { APIClient } from '..'; import { ChatMessage } from '../../types'; import { CostsTracker } from '../costs'; +import { PromptGenerator } from '../prompts/generator'; export abstract class OpenAICompatibleAPIClient implements APIClient { constructor( + protected generator: PromptGenerator, protected tracker: CostsTracker, protected plugin: Markpilot, ) {} @@ -69,10 +71,9 @@ export abstract class OpenAICompatibleAPIClient implements APIClient { } try { - // TODO: - // Get messages from the prompt generator. + const messages = this.generator.generate(prefix, suffix); const completions = await this.openai.chat.completions.create({ - messages: [], + messages, model: settings.completions.model, max_tokens: settings.completions.maxTokens, temperature: settings.completions.temperature, @@ -91,7 +92,11 @@ export abstract class OpenAICompatibleAPIClient implements APIClient { outputTokens, ); - return completions.choices[0].message.content ?? undefined; + const content = completions.choices[0].message.content; + if (content === null) { + return; + } + return this.generator.parse(content); } catch (error) { console.error(error); new Notice( diff --git a/src/api/clients/openai.ts b/src/api/clients/openai.ts index cfc558b..2d3cbdf 100644 --- a/src/api/clients/openai.ts +++ b/src/api/clients/openai.ts @@ -3,14 +3,19 @@ import OpenAI from 'openai'; import Markpilot from 'src/main'; import { APIClient } from '..'; import { CostsTracker } from '../costs'; +import { PromptGenerator } from '../prompts/generator'; import { OpenAICompatibleAPIClient } from './openai-compatible'; export class OpenAIAPIClient extends OpenAICompatibleAPIClient implements APIClient { - constructor(tracker: CostsTracker, plugin: Markpilot) { - super(tracker, plugin); + constructor( + generator: PromptGenerator, + tracker: CostsTracker, + plugin: Markpilot, + ) { + super(generator, tracker, plugin); } get openai(): OpenAI | undefined { diff --git a/src/api/clients/openrouter.ts b/src/api/clients/openrouter.ts index 6be34fd..1002b88 100644 --- a/src/api/clients/openrouter.ts +++ b/src/api/clients/openrouter.ts @@ -3,14 +3,19 @@ import OpenAI from 'openai'; import Markpilot from 'src/main'; import { APIClient } from '..'; import { CostsTracker } from '../costs'; +import { PromptGenerator } from '../prompts/generator'; import { OpenAICompatibleAPIClient } from './openai-compatible'; export class OpenRouterAPIClient extends OpenAICompatibleAPIClient implements APIClient { - constructor(tracker: CostsTracker, plugin: Markpilot) { - super(tracker, plugin); + constructor( + generator: PromptGenerator, + tracker: CostsTracker, + plugin: Markpilot, + ) { + super(generator, tracker, plugin); } get openai(): OpenAI | undefined { diff --git a/src/api/prompts/generator.ts b/src/api/prompts/generator.ts index 944579d..33556a7 100644 --- a/src/api/prompts/generator.ts +++ b/src/api/prompts/generator.ts @@ -59,4 +59,9 @@ export class PromptGenerator { }, ] as ChatMessage[]; } + + parse(content: string) { + const lines = content.split('\n'); + return lines.slice(lines.indexOf('') + 1).join('\n'); + } } diff --git a/src/main.ts b/src/main.ts index c35fcfd..0754f08 100644 --- a/src/main.ts +++ b/src/main.ts @@ -13,6 +13,7 @@ import { OllamaAPIClient } from './api/clients/ollama'; import { OpenAIAPIClient } from './api/clients/openai'; import { OpenRouterAPIClient } from './api/clients/openrouter'; import { CostsTracker } from './api/costs'; +import { PromptGenerator } from './api/prompts/generator'; import { Provider } from './api/provider'; import { MemoryCacheProxy } from './api/proxies/memory-cache'; import { UsageMonitorProxy } from './api/proxies/usage-monitor'; @@ -199,15 +200,16 @@ export default class Markpilot extends Plugin { } createAPIClient(provider: Provider) { + const generator = new PromptGenerator(this); const tracker = new CostsTracker(this); const client = (() => { switch (provider) { case 'openai': - return new OpenAIAPIClient(tracker, this); + return new OpenAIAPIClient(generator, tracker, this); case 'openrouter': - return new OpenRouterAPIClient(tracker, this); + return new OpenRouterAPIClient(generator, tracker, this); case 'ollama': - return new OllamaAPIClient(tracker, this); + return new OllamaAPIClient(generator, tracker, this); } })(); const clientWithMonitor = new UsageMonitorProxy(client, this);