diff --git a/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messages.java b/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messages.java index 726a27bed8..14d9273f88 100644 --- a/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messages.java +++ b/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messages.java @@ -22,6 +22,7 @@ import org.jetbrains.annotations.NotNull; import java.util.Iterator; +import java.util.OptionalLong; import java.util.PrimitiveIterator; public final class Messages implements Iterable { @@ -33,7 +34,12 @@ public Iterator iterator() { } public interface MessageIterator extends PrimitiveIterator.OfDouble { + boolean isEmpty(); + + default OptionalLong sender() { + return OptionalLong.empty(); + } } private final MessageIterator iterator; @@ -42,12 +48,32 @@ public interface MessageIterator extends PrimitiveIterator.OfDouble { this.iterator = iterator; } + /** + * Returns a iterator that can be used to iterate over the messages. + */ @NotNull public PrimitiveIterator.OfDouble doubleIterator() { - return iterator; + return this.iterator; } + /** + * Indicates if there are messages present. + */ public boolean isEmpty() { - return iterator.isEmpty(); + return this.iterator.isEmpty(); + } + + /** + * If the computation defined a {@link org.neo4j.gds.beta.pregel.Reducer}, this method will + * return the sender of the aggregated message. Depending on the reducer implementation, the + * sender is deterministically defined by the reducer, e.g., for Max or Min. In any other case, + * the sender will be one of the node ids that sent messages to that node. + *

+ * Note, that {@link PregelConfig#trackSender()} must return true to enable sender tracking. + * + * @return the sender of an aggregated message or an empty optional if no reducer is defined + */ + public OptionalLong sender() { + return this.iterator.sender(); } } diff --git a/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messenger.java b/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messenger.java index 2c12d5f0fa..1ef3ec7198 100644 --- a/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messenger.java +++ b/pregel/src/main/java/org/neo4j/gds/beta/pregel/Messenger.java @@ -19,6 +19,8 @@ */ package org.neo4j.gds.beta.pregel; +import java.util.OptionalLong; + public interface Messenger { void initIteration(int iteration); @@ -29,5 +31,9 @@ public interface Messenger { void initMessageIterator(ITERATOR messageIterator, long nodeId, boolean isFirstIteration); + default OptionalLong sender(long nodeId) { + return OptionalLong.empty(); + } + void release(); } diff --git a/pregel/src/main/java/org/neo4j/gds/beta/pregel/Pregel.java b/pregel/src/main/java/org/neo4j/gds/beta/pregel/Pregel.java index 320704fb7e..a1bc3a4bf9 100644 --- a/pregel/src/main/java/org/neo4j/gds/beta/pregel/Pregel.java +++ b/pregel/src/main/java/org/neo4j/gds/beta/pregel/Pregel.java @@ -110,6 +110,15 @@ public static MemoryEstimation memoryEstimation( Map propertiesMap, boolean isQueueBased, boolean isAsync + ) { + return memoryEstimation(propertiesMap, isQueueBased, isAsync, false); + } + + public static MemoryEstimation memoryEstimation( + Map propertiesMap, + boolean isQueueBased, + boolean isAsync, + boolean isTrackingSender ) { var estimationBuilder = MemoryEstimations.builder(Pregel.class) .perNode("vote bits", HugeAtomicBitSet::memoryEstimation) @@ -123,7 +132,7 @@ public static MemoryEstimation memoryEstimation( estimationBuilder.add("message queues", SyncQueueMessenger.memoryEstimation()); } } else { - estimationBuilder.add("message arrays", ReducingMessenger.memoryEstimation()); + estimationBuilder.add("message arrays", ReducingMessenger.memoryEstimation(isTrackingSender)); } return estimationBuilder.build(); @@ -169,7 +178,7 @@ private Pregel( var reducer = computation.reducer(); this.messenger = reducer.isPresent() - ? new ReducingMessenger(graph, config, reducer.get()) + ? ReducingMessenger.create(graph, config, reducer.get()) : config.isAsynchronous() ? new AsyncQueueMessenger(graph.nodeCount()) : new SyncQueueMessenger(graph.nodeCount()); diff --git a/pregel/src/main/java/org/neo4j/gds/beta/pregel/PregelConfig.java b/pregel/src/main/java/org/neo4j/gds/beta/pregel/PregelConfig.java index 33456decb7..52a3426004 100644 --- a/pregel/src/main/java/org/neo4j/gds/beta/pregel/PregelConfig.java +++ b/pregel/src/main/java/org/neo4j/gds/beta/pregel/PregelConfig.java @@ -47,4 +47,9 @@ default Partitioning partitioning() { default boolean useForkJoin() { return partitioning() == Partitioning.AUTO; } + + @Configuration.Ignore + default boolean trackSender() { + return false; + } } diff --git a/pregel/src/main/java/org/neo4j/gds/beta/pregel/ReducingMessenger.java b/pregel/src/main/java/org/neo4j/gds/beta/pregel/ReducingMessenger.java index 5232aa90e2..b3404f1f4d 100644 --- a/pregel/src/main/java/org/neo4j/gds/beta/pregel/ReducingMessenger.java +++ b/pregel/src/main/java/org/neo4j/gds/beta/pregel/ReducingMessenger.java @@ -20,12 +20,15 @@ package org.neo4j.gds.beta.pregel; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray; import org.neo4j.gds.core.concurrency.ParallelUtil; -import org.neo4j.gds.termination.TerminationFlag; +import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator; import org.neo4j.gds.mem.MemoryEstimation; import org.neo4j.gds.mem.MemoryEstimations; -import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator; +import org.neo4j.gds.termination.TerminationFlag; + +import java.util.OptionalLong; /** * A messenger implementation that is backed by two double arrays used @@ -33,36 +36,54 @@ * combination with a {@link Reducer} * which atomically reduces all incoming messages into a single one. */ -public class ReducingMessenger implements Messenger { +class ReducingMessenger implements Messenger { private final Graph graph; private final PregelConfig config; - private final Reducer reducer; + final Reducer reducer; - private HugeAtomicDoubleArray sendArray; - private HugeAtomicDoubleArray receiveArray; + HugeAtomicDoubleArray sendArray; + HugeAtomicDoubleArray receiveArray; - ReducingMessenger(Graph graph, PregelConfig config, Reducer reducer) { - assert !Double.isNaN(reducer.identity()): "identity element must not be NaN"; + static ReducingMessenger create(Graph graph, PregelConfig config, Reducer reducer) { + return config.trackSender() + ? new WithSender(graph, config, reducer) + : new ReducingMessenger(graph, config, reducer); + } + + private ReducingMessenger(Graph graph, PregelConfig config, Reducer reducer) { + assert !Double.isNaN(reducer.identity()) : "identity element must not be NaN"; this.graph = graph; this.config = config; this.reducer = reducer; - this.receiveArray = HugeAtomicDoubleArray.of(graph.nodeCount(), ParallelDoublePageCreator.passThrough(config.concurrency())); - this.sendArray = HugeAtomicDoubleArray.of(graph.nodeCount(), ParallelDoublePageCreator.passThrough(config.concurrency())); + this.receiveArray = HugeAtomicDoubleArray.of( + graph.nodeCount(), + ParallelDoublePageCreator.passThrough(config.concurrency()) + ); + this.sendArray = HugeAtomicDoubleArray.of( + graph.nodeCount(), + ParallelDoublePageCreator.passThrough(config.concurrency()) + ); } - static MemoryEstimation memoryEstimation() { - return MemoryEstimations.builder(ReducingMessenger.class) + static MemoryEstimation memoryEstimation(boolean withSender) { + var builder = MemoryEstimations.builder(ReducingMessenger.class) .perNode("send array", HugeAtomicDoubleArray::memoryEstimation) - .perNode("receive array", HugeAtomicDoubleArray::memoryEstimation) + .perNode("receive array", HugeAtomicDoubleArray::memoryEstimation); + + if (withSender) { + builder + .perNode("send sender array", HugeLongArray::memoryEstimation) + .perNode("receive sender array", HugeLongArray::memoryEstimation); + } + return builder .build(); } @Override public void initIteration(int iteration) { - // Swap arrays var tmp = receiveArray; this.receiveArray = sendArray; this.sendArray = tmp; @@ -96,7 +117,7 @@ public void initMessageIterator( boolean isInitialIteration ) { var message = receiveArray.getAndReplace(nodeId, reducer.identity()); - messageIterator.init(message, message != reducer.identity()); + messageIterator.init(message, message != reducer.identity(), OptionalLong.empty()); } @Override @@ -105,14 +126,73 @@ public void release() { receiveArray.release(); } + static class WithSender extends ReducingMessenger { + private HugeLongArray sendSenderArray; + private HugeLongArray receiveSenderArray; + + WithSender(Graph graph, PregelConfig config, Reducer reducer) { + super(graph, config, reducer); + this.sendSenderArray = HugeLongArray.newArray(graph.nodeCount()); + this.receiveSenderArray = HugeLongArray.newArray(graph.nodeCount()); + } + + @Override + public void initIteration(int iteration) { + super.initIteration(iteration); + // Swap sender arrays + var tmp = receiveSenderArray; + this.receiveSenderArray = sendSenderArray; + this.sendSenderArray = tmp; + } + + @Override + public void initMessageIterator( + ReducingMessenger.SingleMessageIterator messageIterator, + long nodeId, + boolean isInitialIteration + ) { + var message = receiveArray.getAndReplace(nodeId, reducer.identity()); + var sender = receiveSenderArray.get(nodeId); + messageIterator.init(message, message != reducer.identity(), OptionalLong.of(sender)); + } + + @Override + public void sendTo(long sourceNodeId, long targetNodeId, double message) { + sendArray.update( + targetNodeId, + currentMessage -> { + var reducedMessage = reducer.reduce(currentMessage, message); + if (Double.compare(reducedMessage, currentMessage) != 0) { + sendSenderArray.set(targetNodeId, sourceNodeId); + } + return reducedMessage; + } + ); + } + + @Override + public OptionalLong sender(long nodeId) { + return OptionalLong.of(receiveSenderArray.get(nodeId)); + } + + @Override + public void release() { + sendSenderArray.release(); + receiveSenderArray.release(); + super.release(); + } + } + static class SingleMessageIterator implements Messages.MessageIterator { boolean hasNext; double message; + OptionalLong sender; - void init(double value, boolean hasNext) { + void init(double value, boolean hasNext, OptionalLong sender) { this.message = value; this.hasNext = hasNext; + this.sender = sender; } @Override @@ -130,5 +210,10 @@ public double nextDouble() { hasNext = false; return message; } + + @Override + public OptionalLong sender() { + return this.sender; + } } } diff --git a/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/ComputeContext.java b/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/ComputeContext.java index 5508d99958..227e27e49b 100644 --- a/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/ComputeContext.java +++ b/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/ComputeContext.java @@ -86,6 +86,15 @@ public long longNodeValue(String key) { return nodeValue.longValue(key, nodeId); } + /** + * Returns the node value for the given node-id and node schema key. + * + * @throws IllegalArgumentException if the key does not exist or the value is not a long + */ + public long longNodeValue(String key, long nodeId) { + return nodeValue.longValue(key, nodeId); + } + /** * Returns the node value for the given node schema key. * diff --git a/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/NodeCentricContext.java b/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/NodeCentricContext.java index ebc29ed3bf..e10fcef20e 100644 --- a/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/NodeCentricContext.java +++ b/pregel/src/main/java/org/neo4j/gds/beta/pregel/context/NodeCentricContext.java @@ -25,6 +25,7 @@ import org.neo4j.gds.beta.pregel.PregelConfig; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import java.util.OptionalLong; import java.util.function.LongConsumer; public abstract class NodeCentricContext extends PregelContext { @@ -33,8 +34,11 @@ public abstract class NodeCentricContext extends Pr protected final Graph graph; + private OptionalLong sender = OptionalLong.empty(); + long nodeId; + NodeCentricContext(Graph graph, CONFIG config, NodeValue nodeValue, ProgressTracker progressTracker) { super(config, progressTracker); this.graph = graph; diff --git a/pregel/src/test/java/org/neo4j/gds/beta/pregel/PregelTest.java b/pregel/src/test/java/org/neo4j/gds/beta/pregel/PregelTest.java index 96204bf56d..185377714f 100644 --- a/pregel/src/test/java/org/neo4j/gds/beta/pregel/PregelTest.java +++ b/pregel/src/test/java/org/neo4j/gds/beta/pregel/PregelTest.java @@ -395,8 +395,24 @@ void testMasterComputeStepWithConvergence(Partitioning partitioning) { static Stream estimations() { return Stream.of( // queue based sync - Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), true, false, 7441752L), - Arguments.of(10, new PregelSchema.Builder().add("key", ValueType.LONG).build(), true, false, 7442256L), + Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + true, + false, + false, + 7441752L + ), + Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + true, + false, + true, // sender tracking must not make a difference + 7441752L + ), + Arguments.of(10, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + true, + false, + false, + 7442256L + ), Arguments.of(1, new PregelSchema.Builder() .add("key1", ValueType.LONG) .add("key2", ValueType.DOUBLE) @@ -405,6 +421,7 @@ static Stream estimations() { .build(), true, false, + false, 9441824L ), Arguments.of(10, new PregelSchema.Builder() @@ -415,12 +432,23 @@ static Stream estimations() { .build(), true, false, + false, 9442328L ), // queue based async - Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), true, true, 3841688L), - Arguments.of(10, new PregelSchema.Builder().add("key", ValueType.LONG).build(), true, true, 3842192L), + Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + true, + true, + false, + 3841688L + ), + Arguments.of(10, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + true, + true, + false, + 3842192L + ), Arguments.of(1, new PregelSchema.Builder() .add("key1", ValueType.LONG) .add("key2", ValueType.DOUBLE) @@ -429,6 +457,7 @@ static Stream estimations() { .build(), true, true, + false, 5841760L ), Arguments.of(10, new PregelSchema.Builder() @@ -439,12 +468,29 @@ static Stream estimations() { .build(), true, true, + false, 5842264L ), - // array based - Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), false, false, 241584L), - Arguments.of(10, new PregelSchema.Builder().add("key", ValueType.LONG).build(), false, false, 242088L), + // reducer (array) based + Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + false, + false, + false, + 241584L + ), + Arguments.of(1, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + false, + false, + true, + 401664L + ), + Arguments.of(10, new PregelSchema.Builder().add("key", ValueType.LONG).build(), + false, + false, + false, + 242088L + ), Arguments.of(1, new PregelSchema.Builder() .add("key1", ValueType.LONG) .add("key2", ValueType.DOUBLE) @@ -453,6 +499,7 @@ static Stream estimations() { .build(), false, false, + false, 2241656L ), Arguments.of(10, new PregelSchema.Builder() @@ -463,6 +510,7 @@ static Stream estimations() { .build(), false, false, + false, 2242160L ) ); @@ -538,6 +586,7 @@ void memoryEstimation( PregelSchema pregelSchema, boolean isQueueBased, boolean isAsync, + boolean isTrackingSender, long expectedBytes ) { var dimensions = ImmutableGraphDimensions.builder() @@ -548,7 +597,7 @@ void memoryEstimation( assertEquals( MemoryRange.of(expectedBytes).max, Pregel - .memoryEstimation(pregelSchema.propertiesMap(), isQueueBased, isAsync) + .memoryEstimation(pregelSchema.propertiesMap(), isQueueBased, isAsync, isTrackingSender) .estimate(dimensions, new Concurrency(concurrency)) .memoryUsage().max ); diff --git a/pregel/src/test/java/org/neo4j/gds/beta/pregel/ReducingMessengerWithSenderTest.java b/pregel/src/test/java/org/neo4j/gds/beta/pregel/ReducingMessengerWithSenderTest.java new file mode 100644 index 0000000000..d5890113be --- /dev/null +++ b/pregel/src/test/java/org/neo4j/gds/beta/pregel/ReducingMessengerWithSenderTest.java @@ -0,0 +1,160 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.beta.pregel; + +import org.assertj.core.api.SoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.neo4j.gds.annotation.Configuration; +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.beta.pregel.context.ComputeContext; +import org.neo4j.gds.beta.pregel.context.InitContext; +import org.neo4j.gds.core.concurrency.DefaultPool; +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.extension.GdlExtension; +import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.Inject; +import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.mem.MemoryEstimateDefinition; +import org.neo4j.gds.mem.MemoryEstimations; +import org.neo4j.gds.termination.TerminationFlag; + +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +@GdlExtension +@ExtendWith(SoftAssertionsExtension.class) +class ReducingMessengerWithSenderTest { + + @GdlGraph + static final String GRAPH = """ + (a)-[{w: 0.42}]->(b) + (a)-[{w: 0.12}]->(c) + (b)-[{w: 0.84}]->(c) + (c)-[{w: 0.23}]->(d) + """; + + @Inject + private TestGraph graph; + + static Stream expectedSenders() { + return Stream.of( + Arguments.of(new Reducer.Max(), Map.of( + "a", "a", + "b", "a", + "c", "b", + "d", "c" + )), + Arguments.of(new Reducer.Min(), Map.of( + "a", "a", + "b", "a", + "c", "a", + "d", "c" + )) + ); + } + + @ParameterizedTest + @MethodSource("expectedSenders") + void test(Reducer reducer, Map expectedTargets, SoftAssertions softly) { + var config = TrackingConfigImpl.builder() + .relationshipWeightProperty("w") + .maxIterations(10) + .build(); + + var result = Pregel.create( + this.graph, + config, + new TestComputation(softly, reducer), + DefaultPool.INSTANCE, + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE + ).run().nodeValues().longProperties(TestComputation.SENDER); + + expectedTargets.forEach((node, expectedSender) -> { + softly.assertThat(result.get(graph.toMappedNodeId(node))).isEqualTo(graph.toMappedNodeId(expectedSender)); + }); + } + + static class TestComputation implements PregelComputation { + + static final String SENDER = "sender"; + + private final SoftAssertions softly; + private final Reducer reducer; + + TestComputation(SoftAssertions softly, Reducer reducer) { + this.softly = softly; + this.reducer = reducer; + } + + @Override + public void init(InitContext context) { + context.setNodeValue(SENDER, context.nodeId()); + } + + @Override + public void compute(ComputeContext context, Messages messages) { + if (context.isInitialSuperstep()) { + context.sendToNeighbors(1.0); + } else { + softly.assertThat(messages.sender()).isPresent(); + context.setNodeValue(SENDER, messages.sender().orElseThrow()); + context.voteToHalt(); + } + } + + @Override + public Optional reducer() { + return Optional.of(this.reducer); + } + + @Override + public double applyRelationshipWeight(double message, double relationshipWeight) { + return message * relationshipWeight; + } + + @Override + public PregelSchema schema(TrackingConfig config) { + return new PregelSchema.Builder() + .add(SENDER, ValueType.LONG) + .build(); + } + + @Override + public MemoryEstimateDefinition estimateDefinition(boolean isAsynchronous) { + return MemoryEstimations::empty; + } + } + + @Configuration("TrackingConfigImpl") + interface TrackingConfig extends PregelConfig { + @Override + @Configuration.Ignore + default boolean trackSender() { + // this will trigger the Pregel framework to use a messenger that tracks the sender + return true; + } + } +}