diff --git a/core/src/main/java/org/apache/spark/util/collection/GuavaOrderingSnippet.java b/core/src/main/java/org/apache/spark/util/collection/GuavaOrderingSnippet.java new file mode 100644 index 0000000000000..e7585e086b290 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/GuavaOrderingSnippet.java @@ -0,0 +1,328 @@ +package org.apache.spark.util.collection; + +import com.google.common.collect.Lists; +import com.google.common.collect.Ordering; +import com.google.common.math.IntMath; + +import javax.annotation.CheckForNull; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static org.apache.spark.util.collection.GuavaOrderingSnippet.NullnessCasts.uncheckedCastNullableTToT; + +/** + * Code snippet of {@code Ordering} from guava-32.1.3 to resolve performance issue on {@code leastOf} method. + * + *

See AL-9236 for more details. + */ +public abstract class GuavaOrderingSnippet implements Comparator { + + /** + * Returns the {@code k} least elements of the given iterable according to this ordering, in order + * from least to greatest. If there are fewer than {@code k} elements present, all will be + * included. + * + *

The implementation does not necessarily use a stable sorting algorithm; when multiple + * elements are equivalent, it is undefined which will come first. + * + *

Java 8 users: Use {@code Streams.stream(iterable).collect(Comparators.least(k, + * thisComparator))} instead. + * + * @return an immutable {@code RandomAccess} list of the {@code k} least elements in ascending + * order + * @throws IllegalArgumentException if {@code k} is negative + * @since 8.0 + */ + public List leastOf(Iterable iterable, int k) { + if (iterable instanceof Collection) { + Collection collection = (Collection) iterable; + if (collection.size() <= 2L * k) { + // In this case, just dumping the collection to an array and sorting is + // faster than using the implementation for Iterator, which is + // specialized for k much smaller than n. + + @SuppressWarnings("unchecked") // c only contains E's and doesn't escape + E[] array = (E[]) collection.toArray(); + Arrays.sort(array, this); + if (array.length > k) { + array = Arrays.copyOf(array, k); + } + return Collections.unmodifiableList(Arrays.asList(array)); + } + } + return leastOf(iterable.iterator(), k); + } + + /** + * Returns the {@code k} least elements from the given iterator according to this ordering, in + * order from least to greatest. If there are fewer than {@code k} elements present, all will be + * included. + * + *

The implementation does not necessarily use a stable sorting algorithm; when multiple + * elements are equivalent, it is undefined which will come first. + * + *

Java 8 users: Use {@code Streams.stream(iterator).collect(Comparators.least(k, + * thisComparator))} instead. + * + * @return an immutable {@code RandomAccess} list of the {@code k} least elements in ascending + * order + * @throws IllegalArgumentException if {@code k} is negative + * @since 14.0 + */ + public List leastOf(Iterator iterator, int k) { + checkNotNull(iterator); + checkNonnegative(k, "k"); + + if (k == 0 || !iterator.hasNext()) { + return Collections.emptyList(); + } else if (k >= Integer.MAX_VALUE / 2) { + // k is really large; just do a straightforward sorted-copy-and-sublist + ArrayList list = Lists.newArrayList(iterator); + Collections.sort(list, this); + if (list.size() > k) { + list.subList(k, list.size()).clear(); + } + list.trimToSize(); + return Collections.unmodifiableList(list); + } else { + TopKSelector selector = TopKSelector.least(k, this); + selector.offerAll(iterator); + return selector.topK(); + } + } + + static int checkNonnegative(int value, String name) { + if (value < 0) { + throw new IllegalArgumentException(name + " cannot be negative but was: " + value); + } + return value; + } + + static final class TopKSelector { + + /** + * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, + * relative to the specified comparator, and returns them via {@link #topK} in ascending order. + * + * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} + */ + public static TopKSelector least( + int k, Comparator comparator) { + return new TopKSelector(comparator, k); + } + + private final int k; + private final Comparator comparator; + + /* + * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates + * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the + * range [0, k) and ignore the remaining elements. + */ + private final T[] buffer; + private int bufferSize; + + /** + * The largest of the lowest k elements we've seen so far relative to this comparator. If + * bufferSize ≥ k, then we can ignore any elements greater than this value. + */ + @CheckForNull + private T threshold; + + private TopKSelector(Comparator comparator, int k) { + this.comparator = checkNotNull(comparator, "comparator"); + this.k = k; + checkArgument(k >= 0, "k (%s) must be >= 0", k); + checkArgument(k <= Integer.MAX_VALUE / 2, "k (%s) must be <= Integer.MAX_VALUE / 2", k); + this.buffer = (T[]) new Object[IntMath.checkedMultiply(k, 2)]; + this.bufferSize = 0; + this.threshold = null; + } + + /** + * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes amortized + * O(1) time. + */ + public void offer(T elem) { + if (k == 0) { + return; + } else if (bufferSize == 0) { + buffer[0] = elem; + threshold = elem; + bufferSize = 1; + } else if (bufferSize < k) { + buffer[bufferSize++] = elem; + // uncheckedCastNullableTToT is safe because bufferSize > 0. + if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) > 0) { + threshold = elem; + } + // uncheckedCastNullableTToT is safe because bufferSize > 0. + } else if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) < 0) { + // Otherwise, we can ignore elem; we've seen k better elements. + buffer[bufferSize++] = elem; + if (bufferSize == 2 * k) { + trim(); + } + } + } + + /** + * Quickselects the top k elements from the 2k elements in the buffer. O(k) expected time, O(k log + * k) worst case. + */ + private void trim() { + int left = 0; + int right = 2 * k - 1; + + int minThresholdPosition = 0; + // The leftmost position at which the greatest of the k lower elements + // -- the new value of threshold -- might be found. + + int iterations = 0; + int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3; + while (left < right) { + int pivotIndex = (left + right + 1) >>> 1; + + int pivotNewIndex = partition(left, right, pivotIndex); + + if (pivotNewIndex > k) { + right = pivotNewIndex - 1; + } else if (pivotNewIndex < k) { + left = Math.max(pivotNewIndex, left + 1); + minThresholdPosition = pivotNewIndex; + } else { + break; + } + iterations++; + if (iterations >= maxIterations) { + @SuppressWarnings("nullness") // safe because we pass sort() a range that contains real Ts + T[] castBuffer = (T[]) buffer; + // We've already taken O(k log k), let's make sure we don't take longer than O(k log k). + Arrays.sort(castBuffer, left, right + 1, comparator); + break; + } + } + bufferSize = k; + + threshold = uncheckedCastNullableTToT(buffer[minThresholdPosition]); + for (int i = minThresholdPosition + 1; i < k; i++) { + if (comparator.compare( + uncheckedCastNullableTToT(buffer[i]), uncheckedCastNullableTToT(threshold)) + > 0) { + threshold = buffer[i]; + } + } + } + + /** + * Partitions the contents of buffer in the range [left, right] around the pivot element + * previously stored in buffer[pivotValue]. Returns the new index of the pivot element, + * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in + * (pivotNewIndex, right] is greater than pivotValue. + */ + private int partition(int left, int right, int pivotIndex) { + T pivotValue = uncheckedCastNullableTToT(buffer[pivotIndex]); + buffer[pivotIndex] = buffer[right]; + + int pivotNewIndex = left; + for (int i = left; i < right; i++) { + if (comparator.compare(uncheckedCastNullableTToT(buffer[i]), pivotValue) < 0) { + swap(pivotNewIndex, i); + pivotNewIndex++; + } + } + buffer[right] = buffer[pivotNewIndex]; + buffer[pivotNewIndex] = pivotValue; + return pivotNewIndex; + } + + private void swap(int i, int j) { + T tmp = buffer[i]; + buffer[i] = buffer[j]; + buffer[j] = tmp; + } + + TopKSelector combine(TopKSelector other) { + for (int i = 0; i < other.bufferSize; i++) { + this.offer(uncheckedCastNullableTToT(other.buffer[i])); + } + return this; + } + + /** + * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This + * operation takes amortized linear time in the length of {@code elements}. + * + *

If all input data to this {@code TopKSelector} is in a single {@code Iterable}, prefer + * {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use case. + */ + public void offerAll(Iterable elements) { + offerAll(elements.iterator()); + } + + /** + * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This + * operation takes amortized linear time in the length of {@code elements}. The iterator is + * consumed after this operation completes. + * + *

If all input data to this {@code TopKSelector} is in a single {@code Iterator}, prefer + * {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use case. + */ + public void offerAll(Iterator elements) { + while (elements.hasNext()) { + offer(elements.next()); + } + } + + /** + * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if + * fewer than {@code k} have been offered, in the order specified by the factory used to create + * this {@code TopKSelector}. + * + *

The returned list is an unmodifiable copy and will not be affected by further changes to + * this {@code TopKSelector}. This method returns in O(k log k) time. + */ + public List topK() { + @SuppressWarnings("nullness") // safe because we pass sort() a range that contains real Ts + T[] castBuffer = (T[]) buffer; + Arrays.sort(castBuffer, 0, bufferSize, comparator); + if (bufferSize > k) { + Arrays.fill(buffer, k, buffer.length, null); + bufferSize = k; + threshold = buffer[k - 1]; + } + // Up to bufferSize, all elements of buffer are real Ts (not null unless T includes null) + T[] topK = Arrays.copyOf(castBuffer, bufferSize); + // we have to support null elements, so no ImmutableList for us + return Collections.unmodifiableList(Arrays.asList(topK)); + } + } + + static final class NullnessCasts { + + @SuppressWarnings("nullness") + static T uncheckedCastNullableTToT(@CheckForNull T t) { + return t; + } + + /** + * Returns {@code null} as any type, even one that does not include {@code null}. + */ + @SuppressWarnings({"nullness", "TypeParameterUnusedInFormals", "ReturnMissingNullable"}) + // The warnings are legitimate. Each time we use this method, we document why. + static T unsafeNull() { + return null; + } + + private NullnessCasts() { + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index 8b543f1642a05..072f7d004eb0d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.collection import scala.collection.JavaConverters._ -import com.google.common.collect.{Ordering => GuavaOrdering} +import org.apache.spark.util.collection.{GuavaOrderingSnippet => GuavaOrdering} /** * Utility functions for collections.