Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate embeddings and summaries of a group in context menu #11832

Merged
merged 10 commits into from
Oct 6, 2024
2 changes: 2 additions & 0 deletions src/main/java/org/jabref/gui/actions/StandardActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ public enum StandardActions implements Action {
GROUP_REMOVE_WITH_SUBGROUPS(Localization.lang("Also remove subgroups")),
GROUP_CHAT(Localization.lang("Chat with group")),
GROUP_EDIT(Localization.lang("Edit group")),
GROUP_GENERATE_SUMMARIES(Localization.lang("Generate summaries for entries in the group")),
GROUP_GENERATE_EMBEDDINGS(Localization.lang("Generate embeddings for linked files in the group")),
GROUP_SUBGROUP_ADD(Localization.lang("Add subgroup")),
GROUP_SUBGROUP_REMOVE(Localization.lang("Remove subgroups")),
GROUP_SUBGROUP_SORT(Localization.lang("Sort subgroups A-Z")),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/jabref/gui/frame/JabRefFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ public JabRefFrame(Stage mainStage,
this.sidePane = new SidePane(
this,
this.preferences,
aiService,
Injector.instantiateModelOrService(JournalAbbreviationRepository.class),
taskExecutor,
dialogService,
aiService,
stateManager,
fileUpdateMonitor,
entryTypesManager,
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/jabref/gui/groups/GroupTreeView.java
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ private ContextMenu createContextMenuForGroup(GroupNodeViewModel group) {

contextMenu.getItems().addAll(
factory.createMenuItem(StandardActions.GROUP_EDIT, new ContextAction(StandardActions.GROUP_EDIT, group)),
factory.createMenuItem(StandardActions.GROUP_GENERATE_EMBEDDINGS, new ContextAction(StandardActions.GROUP_GENERATE_EMBEDDINGS, group)),
factory.createMenuItem(StandardActions.GROUP_GENERATE_SUMMARIES, new ContextAction(StandardActions.GROUP_GENERATE_SUMMARIES, group)),
removeGroup,
new SeparatorMenuItem(),
factory.createMenuItem(StandardActions.GROUP_SUBGROUP_ADD, new ContextAction(StandardActions.GROUP_SUBGROUP_ADD, group)),
Expand Down Expand Up @@ -668,6 +670,10 @@ public void execute() {
viewModel.editGroup(group);
groupTree.refresh();
}
case GROUP_GENERATE_EMBEDDINGS ->
viewModel.generateEmbeddings(group);
case GROUP_GENERATE_SUMMARIES ->
viewModel.generateSummaries(group);
case GROUP_CHAT ->
viewModel.chatWithGroup(group);
case GROUP_SUBGROUP_ADD ->
Expand Down
52 changes: 47 additions & 5 deletions src/main/java/org/jabref/gui/groups/GroupTreeViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;
import org.jabref.model.entry.LinkedFile;
import org.jabref.model.groups.AbstractGroup;
import org.jabref.model.groups.AutomaticKeywordGroup;
import org.jabref.model.groups.AutomaticPersonsGroup;
Expand Down Expand Up @@ -390,11 +391,7 @@ public void editGroup(GroupNodeViewModel oldGroup) {
}

public void chatWithGroup(GroupNodeViewModel group) {
// This should probably be done some other way. Please don't blame, it's just a thing to make it quick and fast.
if (currentDatabase.isEmpty()) {
dialogService.showErrorDialogAndWait(Localization.lang("Unable to chat with group"), Localization.lang("No library is selected."));
return;
}
assert currentDatabase.isPresent();

StringProperty groupNameProperty = group.getGroupNode().getGroup().nameProperty();

Expand Down Expand Up @@ -434,6 +431,51 @@ private void openAiChat(StringProperty name, ObservableList<ChatMessage> chatHis
}
}

public void generateEmbeddings(GroupNodeViewModel groupNode) {
assert currentDatabase.isPresent();

AbstractGroup group = groupNode.getGroupNode().getGroup();

List<LinkedFile> linkedFiles = currentDatabase
.get()
.getDatabase()
.getEntries()
.stream()
.filter(group::isMatch)
.flatMap(entry -> entry.getFiles().stream())
.toList();

aiService.getIngestionService().ingest(
group.nameProperty(),
linkedFiles,
currentDatabase.get()
);

dialogService.notify(Localization.lang("Ingestion started for group \"%0\".", group.getName()));
}

public void generateSummaries(GroupNodeViewModel groupNode) {
assert currentDatabase.isPresent();

AbstractGroup group = groupNode.getGroupNode().getGroup();

List<BibEntry> entries = currentDatabase
.get()
.getDatabase()
.getEntries()
.stream()
.filter(group::isMatch)
.toList();

aiService.getSummariesService().summarize(
group.nameProperty(),
entries,
currentDatabase.get()
);

dialogService.notify(Localization.lang("Summarization started for group \"%0\".", group.getName()));
}

public void removeSubgroups(GroupNodeViewModel group) {
boolean confirmation = dialogService.showConfirmationDialogAndWait(
Localization.lang("Remove subgroups"),
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/jabref/gui/sidepane/SidePane.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public class SidePane extends VBox {

public SidePane(LibraryTabContainer tabContainer,
GuiPreferences preferences,
AiService aiService,
JournalAbbreviationRepository abbreviationRepository,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
StateManager stateManager,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
Expand All @@ -47,11 +47,11 @@ public SidePane(LibraryTabContainer tabContainer,
this.viewModel = new SidePaneViewModel(
tabContainer,
preferences,
aiService,
abbreviationRepository,
stateManager,
taskExecutor,
dialogService,
aiService,
fileUpdateMonitor,
entryTypesManager,
clipBoardManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
public class SidePaneContentFactory {
private final LibraryTabContainer tabContainer;
private final GuiPreferences preferences;
private final AiService aiService;
private final JournalAbbreviationRepository abbreviationRepository;
private final TaskExecutor taskExecutor;
private final DialogService dialogService;
private final AiService aiService;
private final StateManager stateManager;
private final FileUpdateMonitor fileUpdateMonitor;
private final BibEntryTypesManager entryTypesManager;
Expand All @@ -34,21 +34,21 @@ public class SidePaneContentFactory {

public SidePaneContentFactory(LibraryTabContainer tabContainer,
GuiPreferences preferences,
AiService aiService,
JournalAbbreviationRepository abbreviationRepository,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
StateManager stateManager,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
ClipBoardManager clipBoardManager,
UndoManager undoManager) {
this.tabContainer = tabContainer;
this.preferences = preferences;
this.aiService = aiService;
this.abbreviationRepository = abbreviationRepository;
this.taskExecutor = taskExecutor;
this.dialogService = dialogService;
this.aiService = aiService;
this.stateManager = stateManager;
this.fileUpdateMonitor = fileUpdateMonitor;
this.entryTypesManager = entryTypesManager;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/jabref/gui/sidepane/SidePaneViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ public class SidePaneViewModel extends AbstractViewModel {

public SidePaneViewModel(LibraryTabContainer tabContainer,
GuiPreferences preferences,
AiService aiService,
JournalAbbreviationRepository abbreviationRepository,
StateManager stateManager,
TaskExecutor taskExecutor,
DialogService dialogService,
AiService aiService,
FileUpdateMonitor fileUpdateMonitor,
BibEntryTypesManager entryTypesManager,
ClipBoardManager clipBoardManager,
Expand All @@ -57,10 +57,10 @@ public SidePaneViewModel(LibraryTabContainer tabContainer,
this.sidePaneContentFactory = new SidePaneContentFactory(
tabContainer,
preferences,
aiService,
abbreviationRepository,
taskExecutor,
dialogService,
aiService,
stateManager,
fileUpdateMonitor,
entryTypesManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> {
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsForSeveralTask.class);

private final StringProperty name;
private final StringProperty groupName;
private final List<ProcessingInfo<LinkedFile, Void>> linkedFiles;
private final FileEmbeddingsManager fileEmbeddingsManager;
private final BibDatabaseContext bibDatabaseContext;
Expand All @@ -42,23 +42,23 @@ public class GenerateEmbeddingsForSeveralTask extends BackgroundTask<Void> {
private String currentFile = "";

public GenerateEmbeddingsForSeveralTask(
StringProperty name,
StringProperty groupName,
List<ProcessingInfo<LinkedFile, Void>> linkedFiles,
FileEmbeddingsManager fileEmbeddingsManager,
BibDatabaseContext bibDatabaseContext,
FilePreferences filePreferences,
TaskExecutor taskExecutor,
ReadOnlyBooleanProperty shutdownSignal
) {
this.name = name;
this.groupName = groupName;
this.linkedFiles = linkedFiles;
this.fileEmbeddingsManager = fileEmbeddingsManager;
this.bibDatabaseContext = bibDatabaseContext;
this.filePreferences = filePreferences;
this.taskExecutor = taskExecutor;
this.shutdownSignal = shutdownSignal;

configure(name);
configure(groupName);
}

private void configure(StringProperty name) {
Expand All @@ -73,9 +73,10 @@ private void configure(StringProperty name) {

@Override
public Void call() throws Exception {
LOGGER.debug("Starting embeddings generation of several files for {}", name.get());
LOGGER.debug("Starting embeddings generation of several files for {}", groupName.get());

List<Pair<? extends Future<?>, String>> futures = new ArrayList<>();

linkedFiles
.stream()
.map(processingInfo -> {
Expand All @@ -88,6 +89,7 @@ public Void call() throws Exception {
filePreferences,
shutdownSignal
)
.showToUser(false)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.onFinished(() -> progressCounter.increaseWorkDone(1))
Expand All @@ -101,7 +103,7 @@ public Void call() throws Exception {
pair.getKey().get();
}

LOGGER.debug("Finished embeddings generation task of several files for {}", name.get());
LOGGER.debug("Finished embeddings generation task of several files for {}", groupName.get());
progressCounter.stop();
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ public GenerateEmbeddingsTask(LinkedFile linkedFile,
this.filePreferences = filePreferences;
this.shutdownSignal = shutdownSignal;

configure(linkedFile);
configure();
}

private void configure(LinkedFile linkedFile) {
private void configure() {
showToUser(true);
titleProperty().set(Localization.lang("Generating embeddings for file '%0'", linkedFile.getLink()));

progressCounter.listenToAllProperties(this::updateProgress);
Expand Down
14 changes: 10 additions & 4 deletions src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,35 @@ public List<ProcessingInfo<LinkedFile, Void>> getProcessingInfo(List<LinkedFile>
return linkedFiles.stream().map(this::getProcessingInfo).toList();
}

public List<ProcessingInfo<LinkedFile, Void>> ingest(StringProperty name, List<LinkedFile> linkedFiles, BibDatabaseContext bibDatabaseContext) {
public List<ProcessingInfo<LinkedFile, Void>> ingest(StringProperty groupName, List<LinkedFile> linkedFiles, BibDatabaseContext bibDatabaseContext) {
List<ProcessingInfo<LinkedFile, Void>> result = getProcessingInfo(linkedFiles);

if (listsUnderIngestion.contains(linkedFiles)) {
return result;
}

listsUnderIngestion.add(linkedFiles);

List<ProcessingInfo<LinkedFile, Void>> needToProcess = result.stream().filter(processingInfo -> processingInfo.getState() == ProcessingState.STOPPED).toList();
startEmbeddingsGenerationTask(name, needToProcess, bibDatabaseContext);
startEmbeddingsGenerationTask(groupName, needToProcess, bibDatabaseContext);

return result;
}

private void startEmbeddingsGenerationTask(LinkedFile linkedFile, BibDatabaseContext bibDatabaseContext, ProcessingInfo<LinkedFile, Void> processingInfo) {
processingInfo.setState(ProcessingState.PROCESSING);

new GenerateEmbeddingsTask(linkedFile, fileEmbeddingsManager, bibDatabaseContext, filePreferences, shutdownSignal)
.showToUser(true)
.onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
.onFailure(processingInfo::setException)
.executeWith(taskExecutor);
}

private void startEmbeddingsGenerationTask(StringProperty name, List<ProcessingInfo<LinkedFile, Void>> linkedFiles, BibDatabaseContext bibDatabaseContext) {
new GenerateEmbeddingsForSeveralTask(name, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor, shutdownSignal)
private void startEmbeddingsGenerationTask(StringProperty groupName, List<ProcessingInfo<LinkedFile, Void>> linkedFiles, BibDatabaseContext bibDatabaseContext) {
linkedFiles.forEach(processingInfo -> processingInfo.setState(ProcessingState.PROCESSING));

new GenerateEmbeddingsForSeveralTask(groupName, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor, shutdownSignal)
.executeWith(taskExecutor);
}

Expand Down
Loading
Loading