Skip to content

Commit

Permalink
Added support for custom endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimaelQuemerais committed Oct 29, 2024
1 parent 9b73d15 commit dbcc0b3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 9 deletions.
4 changes: 4 additions & 0 deletions paper_ai/lib/providers/chat_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ class ChatNotifier extends StateNotifier<ChatState> {
model: settings.selectedModel!.name,
),
);
case ModelProvider.koboldcpp:
return ChatOpenAI(
baseUrl: settings.customEndpoint,
);
}
}

Expand Down
46 changes: 39 additions & 7 deletions paper_ai/lib/providers/settings_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ final settingsProvider =
return SettingsNotifier();
});

enum ModelProvider { gemini, openai, claude }
enum ModelProvider { gemini, openai, claude, koboldcpp }

class Model {
final String name;
Expand All @@ -31,13 +31,17 @@ class SettingsState {
final String openaiApiKey;
final String claudeApiKey;
final int messageCount;
final String customEndpoint;
final String customModelName;
final Model? selectedModel;

SettingsState({
required this.geminiApiKey,
required this.openaiApiKey,
required this.claudeApiKey,
required this.messageCount,
required this.customEndpoint,
required this.customModelName,
this.selectedModel,
});

Expand All @@ -46,13 +50,17 @@ class SettingsState {
String? openaiApiKey,
String? claudeApiKey,
int? messageCount,
String? customEndpoint,
String? customModelName,
Model? selectedModel,
}) {
return SettingsState(
geminiApiKey: geminiApiKey ?? this.geminiApiKey,
openaiApiKey: openaiApiKey ?? this.openaiApiKey,
claudeApiKey: claudeApiKey ?? this.claudeApiKey,
messageCount: messageCount ?? this.messageCount,
customEndpoint: customEndpoint ?? this.customEndpoint,
customModelName: customModelName ?? this.customModelName,
selectedModel: selectedModel ?? this.selectedModel,
);
}
Expand All @@ -79,6 +87,14 @@ class SettingsState {
Model(name: 'claude-3-opus-latest', provider: ModelProvider.claude),
]);
}
if (customEndpoint.isNotEmpty) {
models.add(
Model(
name: customModelName,
provider: ModelProvider.koboldcpp,
),
);
}

return models;
}
Expand All @@ -92,12 +108,16 @@ class SettingsState {

class SettingsNotifier extends StateNotifier<SettingsState> {
SettingsNotifier()
: super(SettingsState(
geminiApiKey: '',
openaiApiKey: '',
claudeApiKey: '',
messageCount: 5,
)) {
: super(
SettingsState(
geminiApiKey: '',
openaiApiKey: '',
claudeApiKey: '',
messageCount: 5,
customEndpoint: '',
customModelName: 'my-model',
),
) {
loadSettings();
}

Expand All @@ -108,6 +128,8 @@ class SettingsNotifier extends StateNotifier<SettingsState> {
openaiApiKey: prefs.getString('openaiApiKey') ?? '',
claudeApiKey: prefs.getString('claudeApiKey') ?? '',
messageCount: prefs.getInt('messageCount') ?? 5,
customEndpoint: prefs.getString('customEndpoint') ?? '',
customModelName: prefs.getString('customModelName') ?? 'my-model',
);

if (state.selectedModel == null) {
Expand All @@ -128,6 +150,8 @@ class SettingsNotifier extends StateNotifier<SettingsState> {
await prefs.setString('openaiApiKey', state.openaiApiKey);
await prefs.setString('claudeApiKey', state.claudeApiKey);
await prefs.setInt('messageCount', state.messageCount);
await prefs.setString('customEndpoint', state.customEndpoint);
await prefs.setString('customModelName', state.customModelName);

loadSettings();
}
Expand All @@ -148,6 +172,14 @@ class SettingsNotifier extends StateNotifier<SettingsState> {
state = state.copyWith(messageCount: count);
}

void updateCustomEndpoint(String endpoint) {
state = state.copyWith(customEndpoint: endpoint);
}

void updateCustomModelName(String name) {
state = state.copyWith(customModelName: name);
}

void updateSelectedModel(Model model) async {
state = state.copyWith(selectedModel: model);

Expand Down
47 changes: 46 additions & 1 deletion paper_ai/lib/screens/settings_screen.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class _SettingsScreenState extends ConsumerState<SettingsScreen> {
bool _isOpenaiApiKeyVisible = false;
bool _isClaudeApiKeyVisible = false;

final TextEditingController _customEndpointController =
TextEditingController();
final TextEditingController _customModelNameController =
TextEditingController();

@override
void initState() {
super.initState();
Expand All @@ -35,6 +40,8 @@ class _SettingsScreenState extends ConsumerState<SettingsScreen> {
_openaiApiKeyController.text = settings.openaiApiKey;
_claudeApiKeyController.text = settings.claudeApiKey;
_messageCount = settings.messageCount;
_customEndpointController.text = settings.customEndpoint;
_customModelNameController.text = settings.customModelName;
});
}

Expand All @@ -49,6 +56,12 @@ class _SettingsScreenState extends ConsumerState<SettingsScreen> {
.read(settingsProvider.notifier)
.updateClaudeApiKey(_claudeApiKeyController.text);
ref.read(settingsProvider.notifier).updateMessageCount(_messageCount);
ref.read(settingsProvider.notifier).updateCustomEndpoint(
_customEndpointController.text,
);
ref.read(settingsProvider.notifier).updateCustomModelName(
_customModelNameController.text,
);
await ref.read(settingsProvider.notifier).saveSettings();

if (mounted) {
Expand Down Expand Up @@ -155,7 +168,7 @@ class _SettingsScreenState extends ConsumerState<SettingsScreen> {
),
],
),
const SizedBox(height: 16),
const SizedBox(height: 24),
const Divider(),
Row(
children: [
Expand All @@ -180,6 +193,38 @@ class _SettingsScreenState extends ConsumerState<SettingsScreen> {
_messageCount = newValue;
},
),
const SizedBox(height: 24),
const Divider(),
Row(
children: [
Text(
'Custom OpenAI API endpoint',
style: Theme.of(context).textTheme.titleLarge,
),
],
),
Row(
children: [
Text(
'This URL is used to connect to an OpenAI compatible API',
style: Theme.of(context).textTheme.labelLarge,
),
],
),
const SizedBox(height: 8),
TextField(
controller: _customEndpointController,
decoration: const InputDecoration(
labelText: 'Custom endpoint (e.g. https://api.myserver.com/v1/)',
),
),
const SizedBox(height: 8),
TextField(
controller: _customModelNameController,
decoration: const InputDecoration(
labelText: 'Custom model name',
),
),
const Spacer(),
PaperButton(text: 'Save', onPressed: _saveSettings),
],
Expand Down
2 changes: 1 addition & 1 deletion paper_ai/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: paper_ai
description: "An AI assistant for e-ink devices"

publish_to: 'none'
version: 0.0.2+2
version: 0.0.3+3

environment:
sdk: ^3.5.4
Expand Down

0 comments on commit dbcc0b3

Please sign in to comment.