diff --git a/paper_ai/lib/providers/chat_provider.dart b/paper_ai/lib/providers/chat_provider.dart index f20a833..6adb61a 100644 --- a/paper_ai/lib/providers/chat_provider.dart +++ b/paper_ai/lib/providers/chat_provider.dart @@ -78,6 +78,10 @@ class ChatNotifier extends StateNotifier { model: settings.selectedModel!.name, ), ); + case ModelProvider.koboldcpp: + return ChatOpenAI( + baseUrl: settings.customEndpoint, + ); } } diff --git a/paper_ai/lib/providers/settings_provider.dart b/paper_ai/lib/providers/settings_provider.dart index c22a437..2625df3 100644 --- a/paper_ai/lib/providers/settings_provider.dart +++ b/paper_ai/lib/providers/settings_provider.dart @@ -7,7 +7,7 @@ final settingsProvider = return SettingsNotifier(); }); -enum ModelProvider { gemini, openai, claude } +enum ModelProvider { gemini, openai, claude, koboldcpp } class Model { final String name; @@ -31,6 +31,8 @@ class SettingsState { final String openaiApiKey; final String claudeApiKey; final int messageCount; + final String customEndpoint; + final String customModelName; final Model? selectedModel; SettingsState({ @@ -38,6 +40,8 @@ class SettingsState { required this.openaiApiKey, required this.claudeApiKey, required this.messageCount, + required this.customEndpoint, + required this.customModelName, this.selectedModel, }); @@ -46,6 +50,8 @@ class SettingsState { String? openaiApiKey, String? claudeApiKey, int? messageCount, + String? customEndpoint, + String? customModelName, Model? selectedModel, }) { return SettingsState( @@ -53,6 +59,8 @@ class SettingsState { openaiApiKey: openaiApiKey ?? this.openaiApiKey, claudeApiKey: claudeApiKey ?? this.claudeApiKey, messageCount: messageCount ?? this.messageCount, + customEndpoint: customEndpoint ?? this.customEndpoint, + customModelName: customModelName ?? this.customModelName, selectedModel: selectedModel ?? this.selectedModel, ); } @@ -79,6 +87,14 @@ class SettingsState { Model(name: 'claude-3-opus-latest', provider: ModelProvider.claude), ]); } + if (customEndpoint.isNotEmpty) { + models.add( + Model( + name: customModelName, + provider: ModelProvider.koboldcpp, + ), + ); + } return models; } @@ -92,12 +108,16 @@ class SettingsState { class SettingsNotifier extends StateNotifier { SettingsNotifier() - : super(SettingsState( - geminiApiKey: '', - openaiApiKey: '', - claudeApiKey: '', - messageCount: 5, - )) { + : super( + SettingsState( + geminiApiKey: '', + openaiApiKey: '', + claudeApiKey: '', + messageCount: 5, + customEndpoint: '', + customModelName: 'my-model', + ), + ) { loadSettings(); } @@ -108,6 +128,8 @@ class SettingsNotifier extends StateNotifier { openaiApiKey: prefs.getString('openaiApiKey') ?? '', claudeApiKey: prefs.getString('claudeApiKey') ?? '', messageCount: prefs.getInt('messageCount') ?? 5, + customEndpoint: prefs.getString('customEndpoint') ?? '', + customModelName: prefs.getString('customModelName') ?? 'my-model', ); if (state.selectedModel == null) { @@ -128,6 +150,8 @@ class SettingsNotifier extends StateNotifier { await prefs.setString('openaiApiKey', state.openaiApiKey); await prefs.setString('claudeApiKey', state.claudeApiKey); await prefs.setInt('messageCount', state.messageCount); + await prefs.setString('customEndpoint', state.customEndpoint); + await prefs.setString('customModelName', state.customModelName); loadSettings(); } @@ -148,6 +172,14 @@ class SettingsNotifier extends StateNotifier { state = state.copyWith(messageCount: count); } + void updateCustomEndpoint(String endpoint) { + state = state.copyWith(customEndpoint: endpoint); + } + + void updateCustomModelName(String name) { + state = state.copyWith(customModelName: name); + } + void updateSelectedModel(Model model) async { state = state.copyWith(selectedModel: model); diff --git a/paper_ai/lib/screens/settings_screen.dart b/paper_ai/lib/screens/settings_screen.dart index 33aa6db..e55ba57 100644 --- a/paper_ai/lib/screens/settings_screen.dart +++ b/paper_ai/lib/screens/settings_screen.dart @@ -21,6 +21,11 @@ class _SettingsScreenState extends ConsumerState { bool _isOpenaiApiKeyVisible = false; bool _isClaudeApiKeyVisible = false; + final TextEditingController _customEndpointController = + TextEditingController(); + final TextEditingController _customModelNameController = + TextEditingController(); + @override void initState() { super.initState(); @@ -35,6 +40,8 @@ class _SettingsScreenState extends ConsumerState { _openaiApiKeyController.text = settings.openaiApiKey; _claudeApiKeyController.text = settings.claudeApiKey; _messageCount = settings.messageCount; + _customEndpointController.text = settings.customEndpoint; + _customModelNameController.text = settings.customModelName; }); } @@ -49,6 +56,12 @@ class _SettingsScreenState extends ConsumerState { .read(settingsProvider.notifier) .updateClaudeApiKey(_claudeApiKeyController.text); ref.read(settingsProvider.notifier).updateMessageCount(_messageCount); + ref.read(settingsProvider.notifier).updateCustomEndpoint( + _customEndpointController.text, + ); + ref.read(settingsProvider.notifier).updateCustomModelName( + _customModelNameController.text, + ); await ref.read(settingsProvider.notifier).saveSettings(); if (mounted) { @@ -155,7 +168,7 @@ class _SettingsScreenState extends ConsumerState { ), ], ), - const SizedBox(height: 16), + const SizedBox(height: 24), const Divider(), Row( children: [ @@ -180,6 +193,38 @@ class _SettingsScreenState extends ConsumerState { _messageCount = newValue; }, ), + const SizedBox(height: 24), + const Divider(), + Row( + children: [ + Text( + 'Custom OpenAI API endpoint', + style: Theme.of(context).textTheme.titleLarge, + ), + ], + ), + Row( + children: [ + Text( + 'This URL is used to connect to an OpenAI compatible API', + style: Theme.of(context).textTheme.labelLarge, + ), + ], + ), + const SizedBox(height: 8), + TextField( + controller: _customEndpointController, + decoration: const InputDecoration( + labelText: 'Custom endpoint (e.g. https://api.myserver.com/v1/)', + ), + ), + const SizedBox(height: 8), + TextField( + controller: _customModelNameController, + decoration: const InputDecoration( + labelText: 'Custom model name', + ), + ), const Spacer(), PaperButton(text: 'Save', onPressed: _saveSettings), ], diff --git a/paper_ai/pubspec.yaml b/paper_ai/pubspec.yaml index 4d51331..aac3f55 100644 --- a/paper_ai/pubspec.yaml +++ b/paper_ai/pubspec.yaml @@ -2,7 +2,7 @@ name: paper_ai description: "An AI assistant for e-ink devices" publish_to: 'none' -version: 0.0.2+2 +version: 0.0.3+3 environment: sdk: ^3.5.4