Skip to content

Commit

Permalink
Merge pull request #4 from taichimaeda/add-support-for-other-providers
Browse files Browse the repository at this point in the history
Add support for other providers
  • Loading branch information
taichimaeda authored Apr 20, 2024
2 parents 55bd9e5 + b553a89 commit 446e2ca
Show file tree
Hide file tree
Showing 78 changed files with 4,074 additions and 789 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
id: setup-node
uses: actions/setup-node@v4
with:
node-version: 16
node-version: 18

- name: Install dependencies
id: install
Expand All @@ -45,7 +45,7 @@ jobs:
id: setup-node
uses: actions/setup-node@v4
with:
node-version: 16
node-version: 18

- name: Install dependencies
id: install
Expand All @@ -68,7 +68,7 @@ jobs:
id: setup-node
uses: actions/setup-node@v4
with:
node-version: 16
node-version: 18

- name: Install dependencies
id: install
Expand All @@ -91,7 +91,7 @@ jobs:
id: setup-node
uses: actions/setup-node@v4
with:
node-version: 16
node-version: 18

- name: Install dependencies
id: install
Expand Down
7 changes: 5 additions & 2 deletions esbuild.config.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ const context = await esbuild.context({
sourcemap: prod ? false : "inline",
treeShaking: true,
outfile: "main.js",
// For loading custom icons:
loader: { ".svg": "text" },
loader: {
".txt": "text",
".md": "text",
".svg": "text", // For custom icons,
},
});

if (prod) {
Expand Down
8 changes: 7 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"format:fix": "prettier src --write",
"lint": "eslint src --max-warnings 0",
"lint:fix": "eslint src --max-warnings 0 --fix",
"test": "echo \"Error: no test specified\""
"test": "echo \"Error: no test specified\"",
"scrape": "ts-node src/scripts/scrape.ts"
},
"keywords": [],
"author": "",
Expand All @@ -23,11 +24,15 @@
"@types/react-dom": "^18.2.22",
"@typescript-eslint/eslint-plugin": "5.29.0",
"@typescript-eslint/parser": "5.29.0",
"axios": "^1.6.8",
"builtin-modules": "3.3.0",
"cheerio": "^1.0.0-rc.12",
"commander": "^12.0.0",
"esbuild": "0.17.3",
"eslint": "^8.57.0",
"obsidian": "latest",
"prettier": "^3.2.5",
"ts-node": "^10.9.2",
"tslib": "2.4.0",
"typescript": "4.7.4"
},
Expand All @@ -37,6 +42,7 @@
"chart.js": "^4.4.2",
"js-tiktoken": "^1.0.10",
"lucide-react": "^0.363.0",
"minimatch": "^9.0.4",
"openai": "^4.30.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
Expand Down
19 changes: 19 additions & 0 deletions src/api/clients/gemini.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { APIClient, ChatMessage } from '..';

// TODO:
// Implement API client for Gemini.

export class GeminiAPIClient implements APIClient {
fetchChat(messages: ChatMessage[]): AsyncGenerator<string | undefined> {
throw new Error('Method not implemented.');
}
fetchCompletions(
prefix: string,
suffix: string,
): Promise<string | undefined> {
throw new Error('Method not implemented.');
}
testConnection(): Promise<boolean> {
throw new Error('Method not implemented.');
}
}
46 changes: 46 additions & 0 deletions src/api/clients/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { Notice } from 'obsidian';
import OpenAI from 'openai';
import Markpilot from 'src/main';
import { validateURL } from 'src/utils';
import { APIClient } from '..';
import { PromptGenerator } from '../prompts/generator';
import { Provider } from '../providers';
import { CostsTracker } from '../providers/costs';
import { OpenAICompatibleAPIClient } from './openai-compatible';

export class OllamaAPIClient
extends OpenAICompatibleAPIClient
implements APIClient
{
constructor(
generator: PromptGenerator,
tracker: CostsTracker,
plugin: Markpilot,
) {
super(generator, tracker, plugin);
}

get provider(): Provider {
return 'ollama';
}

get openai(): OpenAI | undefined {
const { settings } = this.plugin;

const apiUrl = settings.providers.ollama.apiUrl;
if (apiUrl === undefined) {
new Notice('Ollama API URL is not set.');
return;
}
if (!validateURL(apiUrl)) {
new Notice('Ollama API URL is invalid.');
return;
}

return new OpenAI({
baseURL: apiUrl,
apiKey: 'ollama', // Required but ignored.
dangerouslyAllowBrowser: true,
});
}
}
136 changes: 136 additions & 0 deletions src/api/clients/openai-compatible.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import { getEncoding } from 'js-tiktoken';
import { Notice } from 'obsidian';
import OpenAI from 'openai';
import Markpilot from 'src/main';
import { APIClient, ChatMessage } from '..';
import { PromptGenerator } from '../prompts/generator';
import { Provider } from '../providers';
import { CostsTracker } from '../providers/costs';
import { DEFAULT_MODELS } from '../providers/models';

export abstract class OpenAICompatibleAPIClient implements APIClient {
constructor(
protected generator: PromptGenerator,
protected tracker: CostsTracker,
protected plugin: Markpilot,
) {}

abstract get provider(): Provider;

abstract get openai(): OpenAI | undefined;

async *fetchChat(messages: ChatMessage[]) {
if (this.openai === undefined) {
return;
}

const { settings } = this.plugin;
try {
const stream = await this.openai.chat.completions.create({
messages,
model: settings.chat.model,
max_tokens: settings.chat.maxTokens,
temperature: settings.chat.temperature,
top_p: 1,
n: 1,
stream: true,
});

const contents = [];
for await (const chunk of stream) {
const content = chunk.choices[0].delta.content ?? '';
contents.push(content);
yield content;
}

// Update usage cost estimates.
const enc = getEncoding('gpt2'); // Assume GPT-2 encoding
const inputMessage = messages
.map((message) => message.content)
.join('\n');
const outputMessage = contents.join('');
const inputTokens = enc.encode(inputMessage).length;
const outputTokens = enc.encode(outputMessage).length;
await this.tracker.add(
settings.chat.provider,
settings.chat.model,
inputTokens,
outputTokens,
);
} catch (error) {
console.error(error);
new Notice(
'Failed to fetch chat completions. Make sure your API key or API URL is correct.',
);
}
}

async fetchCompletions(prefix: string, suffix: string) {
if (this.openai === undefined) {
return;
}

const { settings } = this.plugin;
try {
const messages = this.generator.generate(prefix, suffix);
const completions = await this.openai.chat.completions.create({
messages,
model: settings.completions.model,
max_tokens: settings.completions.maxTokens,
temperature: settings.completions.temperature,
top_p: 1,
n: 1,
stop: ['\n\n\n'],
});

// Update usage cost estimates.
const inputTokens = completions.usage?.prompt_tokens ?? 0;
const outputTokens = completions.usage?.completion_tokens ?? 0;
await this.tracker.add(
settings.completions.provider,
settings.completions.model,
inputTokens,
outputTokens,
);

const content = completions.choices[0].message.content;
if (content === null) {
return;
}
return this.generator.parse(content);
} catch (error) {
console.error(error);
console.log(JSON.stringify(error));
new Notice(
'Failed to fetch completions. Make sure your API key or API URL is correct.',
);
}
}

async testConnection() {
if (this.openai === undefined) {
return false;
}

try {
const response = await this.openai.chat.completions.create({
messages: [
{
role: 'user',
content: 'Say this is a test',
},
],
model: DEFAULT_MODELS[this.provider],
max_tokens: 1,
temperature: 0,
top_p: 1,
n: 1,
});

return response.choices[0].message.content !== '';
} catch (error) {
console.error(error);
return false;
}
}
}
44 changes: 44 additions & 0 deletions src/api/clients/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { Notice } from 'obsidian';
import OpenAI from 'openai';
import Markpilot from 'src/main';
import { APIClient } from '..';
import { PromptGenerator } from '../prompts/generator';
import { Provider } from '../providers';
import { CostsTracker } from '../providers/costs';
import { OpenAICompatibleAPIClient } from './openai-compatible';

export class OpenAIAPIClient
extends OpenAICompatibleAPIClient
implements APIClient
{
constructor(
generator: PromptGenerator,
tracker: CostsTracker,
plugin: Markpilot,
) {
super(generator, tracker, plugin);
}

get provider(): Provider {
return 'openai';
}

get openai(): OpenAI | undefined {
const { settings } = this.plugin;

const apiKey = settings.providers.openai.apiKey;
if (apiKey === undefined) {
new Notice('OpenAI API key is not set.');
return;
}
if (!apiKey.startsWith('sk')) {
new Notice('OpenAI API key is invalid.');
return;
}

return new OpenAI({
apiKey,
dangerouslyAllowBrowser: true,
});
}
}
45 changes: 45 additions & 0 deletions src/api/clients/openrouter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { Notice } from 'obsidian';
import OpenAI from 'openai';
import Markpilot from 'src/main';
import { APIClient } from '..';
import { PromptGenerator } from '../prompts/generator';
import { Provider } from '../providers';
import { CostsTracker } from '../providers/costs';
import { OpenAICompatibleAPIClient } from './openai-compatible';

export class OpenRouterAPIClient
extends OpenAICompatibleAPIClient
implements APIClient
{
constructor(
generator: PromptGenerator,
tracker: CostsTracker,
plugin: Markpilot,
) {
super(generator, tracker, plugin);
}

get provider(): Provider {
return 'openrouter';
}

get openai(): OpenAI | undefined {
const { settings } = this.plugin;

const apiKey = settings.providers.openrouter.apiKey;
if (apiKey === undefined) {
new Notice('OpenRouter API key is not set.');
return;
}
if (!apiKey.startsWith('sk')) {
new Notice('OpenRouter API key is invalid.');
return;
}

return new OpenAI({
apiKey,
baseURL: 'https://openrouter.ai/api/v1',
dangerouslyAllowBrowser: true,
});
}
}
17 changes: 17 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export interface APIClient {
fetchChat(messages: ChatMessage[]): AsyncGenerator<string | undefined>;
fetchCompletions(prefix: string, suffix: string): Promise<string | undefined>;
testConnection(): Promise<boolean>;
}

export type ChatRole = 'system' | 'assistant' | 'user';

export interface ChatMessage {
role: ChatRole;
content: string;
}

export interface ChatHistory {
messages: ChatMessage[];
response: string;
}
Loading

0 comments on commit 446e2ca

Please sign in to comment.