From 8f01520de000ceb1971c704f847c7d9e8c644b94 Mon Sep 17 00:00:00 2001 From: Technici4n <13494793+Technici4n@users.noreply.github.com> Date: Fri, 3 May 2024 22:21:17 +0200 Subject: [PATCH] Implement dependency-based (partial) ordering for parallel initialization tasks (#123) --- .../java/net/neoforged/fml/ModLoader.java | 54 ++++++++++++++++--- .../neoforged/fml/loading/LoadingModList.java | 20 +++++-- .../net/neoforged/fml/loading/ModSorter.java | 47 ++++++++-------- 3 files changed, 85 insertions(+), 36 deletions(-) diff --git a/loader/src/main/java/net/neoforged/fml/ModLoader.java b/loader/src/main/java/net/neoforged/fml/ModLoader.java index 75efe1804..7284afcd4 100644 --- a/loader/src/main/java/net/neoforged/fml/ModLoader.java +++ b/loader/src/main/java/net/neoforged/fml/ModLoader.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -190,6 +191,12 @@ public static void waitForTask(String name, Runnable periodicTask, CompletableFu } } + /** + * Exception that is fired when a mod loading future cannot be executed because a dependent future failed. + * It is only used for control flow and easy filtering out, but never logged or propagated further. + */ + private static class DependentFutureFailedException extends RuntimeException {} + /** * Dispatches a task across all mod containers in parallel, with progress displayed on the loading screen. */ @@ -197,15 +204,44 @@ public static void dispatchParallelTask(String name, Executor parallelExecutor, var progress = StartupNotificationManager.addProgressBar(name, modList.size()); try { periodicTask.run(); + Map> modFutures = new IdentityHashMap<>(modList.size()); var futureList = modList.getSortedMods().stream() .map(modContainer -> { - return CompletableFuture.runAsync(() -> { - ModLoadingContext.get().setActiveContainer(modContainer); - task.accept(modContainer); - }, parallelExecutor).whenComplete((result, exception) -> { - progress.increment(); - ModLoadingContext.get().setActiveContainer(null); - }); + // Collect futures for all dependencies first + var depFutures = LoadingModList.get().getDependencies(modContainer.getModInfo()).stream() + .map(modInfo -> { + var future = modFutures.get(modInfo); + if (future == null) { + throw new IllegalStateException("Dependency future for mod %s which is a dependency of %s not found!".formatted( + modInfo.getModId(), modContainer.getModId())); + } + return future; + }) + .toArray(CompletableFuture[]::new); + + // Build the future for this container + var future = CompletableFuture.allOf(depFutures) + .handleAsync((void_, exception) -> { + if (exception != null) { + // If there was any exception, short circuit. + // The exception will already be handled by `waitForFuture` since it comes from another mod. + LOGGER.debug("Skipping {} task for mod {} because a dependency threw an exception.", name, modContainer.getModId()); + progress.increment(); + // Throw a marker exception to make sure that dependencies of *this* task don't get executed. + throw new DependentFutureFailedException(); + } + + try { + ModLoadingContext.get().setActiveContainer(modContainer); + task.accept(modContainer); + } finally { + progress.increment(); + ModLoadingContext.get().setActiveContainer(null); + } + return null; + }, parallelExecutor); + modFutures.put(modContainer.getModInfo(), future); + return future; }) .toList(); var singleFuture = ModList.gather(futureList) @@ -226,7 +262,9 @@ private static void waitForFuture(String name, Runnable periodicTask, Completabl // Merge all potential modloading issues var errorCount = 0; for (var error : e.getCause().getSuppressed()) { - if (error instanceof ModLoadingException modLoadingException) { + if (error instanceof DependentFutureFailedException) { + continue; + } else if (error instanceof ModLoadingException modLoadingException) { loadingIssues.addAll(modLoadingException.getIssues()); } else { loadingIssues.add(ModLoadingIssue.error("fml.modloading.uncaughterror", name).withCause(e)); diff --git a/loader/src/main/java/net/neoforged/fml/loading/LoadingModList.java b/loader/src/main/java/net/neoforged/fml/loading/LoadingModList.java index 0a7c58e24..8f7df4df9 100644 --- a/loader/src/main/java/net/neoforged/fml/loading/LoadingModList.java +++ b/loader/src/main/java/net/neoforged/fml/loading/LoadingModList.java @@ -23,6 +23,7 @@ import net.neoforged.fml.loading.moddiscovery.ModFileInfo; import net.neoforged.fml.loading.moddiscovery.ModInfo; import net.neoforged.fml.loading.modscan.BackgroundScanHandler; +import net.neoforged.neoforgespi.language.IModInfo; /** * Master list of all mods in the loading context. This class cannot refer outside the @@ -32,10 +33,11 @@ public class LoadingModList { private static LoadingModList INSTANCE; private final List modFiles; private final List sortedList; + private final Map> modDependencies; private final Map fileById; private final List modLoadingIssues; - private LoadingModList(final List modFiles, final List sortedList) { + private LoadingModList(final List modFiles, final List sortedList, Map> modDependencies) { this.modFiles = modFiles.stream() .map(ModFile::getModFileInfo) .map(ModFileInfo.class::cast) @@ -43,6 +45,7 @@ private LoadingModList(final List modFiles, final List sortedL this.sortedList = sortedList.stream() .map(ModInfo.class::cast) .collect(Collectors.toList()); + this.modDependencies = modDependencies; this.fileById = this.modFiles.stream() .map(ModFileInfo::getMods) .flatMap(Collection::stream) @@ -51,8 +54,8 @@ private LoadingModList(final List modFiles, final List sortedL this.modLoadingIssues = new ArrayList<>(); } - public static LoadingModList of(List modFiles, List sortedList, List issues) { - INSTANCE = new LoadingModList(modFiles, sortedList); + public static LoadingModList of(List modFiles, List sortedList, List issues, Map> modDependencies) { + INSTANCE = new LoadingModList(modFiles, sortedList, modDependencies); INSTANCE.modLoadingIssues.addAll(issues); return INSTANCE; } @@ -154,6 +157,17 @@ public List getMods() { return this.sortedList; } + /** + * Returns all direct loading dependencies of the given mod. + * + *

This means: all the mods that are directly specified to be loaded before the given mod, + * either because the given mod has an {@link IModInfo.Ordering#AFTER} constraint on the dependency, + * or because the dependency has a {@link IModInfo.Ordering#BEFORE} constraint on the given mod. + */ + public List getDependencies(IModInfo mod) { + return this.modDependencies.getOrDefault(mod, List.of()); + } + public boolean hasErrors() { return modLoadingIssues.stream().noneMatch(issue -> issue.severity() == ModLoadingIssue.Severity.ERROR); } diff --git a/loader/src/main/java/net/neoforged/fml/loading/ModSorter.java b/loader/src/main/java/net/neoforged/fml/loading/ModSorter.java index 162089711..da7b52a70 100644 --- a/loader/src/main/java/net/neoforged/fml/loading/ModSorter.java +++ b/loader/src/main/java/net/neoforged/fml/loading/ModSorter.java @@ -46,6 +46,7 @@ public class ModSorter { private final UniqueModListBuilder uniqueModListBuilder; private List modFiles; private List sortedList; + private Map> modDependencies; private Map modIdNameLookup; private List systemMods; @@ -59,7 +60,7 @@ public static LoadingModList sort(List mods, final List (ModInfo) mf.getModInfos().get(0)).collect(toList()), e.getIssues()); + return LoadingModList.of(ms.systemMods, ms.systemMods.stream().map(mf -> (ModInfo) mf.getModInfos().get(0)).collect(toList()), e.getIssues(), Map.of()); } // try and validate dependencies @@ -69,7 +70,7 @@ public static LoadingModList sort(List mods, final List (ModInfo) mf.getModInfos().get(0)).collect(toList()), concat(issues, resolutionResult.buildErrorMessages())); + list = LoadingModList.of(ms.systemMods, ms.systemMods.stream().map(mf -> (ModInfo) mf.getModInfos().get(0)).collect(toList()), concat(issues, resolutionResult.buildErrorMessages()), Map.of()); } else { // Otherwise, lets try and sort the modlist and proceed ModLoadingException modLoadingException = null; @@ -79,9 +80,9 @@ public static LoadingModList sort(List mods, final List List concat(List... lists) { @SuppressWarnings("UnstableApiUsage") private void sort() { // lambdas are identity based, so sorting them is impossible unless you hold reference to them - final MutableGraph graph = GraphBuilder.directed().build(); + final MutableGraph graph = GraphBuilder.directed().build(); AtomicInteger counter = new AtomicInteger(); - Map infos = modFiles.stream() - .map(ModFile::getModFileInfo) - .filter(ModFileInfo.class::isInstance) - .map(ModFileInfo.class::cast) + Map infos = modFiles.stream() + .flatMap(mf -> mf.getModInfos().stream()) + .map(ModInfo.class::cast) .collect(toMap(Function.identity(), e -> counter.incrementAndGet())); infos.keySet().forEach(graph::addNode); modFiles.stream() @@ -120,7 +120,7 @@ private void sort() { .map(IModInfo::getDependencies).mapMulti(Iterable::forEach) .forEach(dep -> addDependency(graph, dep)); - final List sorted; + final List sorted; try { sorted = TopologicalSort.topologicalSort(graph, Comparator.comparing(infos::get)); } catch (CyclePresentException e) { @@ -136,22 +136,21 @@ private void sort() { .toList(); throw new ModLoadingException(dataList); } - this.sortedList = sorted.stream() - .map(ModFileInfo::getMods) - .mapMulti(Iterable::forEach) - .map(ModInfo.class::cast) - .collect(toList()); + this.sortedList = List.copyOf(sorted); + this.modDependencies = sorted.stream() + .collect(Collectors.toMap(modInfo -> modInfo, modInfo -> List.copyOf(graph.predecessors(modInfo)))); this.modFiles = sorted.stream() - .map(ModFileInfo::getFile) - .collect(toList()); + .map(mi -> mi.getOwningFile().getFile()) + .distinct() + .toList(); } @SuppressWarnings("UnstableApiUsage") - private void addDependency(MutableGraph topoGraph, IModInfo.ModVersion dep) { - final ModFileInfo self = (ModFileInfo) dep.getOwner().getOwningFile(); + private void addDependency(MutableGraph topoGraph, IModInfo.ModVersion dep) { + final ModInfo self = (ModInfo) dep.getOwner(); final IModInfo targetModInfo = modIdNameLookup.get(dep.getModId()); // soft dep that doesn't exist. Just return. No edge required. - if (targetModInfo == null || !(targetModInfo.getOwningFile() instanceof final ModFileInfo target)) return; + if (!(targetModInfo instanceof ModInfo target)) return; if (self == target) return; // in case a jar has two mods that have dependencies between switch (dep.getOrdering()) { @@ -167,11 +166,9 @@ private void buildUniqueList() { detectSystemMods(uniqueModListData.modFilesByFirstId()); - modIdNameLookup = uniqueModListData.modFilesByFirstId().entrySet().stream() - .filter(e -> !e.getValue().get(0).getModInfos().isEmpty()) - .collect(Collectors.toMap( - Map.Entry::getKey, - e -> e.getValue().get(0).getModInfos().get(0))); + modIdNameLookup = uniqueModListData.modFiles().stream() + .flatMap(mf -> mf.getModInfos().stream()) + .collect(Collectors.toMap(IModInfo::getModId, mi -> mi)); } private void detectSystemMods(final Map> modFilesByFirstId) {