Skip to content

Commit

Permalink
Add Gemini LLM (#11736)
Browse files Browse the repository at this point in the history
* Add Gemini LLM model

* Fix bugs

* Fix privacy policy notice

* Fix table formatting

---------

Co-authored-by: Oliver Kopp <[email protected]>
  • Loading branch information
InAnYan and koppor authored Sep 13, 2024
1 parent ba9de82 commit 2a39416
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 66 deletions.
77 changes: 39 additions & 38 deletions PRIVACY.md

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,13 @@ dependencies {
implementation 'org.yaml:snakeyaml:2.3'

// AI
implementation 'dev.langchain4j:langchain4j:0.33.0'
implementation 'dev.langchain4j:langchain4j:0.34.0'
// Even though we use jvm-openai for LLM connection, we still need this package for tokenization.
implementation('dev.langchain4j:langchain4j-open-ai:0.33.0') {
implementation('dev.langchain4j:langchain4j-open-ai:0.34.0') {
exclude group: 'org.jetbrains.kotlin', module: 'kotlin-stdlib-jdk8'
}
implementation('dev.langchain4j:langchain4j-mistral-ai:0.33.0')
implementation('dev.langchain4j:langchain4j-mistral-ai:0.34.0')
implementation('dev.langchain4j:langchain4j-google-ai-gemini:0.34.0')
implementation('dev.langchain4j:langchain4j-hugging-face:0.34.0')
implementation 'ai.djl:api:0.29.0'
implementation 'ai.djl.pytorch:pytorch-model-zoo:0.29.0'
Expand Down
1 change: 1 addition & 0 deletions src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,6 @@
requires mslinks;
requires org.antlr.antlr4.runtime;
requires org.libreoffice.uno;
requires langchain4j.google.ai.gemini;
// endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
</Label>
<TextFlow fx:id="openAiPrivacyTextFlow">
<children>
<Text text="%If you have chosen the OpenAI as AI provider, the privacy policy of OpenAI applies. You find it at %0.">
<Text text="%If you have chosen %0 as an AI provider, the privacy policy of %0 applies. You find it at %1.">
<font>
<Font size="14.0" />
</font>
Expand All @@ -37,16 +37,25 @@
</TextFlow>
<TextFlow fx:id="mistralAiPrivacyTextFlow">
<children>
<Text text="%If you have chosen the Mistral AI as AI provider, the privacy policy of Mistral AI applies. You find it at %0.">
<Text text="%If you have chosen %0 as an AI provider, the privacy policy of %0 applies. You find it at %1.">
<font>
<Font size="14.0" />
</font>
</Text>
</children>
</TextFlow>
<TextFlow fx:id="geminiPrivacyTextFlow">
<children>
<Text text="%If you have chosen %0 as an AI provider, the privacy policy of %0 applies. You find it at %1.">
<font>
<Font size="14.0" />
</font>
</Text>
</children>
</TextFlow>
<TextFlow fx:id="huggingFacePrivacyTextFlow">
<children>
<Text text="%If you have chosen the Hugging Face as AI provider, the privacy policy of Hugging Face applies. You find it at %0.">
<Text text="%If you have chosen %0 as an AI provider, the privacy policy of %0 applies. You find it at %1.">
<font>
<Font size="14.0" />
</font>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

import org.jabref.gui.DialogService;
import org.jabref.gui.desktop.JabRefDesktop;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.preferences.FilePreferences;
import org.jabref.preferences.ai.AiPreferences;
import org.jabref.preferences.ai.AiProvider;

import com.airhacks.afterburner.views.ViewLoader;
import org.slf4j.Logger;
Expand All @@ -22,6 +24,7 @@ public class PrivacyNoticeComponent extends ScrollPane {

@FXML private TextFlow openAiPrivacyTextFlow;
@FXML private TextFlow mistralAiPrivacyTextFlow;
@FXML private TextFlow geminiPrivacyTextFlow;
@FXML private TextFlow huggingFacePrivacyTextFlow;
@FXML private Text embeddingModelText;

Expand All @@ -43,9 +46,10 @@ public PrivacyNoticeComponent(AiPreferences aiPreferences, Runnable onIAgreeButt

@FXML
private void initialize() {
initPrivacyHyperlink(openAiPrivacyTextFlow, "https://openai.com/policies/privacy-policy/");
initPrivacyHyperlink(mistralAiPrivacyTextFlow, "https://mistral.ai/terms/#privacy-policy");
initPrivacyHyperlink(huggingFacePrivacyTextFlow, "https://huggingface.co/privacy");
initPrivacyHyperlink(openAiPrivacyTextFlow, AiProvider.OPEN_AI);
initPrivacyHyperlink(mistralAiPrivacyTextFlow, AiProvider.MISTRAL_AI);
initPrivacyHyperlink(geminiPrivacyTextFlow, AiProvider.GEMINI);
initPrivacyHyperlink(huggingFacePrivacyTextFlow, AiProvider.HUGGING_FACE);

String newEmbeddingModelText = embeddingModelText.getText().replaceAll("%0", aiPreferences.getEmbeddingModel().sizeInfo());
embeddingModelText.setText(newEmbeddingModelText);
Expand All @@ -56,20 +60,19 @@ private void initialize() {
embeddingModelText.wrappingWidthProperty().bind(this.widthProperty());
}

private void initPrivacyHyperlink(TextFlow textFlow, String link) {
private void initPrivacyHyperlink(TextFlow textFlow, AiProvider aiProvider) {
if (textFlow.getChildren().isEmpty() || !(textFlow.getChildren().getFirst() instanceof Text text)) {
return;
}

String[] stringArray = text.getText().split("%0");
String replacedText = text.getText().replaceAll("%0", aiProvider.getLabel()).replace("%1", "");

if (stringArray.length != 2) {
return;
}
replacedText = replacedText.endsWith(".") ? replacedText.substring(0, replacedText.length() - 1) : replacedText;

text.setText(replacedText);
text.wrappingWidthProperty().bind(this.widthProperty());
text.setText(stringArray[0]);

String link = AiDefaultPreferences.PROVIDERS_PRIVACY_POLICIES.get(aiProvider);
Hyperlink hyperlink = new Hyperlink(link);
hyperlink.setWrapText(true);
hyperlink.setFont(text.getFont());
Expand All @@ -79,11 +82,11 @@ private void initPrivacyHyperlink(TextFlow textFlow, String link) {

textFlow.getChildren().add(hyperlink);

Text postText = new Text(stringArray[1]);
postText.setFont(text.getFont());
postText.wrappingWidthProperty().bind(this.widthProperty());
Text dot = new Text(".");
dot.setFont(text.getFont());
dot.wrappingWidthProperty().bind(this.widthProperty());

textFlow.getChildren().add(postText);
textFlow.getChildren().add(dot);
}

@FXML
Expand Down
27 changes: 24 additions & 3 deletions src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ public class AiTabViewModel implements PreferenceTabViewModel {

private final StringProperty openAiChatModel = new SimpleStringProperty();
private final StringProperty mistralAiChatModel = new SimpleStringProperty();
private final StringProperty geminiChatModel = new SimpleStringProperty();
private final StringProperty huggingFaceChatModel = new SimpleStringProperty();

private final StringProperty currentApiKey = new SimpleStringProperty();

private final StringProperty openAiApiKey = new SimpleStringProperty();
private final StringProperty mistralAiApiKey = new SimpleStringProperty();
private final StringProperty geminiAiApiKey = new SimpleStringProperty();
private final StringProperty huggingFaceApiKey = new SimpleStringProperty();

private final BooleanProperty customizeExpertSettings = new SimpleBooleanProperty();
Expand All @@ -64,10 +66,11 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final ObjectProperty<EmbeddingModel> selectedEmbeddingModel = new SimpleObjectProperty<>();

private final StringProperty currentApiBaseUrl = new SimpleStringProperty();
private final BooleanProperty disableApiBaseUrl = new SimpleBooleanProperty(true); // {@link HuggingFaceChatModel} doesn't support setting API base URL
private final BooleanProperty disableApiBaseUrl = new SimpleBooleanProperty(true); // {@link HuggingFaceChatModel} and {@link GoogleAiGeminiChatModel} doesn't support setting API base URL

private final StringProperty openAiApiBaseUrl = new SimpleStringProperty();
private final StringProperty mistralAiApiBaseUrl = new SimpleStringProperty();
private final StringProperty geminiApiBaseUrl = new SimpleStringProperty();
private final StringProperty huggingFaceApiBaseUrl = new SimpleStringProperty();

private final StringProperty instruction = new SimpleStringProperty();
Expand Down Expand Up @@ -120,7 +123,7 @@ public AiTabViewModel(PreferencesService preferencesService) {
String oldChatModel = currentChatModel.get();
chatModelsList.setAll(models);

disableApiBaseUrl.set(newValue == AiProvider.HUGGING_FACE);
disableApiBaseUrl.set(newValue == AiProvider.HUGGING_FACE || newValue == AiProvider.GEMINI);

if (oldValue != null) {
switch (oldValue) {
Expand All @@ -134,6 +137,11 @@ public AiTabViewModel(PreferencesService preferencesService) {
mistralAiApiKey.set(currentApiKey.get());
mistralAiApiBaseUrl.set(currentApiBaseUrl.get());
}
case GEMINI -> {
geminiChatModel.set(oldChatModel);
geminiAiApiKey.set(currentApiKey.get());
geminiApiBaseUrl.set(currentApiBaseUrl.get());
}
case HUGGING_FACE -> {
huggingFaceChatModel.set(oldChatModel);
huggingFaceApiKey.set(currentApiKey.get());
Expand All @@ -153,6 +161,11 @@ public AiTabViewModel(PreferencesService preferencesService) {
currentApiKey.set(mistralAiApiKey.get());
currentApiBaseUrl.set(mistralAiApiBaseUrl.get());
}
case GEMINI -> {
currentChatModel.set(geminiChatModel.get());
currentApiKey.set(geminiAiApiKey.get());
currentApiBaseUrl.set(geminiApiBaseUrl.get());
}
case HUGGING_FACE -> {
currentChatModel.set(huggingFaceChatModel.get());
currentApiKey.set(huggingFaceApiKey.get());
Expand All @@ -165,6 +178,7 @@ public AiTabViewModel(PreferencesService preferencesService) {
switch (selectedAiProvider.get()) {
case OPEN_AI -> openAiChatModel.set(newValue);
case MISTRAL_AI -> mistralAiChatModel.set(newValue);
case GEMINI -> geminiChatModel.set(newValue);
case HUGGING_FACE -> huggingFaceChatModel.set(newValue);
}

Expand All @@ -182,6 +196,7 @@ public AiTabViewModel(PreferencesService preferencesService) {
switch (selectedAiProvider.get()) {
case OPEN_AI -> openAiApiKey.set(newValue);
case MISTRAL_AI -> mistralAiApiKey.set(newValue);
case GEMINI -> geminiAiApiKey.set(newValue);
case HUGGING_FACE -> huggingFaceApiKey.set(newValue);
}
});
Expand All @@ -190,6 +205,7 @@ public AiTabViewModel(PreferencesService preferencesService) {
switch (selectedAiProvider.get()) {
case OPEN_AI -> openAiApiBaseUrl.set(newValue);
case MISTRAL_AI -> mistralAiApiBaseUrl.set(newValue);
case GEMINI -> geminiApiBaseUrl.set(newValue);
case HUGGING_FACE -> huggingFaceApiBaseUrl.set(newValue);
}
});
Expand Down Expand Up @@ -265,14 +281,17 @@ public AiTabViewModel(PreferencesService preferencesService) {
public void setValues() {
openAiApiKey.setValue(aiPreferences.getApiKeyForAiProvider(AiProvider.OPEN_AI));
mistralAiApiKey.setValue(aiPreferences.getApiKeyForAiProvider(AiProvider.MISTRAL_AI));
geminiAiApiKey.setValue(aiPreferences.getApiKeyForAiProvider(AiProvider.GEMINI));
huggingFaceApiKey.setValue(aiPreferences.getApiKeyForAiProvider(AiProvider.HUGGING_FACE));

openAiApiBaseUrl.setValue(aiPreferences.getOpenAiApiBaseUrl());
mistralAiApiBaseUrl.setValue(aiPreferences.getMistralAiApiBaseUrl());
geminiApiBaseUrl.setValue(aiPreferences.getGeminiApiBaseUrl());
huggingFaceApiBaseUrl.setValue(aiPreferences.getHuggingFaceApiBaseUrl());

openAiChatModel.setValue(aiPreferences.getOpenAiChatModel());
mistralAiChatModel.setValue(aiPreferences.getMistralAiChatModel());
geminiChatModel.setValue(aiPreferences.getGeminiChatModel());
huggingFaceChatModel.setValue(aiPreferences.getHuggingFaceChatModel());

enableAi.setValue(aiPreferences.getEnableAi());
Expand All @@ -282,7 +301,6 @@ public void setValues() {
customizeExpertSettings.setValue(aiPreferences.getCustomizeExpertSettings());

selectedEmbeddingModel.setValue(aiPreferences.getEmbeddingModel());

instruction.setValue(aiPreferences.getInstruction());
temperature.setValue(LocalizedNumbers.doubleToString(aiPreferences.getTemperature()));
contextWindowSize.setValue(aiPreferences.getContextWindowSize());
Expand All @@ -300,10 +318,12 @@ public void storeSettings() {

aiPreferences.setOpenAiChatModel(openAiChatModel.get() == null ? "" : openAiChatModel.get());
aiPreferences.setMistralAiChatModel(mistralAiChatModel.get() == null ? "" : mistralAiChatModel.get());
aiPreferences.setGeminiChatModel(geminiChatModel.get() == null ? "" : geminiChatModel.get());
aiPreferences.setHuggingFaceChatModel(huggingFaceChatModel.get() == null ? "" : huggingFaceChatModel.get());

aiPreferences.storeAiApiKeyInKeyring(AiProvider.OPEN_AI, openAiApiKey.get() == null ? "" : openAiApiKey.get());
aiPreferences.storeAiApiKeyInKeyring(AiProvider.MISTRAL_AI, mistralAiApiKey.get() == null ? "" : mistralAiApiKey.get());
aiPreferences.storeAiApiKeyInKeyring(AiProvider.GEMINI, geminiAiApiKey.get() == null ? "" : geminiAiApiKey.get());
aiPreferences.storeAiApiKeyInKeyring(AiProvider.HUGGING_FACE, huggingFaceApiKey.get() == null ? "" : huggingFaceApiKey.get());
// We notify in all cases without a real check if something was changed
aiPreferences.apiKeyUpdated();
Expand All @@ -314,6 +334,7 @@ public void storeSettings() {

aiPreferences.setOpenAiApiBaseUrl(openAiApiBaseUrl.get() == null ? "" : openAiApiBaseUrl.get());
aiPreferences.setMistralAiApiBaseUrl(mistralAiApiBaseUrl.get() == null ? "" : mistralAiApiBaseUrl.get());
aiPreferences.setGeminiApiBaseUrl(geminiApiBaseUrl.get() == null ? "" : geminiApiBaseUrl.get());
aiPreferences.setHuggingFaceApiBaseUrl(huggingFaceApiBaseUrl.get() == null ? "" : huggingFaceApiBaseUrl.get());

aiPreferences.setInstruction(instruction.get());
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@ public class AiDefaultPreferences {
AiProvider.OPEN_AI, List.of("gpt-4o-mini", "gpt-4o", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"),
// "mistral" and "mixtral" are not language mistakes.
AiProvider.MISTRAL_AI, List.of("open-mistral-nemo", "open-mistral-7b", "open-mixtral-8x7b", "open-mixtral-8x22b", "mistral-large-latest"),
AiProvider.GEMINI, List.of("gemini-1.5-flash", "gemini-1.5-pro", "gemini-1.0-pro"),
AiProvider.HUGGING_FACE, List.of()
);

public static final Map<AiProvider, String> PROVIDERS_PRIVACY_POLICIES = Map.of(
AiProvider.OPEN_AI, "https://openai.com/policies/privacy-policy/",
AiProvider.MISTRAL_AI, "https://mistral.ai/terms/#privacy-policy",
AiProvider.GEMINI, "https://ai.google.dev/gemini-api/terms",
AiProvider.HUGGING_FACE, "https://huggingface.co/privacy"
);

public static final Map<AiProvider, String> PROVIDERS_API_URLS = Map.of(
AiProvider.OPEN_AI, "https://api.openai.com/v1",
AiProvider.MISTRAL_AI, "https://api.mistral.ai/v1",
AiProvider.GEMINI, "https://generativelanguage.googleapis.com/v1beta/",
AiProvider.HUGGING_FACE, "https://huggingface.co/api"
);

Expand All @@ -34,6 +43,11 @@ public class AiDefaultPreferences {
"open-mistral-7b", 32000,
"open-mixtral-8x7b", 32000,
"open-mixtral-8x22b", 64000
),
AiProvider.GEMINI, Map.of(
"gemini-1.5-flash", 1048576,
"gemini-1.5-pro", 2097152,
"gemini-1.0-pro", 32000
)
);

Expand All @@ -44,6 +58,7 @@ public class AiDefaultPreferences {
public static final Map<AiProvider, String> CHAT_MODELS = Map.of(
AiProvider.OPEN_AI, "gpt-4o-mini",
AiProvider.MISTRAL_AI, "open-mixtral-8x22b",
AiProvider.GEMINI, "gemini-1.5-flash",
AiProvider.HUGGING_FACE, ""
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel;
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.output.Response;
Expand Down Expand Up @@ -76,8 +77,20 @@ private void rebuild() {
);
}

case GEMINI -> {
// NOTE: {@link GoogleAiGeminiChatModel} doesn't support API base url.
langchainChatModel = Optional.of(GoogleAiGeminiChatModel
.builder()
.apiKey(apiKey)
.modelName(aiPreferences.getSelectedChatModel())
.temperature(aiPreferences.getTemperature())
.logRequestsAndResponses(true)
.build()
);
}

case HUGGING_FACE -> {
// NOTE: {@link HuggingFaceChatModel} doesn't support API base url :(
// NOTE: {@link HuggingFaceChatModel} doesn't support API base url.
langchainChatModel = Optional.of(HuggingFaceChatModel
.builder()
.accessToken(apiKey)
Expand Down
Loading

0 comments on commit 2a39416

Please sign in to comment.