Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate embeddings and summaries of a group in context menu #11832

Merged
merged 10 commits into from
Oct 6, 2024
2 changes: 2 additions & 0 deletions src/main/java/org/jabref/gui/actions/StandardActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ public enum StandardActions implements Action {
GROUP_REMOVE_WITH_SUBGROUPS(Localization.lang("Also remove subgroups")),
GROUP_CHAT(Localization.lang("Chat with group")),
GROUP_EDIT(Localization.lang("Edit group")),
GROUP_GENERATE_SUMMARIES(Localization.lang("Generate summaries for entries in the group")),
GROUP_GENERATE_EMBEDDINGS(Localization.lang("Generate embeddings for linked files in the group")),
GROUP_SUBGROUP_ADD(Localization.lang("Add subgroup")),
GROUP_SUBGROUP_REMOVE(Localization.lang("Remove subgroups")),
GROUP_SUBGROUP_SORT(Localization.lang("Sort subgroups A-Z")),
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/jabref/gui/frame/JabRefFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ public JabRefFrame(Stage mainStage,
Injector.instantiateModelOrService(JournalAbbreviationRepository.class),
taskExecutor,
dialogService,
aiService,
stateManager,
fileUpdateMonitor,
entryTypesManager,
Expand Down
12 changes: 11 additions & 1 deletion src/main/java/org/jabref/gui/groups/GroupTreeView.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.jabref.gui.util.RecursiveTreeItem;
import org.jabref.gui.util.ViewModelTreeTableCellFactory;
import org.jabref.gui.util.ViewModelTreeTableRowFactory;
import org.jabref.logic.ai.AiService;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.entry.BibEntry;
Expand All @@ -82,6 +83,7 @@ public class GroupTreeView extends BorderPane {

private final StateManager stateManager;
private final DialogService dialogService;
private final AiService aiService;
private final ChatHistoryService chatHistoryService;
private final TaskExecutor taskExecutor;
private final GuiPreferences preferences;
Expand All @@ -108,12 +110,14 @@ public GroupTreeView(TaskExecutor taskExecutor,
StateManager stateManager,
GuiPreferences preferences,
DialogService dialogService,
AiService aiService,
ChatHistoryService chatHistoryService
) {
this.taskExecutor = taskExecutor;
this.stateManager = stateManager;
this.preferences = preferences;
this.dialogService = dialogService;
this.aiService = aiService;
this.chatHistoryService = chatHistoryService;

createNodes();
Expand Down Expand Up @@ -164,7 +168,7 @@ private void createNodes() {

private void initialize() {
this.localDragboard = stateManager.getLocalDragboard();
viewModel = new GroupTreeViewModel(stateManager, dialogService, chatHistoryService, preferences, taskExecutor, localDragboard);
viewModel = new GroupTreeViewModel(stateManager, dialogService, aiService, chatHistoryService, preferences, taskExecutor, localDragboard);

// Set-up groups tree
groupTree.getSelectionModel().setSelectionMode(SelectionMode.MULTIPLE);
Expand Down Expand Up @@ -555,6 +559,8 @@ private ContextMenu createContextMenuForGroup(GroupNodeViewModel group) {

contextMenu.getItems().addAll(
factory.createMenuItem(StandardActions.GROUP_EDIT, new ContextAction(StandardActions.GROUP_EDIT, group)),
factory.createMenuItem(StandardActions.GROUP_GENERATE_EMBEDDINGS, new ContextAction(StandardActions.GROUP_GENERATE_EMBEDDINGS, group)),
factory.createMenuItem(StandardActions.GROUP_GENERATE_SUMMARIES, new ContextAction(StandardActions.GROUP_GENERATE_SUMMARIES, group)),
removeGroup,
new SeparatorMenuItem(),
factory.createMenuItem(StandardActions.GROUP_SUBGROUP_ADD, new ContextAction(StandardActions.GROUP_SUBGROUP_ADD, group)),
Expand Down Expand Up @@ -668,6 +674,10 @@ public void execute() {
viewModel.editGroup(group);
groupTree.refresh();
}
case GROUP_GENERATE_EMBEDDINGS ->
viewModel.generateEmbeddings(group);
case GROUP_GENERATE_SUMMARIES ->
viewModel.generateSummaries(group);
case GROUP_CHAT ->
viewModel.chatWithGroup(group);
case GROUP_SUBGROUP_ADD ->
Expand Down
63 changes: 57 additions & 6 deletions src/main/java/org/jabref/gui/groups/GroupTreeViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;
import org.jabref.model.entry.LinkedFile;
import org.jabref.model.groups.AbstractGroup;
import org.jabref.model.groups.AutomaticKeywordGroup;
import org.jabref.model.groups.AutomaticPersonsGroup;
Expand All @@ -54,6 +55,7 @@ public class GroupTreeViewModel extends AbstractViewModel {
private final ListProperty<GroupNodeViewModel> selectedGroups = new SimpleListProperty<>(FXCollections.observableArrayList());
private final StateManager stateManager;
private final DialogService dialogService;
private final AiService aiService;
private final ChatHistoryService chatHistoryService;
private final GuiPreferences preferences;
private final TaskExecutor taskExecutor;
Expand All @@ -78,9 +80,17 @@ public class GroupTreeViewModel extends AbstractViewModel {
};
private Optional<BibDatabaseContext> currentDatabase = Optional.empty();

public GroupTreeViewModel(StateManager stateManager, DialogService dialogService, ChatHistoryService chatHistoryService, GuiPreferences preferences, TaskExecutor taskExecutor, CustomLocalDragboard localDragboard) {
public GroupTreeViewModel(StateManager stateManager,
DialogService dialogService,
AiService aiService,
ChatHistoryService chatHistoryService,
GuiPreferences preferences,
TaskExecutor taskExecutor,
CustomLocalDragboard localDragboard
) {
this.stateManager = Objects.requireNonNull(stateManager);
this.dialogService = Objects.requireNonNull(dialogService);
this.aiService = aiService;
this.chatHistoryService = Objects.requireNonNull(chatHistoryService);
this.preferences = Objects.requireNonNull(preferences);
this.taskExecutor = Objects.requireNonNull(taskExecutor);
Expand Down Expand Up @@ -386,11 +396,7 @@ public void editGroup(GroupNodeViewModel oldGroup) {
}

public void chatWithGroup(GroupNodeViewModel group) {
// This should probably be done some other way. Please don't blame, it's just a thing to make it quick and fast.
if (currentDatabase.isEmpty()) {
dialogService.showErrorDialogAndWait(Localization.lang("Unable to chat with group"), Localization.lang("No library is selected."));
return;
}
assert currentDatabase.isPresent();

StringProperty groupNameProperty = group.getGroupNode().getGroup().nameProperty();

Expand Down Expand Up @@ -430,6 +436,51 @@ private void openAiChat(StringProperty name, ObservableList<ChatMessage> chatHis
}
}

public void generateEmbeddings(GroupNodeViewModel groupNode) {
assert currentDatabase.isPresent();

AbstractGroup group = groupNode.getGroupNode().getGroup();

List<LinkedFile> linkedFiles = currentDatabase
.get()
.getDatabase()
.getEntries()
.stream()
.filter(group::isMatch)
.flatMap(entry -> entry.getFiles().stream())
.toList();

aiService.getIngestionService().ingest(
group.nameProperty(),
linkedFiles,
currentDatabase.get()
);

dialogService.notify(Localization.lang("Ingestion started for group \"%0\".", group.getName()));
}

public void generateSummaries(GroupNodeViewModel groupNode) {
assert currentDatabase.isPresent();

AbstractGroup group = groupNode.getGroupNode().getGroup();

List<BibEntry> entries = currentDatabase
.get()
.getDatabase()
.getEntries()
.stream()
.filter(group::isMatch)
.toList();

aiService.getSummariesService().summarize(
group.nameProperty(),
entries,
currentDatabase.get()
);

dialogService.notify(Localization.lang("Summarization started for group \"%0\".", group.getName()));
}

public void removeSubgroups(GroupNodeViewModel group) {
boolean confirmation = dialogService.showConfirmationDialogAndWait(
Localization.lang("Remove subgroups"),
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/jabref/gui/sidepane/SidePane.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.jabref.gui.actions.SimpleCommand;
import org.jabref.gui.ai.chatting.chathistory.ChatHistoryService;
import org.jabref.gui.preferences.GuiPreferences;
import org.jabref.logic.ai.AiService;
import org.jabref.logic.journals.JournalAbbreviationRepository;
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.entry.BibEntryTypesManager;
Expand All @@ -37,6 +38,7 @@ public SidePane(LibraryTabContainer tabContainer,
JournalAbbreviationRepository abbreviationRepository,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
StateManager stateManager,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
Expand All @@ -52,6 +54,7 @@ public SidePane(LibraryTabContainer tabContainer,
stateManager,
taskExecutor,
dialogService,
aiService,
fileUpdateMonitor,
entryTypesManager,
clipBoardManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.jabref.gui.openoffice.OpenOfficePanel;
import org.jabref.gui.preferences.GuiPreferences;
import org.jabref.gui.util.UiTaskExecutor;
import org.jabref.logic.ai.AiService;
import org.jabref.logic.journals.JournalAbbreviationRepository;
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.entry.BibEntryTypesManager;
Expand All @@ -26,6 +27,7 @@ public class SidePaneContentFactory {
private final JournalAbbreviationRepository abbreviationRepository;
private final TaskExecutor taskExecutor;
private final DialogService dialogService;
private final AiService aiService;
private final StateManager stateManager;
private final FileUpdateMonitor fileUpdateMonitor;
private final BibEntryTypesManager entryTypesManager;
Expand All @@ -38,6 +40,7 @@ public SidePaneContentFactory(LibraryTabContainer tabContainer,
JournalAbbreviationRepository abbreviationRepository,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
StateManager stateManager,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
Expand All @@ -49,6 +52,7 @@ public SidePaneContentFactory(LibraryTabContainer tabContainer,
this.abbreviationRepository = abbreviationRepository;
this.taskExecutor = taskExecutor;
this.dialogService = dialogService;
this.aiService = aiService;
this.stateManager = stateManager;
this.fileUpdateMonitor = fileUpdateMonitor;
this.entryTypesManager = entryTypesManager;
Expand All @@ -63,6 +67,7 @@ public Node create(SidePaneType sidePaneType) {
stateManager,
preferences,
dialogService,
aiService,
chatHistoryService);
case OPEN_OFFICE -> new OpenOfficePanel(
tabContainer,
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/jabref/gui/sidepane/SidePaneViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.jabref.gui.ai.chatting.chathistory.ChatHistoryService;
import org.jabref.gui.frame.SidePanePreferences;
import org.jabref.gui.preferences.GuiPreferences;
import org.jabref.logic.ai.AiService;
import org.jabref.logic.journals.JournalAbbreviationRepository;
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.entry.BibEntryTypesManager;
Expand All @@ -47,6 +48,7 @@ public SidePaneViewModel(LibraryTabContainer tabContainer,
StateManager stateManager,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
ClipBoardManager clipBoardManager,
Expand All @@ -61,6 +63,7 @@ public SidePaneViewModel(LibraryTabContainer tabContainer,
abbreviationRepository,
taskExecutor,
dialogService,
aiService,
stateManager,
fileUpdateMonitor,
entryTypesManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> {
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsForSeveralTask.class);

private final StringProperty name;
private final StringProperty groupName;
private final List<ProcessingInfo<LinkedFile, Void>> linkedFiles;
private final FileEmbeddingsManager fileEmbeddingsManager;
private final BibDatabaseContext bibDatabaseContext;
Expand All @@ -42,23 +42,23 @@ public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> {
private String currentFile = "";

public GenerateEmbeddingsForSeveralTask(
StringProperty name,
StringProperty groupName,
List<ProcessingInfo<LinkedFile, Void>> linkedFiles,
FileEmbeddingsManager fileEmbeddingsManager,
BibDatabaseContext bibDatabaseContext,
FilePreferences filePreferences,
TaskExecutor taskExecutor,
ReadOnlyBooleanProperty shutdownSignal
) {
this.name = name;
this.groupName = groupName;
this.linkedFiles = linkedFiles;
this.fileEmbeddingsManager = fileEmbeddingsManager;
this.bibDatabaseContext = bibDatabaseContext;
this.filePreferences = filePreferences;
this.taskExecutor = taskExecutor;
this.shutdownSignal = shutdownSignal;

configure(name);
configure(groupName);
}

private void configure(StringProperty name) {
Expand All @@ -73,9 +73,10 @@ private void configure(StringProperty name) {

@Override
public Void call() throws Exception {
LOGGER.debug("Starting embeddings generation of several files for {}", name.get());
LOGGER.debug("Starting embeddings generation of several files for {}", groupName.get());

List<Pair<? extends Future<?>, String>> futures = new ArrayList<>();

linkedFiles
.stream()
.map(processingInfo -> {
Expand All @@ -88,6 +89,7 @@ public Void call() throws Exception {
filePreferences,
shutdownSignal
)
.showToUser(false)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.onFinished(() -> progressCounter.increaseWorkDone(1))
Expand All @@ -101,7 +103,7 @@ public Void call() throws Exception {
pair.getKey().get();
}

LOGGER.debug("Finished embeddings generation task of several files for {}", name.get());
LOGGER.debug("Finished embeddings generation task of several files for {}", groupName.get());
progressCounter.stop();
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ public GenerateEmbeddingsTask(LinkedFile linkedFile,
this.filePreferences = filePreferences;
this.shutdownSignal = shutdownSignal;

configure(linkedFile);
configure();
}

private void configure(LinkedFile linkedFile) {
private void configure() {
showToUser(true);
titleProperty().set(Localization.lang("Generating embeddings for file '%0'", linkedFile.getLink()));

progressCounter.listenToAllProperties(this::updateProgress);
Expand Down
14 changes: 10 additions & 4 deletions src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,34 @@ public List<ProcessingInfo<LinkedFile, Void>> getProcessingInfo(List<LinkedFile>
return linkedFiles.stream().map(this::getProcessingInfo).toList();
}

public List<ProcessingInfo<LinkedFile, Void>> ingest(StringProperty name, List<LinkedFile> linkedFiles, BibDatabaseContext bibDatabaseContext) {
public List<ProcessingInfo<LinkedFile, Void>> ingest(StringProperty groupName, List<LinkedFile> linkedFiles, BibDatabaseContext bibDatabaseContext) {
List<ProcessingInfo<LinkedFile, Void>> result = getProcessingInfo(linkedFiles);

if (listsUnderIngestion.contains(linkedFiles)) {
return result;
}

listsUnderIngestion.add(linkedFiles);

List<ProcessingInfo<LinkedFile, Void>> needToProcess = result.stream().filter(processingInfo -> processingInfo.getState() == ProcessingState.STOPPED).toList();
startEmbeddingsGenerationTask(name, needToProcess, bibDatabaseContext);
startEmbeddingsGenerationTask(groupName, needToProcess, bibDatabaseContext);

return result;
}

private void startEmbeddingsGenerationTask(LinkedFile linkedFile, BibDatabaseContext bibDatabaseContext, ProcessingInfo<LinkedFile, Void> processingInfo) {
processingInfo.setState(ProcessingState.PROCESSING);

new GenerateEmbeddingsTask(linkedFile, fileEmbeddingsManager, bibDatabaseContext, filePreferences, shutdownSignal)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.executeWith(taskExecutor);
}

private void startEmbeddingsGenerationTask(StringProperty name, List<ProcessingInfo<LinkedFile, Void>> linkedFiles, BibDatabaseContext bibDatabaseContext) {
new GenerateEmbeddingsForSeveralTask(name, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor, shutdownSignal)
private void startEmbeddingsGenerationTask(StringProperty groupName, List<ProcessingInfo<LinkedFile, Void>> linkedFiles, BibDatabaseContext bibDatabaseContext) {
linkedFiles.forEach(processingInfo -> processingInfo.setState(ProcessingState.PROCESSING));

new GenerateEmbeddingsForSeveralTask(groupName, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor, shutdownSignal)
.executeWith(taskExecutor);
}

Expand Down
Loading
Loading