diff --git a/pom.xml b/pom.xml
index 0f674bf..eb757e8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -209,8 +209,9 @@
org.apache.spark
- spark-sql_2.13
- 3.2.0
+ spark-sql_2.12
+ 3.1.2
+ provided
@@ -227,6 +228,7 @@
io/atoti/spark/condition/**/*.java
io/atoti/spark/aggregation/**/*.java
io/atoti/spark/join/**/*.java
+ io/atoti/spark/operation/**/*.java
io/atoti/spark/ListQuery.java
@@ -257,6 +259,7 @@
io/atoti/spark/condition/**/*.java
io/atoti/spark/aggregation/**/*.java
io/atoti/spark/join/**/*.java
+ io/atoti/spark/operation/**/*.java
io/atoti/spark/ListQuery.java
diff --git a/src/main/java/io/atoti/spark/AggregateQuery.java b/src/main/java/io/atoti/spark/AggregateQuery.java
index 496614a..cf08871 100644
--- a/src/main/java/io/atoti/spark/AggregateQuery.java
+++ b/src/main/java/io/atoti/spark/AggregateQuery.java
@@ -3,8 +3,11 @@
import io.atoti.spark.aggregation.AggregatedValue;
import io.atoti.spark.condition.QueryCondition;
import io.atoti.spark.condition.TrueCondition;
+import io.atoti.spark.operation.Operation;
import java.util.Arrays;
import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
@@ -29,27 +32,81 @@ public static Dataset aggregate(
List groupByColumns,
List aggregations,
QueryCondition condition) {
- if (aggregations.isEmpty()) {
+ return aggregate(dataframe, groupByColumns, aggregations, Arrays.asList(), condition);
+ }
+
+ public static Dataset aggregate(
+ Dataset dataframe,
+ List groupByColumns,
+ List aggregations,
+ List operations) {
+ return aggregate(dataframe, groupByColumns, aggregations, operations, TrueCondition.value());
+ }
+
+ public static Dataset aggregate(
+ Dataset dataframe,
+ List groupByColumns,
+ List aggregations,
+ List operations,
+ QueryCondition condition) {
+ if (aggregations.isEmpty() && operations.isEmpty()) {
throw new IllegalArgumentException(
- "#aggregate can only be called with at least one AggregatedValue");
+ "#aggregate can only be called with at least one AggregatedValue or Operation");
}
final Column[] columns = groupByColumns.stream().map(functions::col).toArray(Column[]::new);
- final Column[] createdColumns =
+ final Column[] createdAggregatedColumns =
aggregations.stream().map(AggregatedValue::toColumn).toArray(Column[]::new);
- final Column[] columnsToSelect = Arrays.copyOf(columns, columns.length + createdColumns.length);
- System.arraycopy(createdColumns, 0, columnsToSelect, columns.length, createdColumns.length);
+ final Column[] createdOperationColumns =
+ operations.stream().map(Operation::toColumn).toArray(Column[]::new);
+ final Column[] columnsToSelect =
+ Arrays.copyOf(
+ columns,
+ columns.length + createdAggregatedColumns.length + createdOperationColumns.length);
+ System.arraycopy(
+ createdAggregatedColumns,
+ 0,
+ columnsToSelect,
+ columns.length,
+ createdAggregatedColumns.length);
+ System.arraycopy(
+ createdOperationColumns,
+ 0,
+ columnsToSelect,
+ columns.length + createdAggregatedColumns.length,
+ createdOperationColumns.length);
+
+ // Add needed aggregations for operations to the `aggregations` list
+ final List neededAggregations =
+ Stream.concat(
+ operations.stream().flatMap(Operation::getNeededAggregations),
+ aggregations.stream())
+ .distinct()
+ .collect(Collectors.toList());
+
+ // Aggregations
+ if (!neededAggregations.isEmpty()) {
+ final Column firstAggColumn = neededAggregations.get(0).toAggregateColumn();
+ final Column[] nextAggColumns =
+ neededAggregations.subList(1, neededAggregations.size()).stream()
+ .map(AggregatedValue::toAggregateColumn)
+ .toArray(Column[]::new);
+ dataframe =
+ dataframe
+ .filter(condition.getCondition())
+ .groupBy(columns)
+ .agg(firstAggColumn, nextAggColumns);
+ }
+
+ // Operations
+ if (!operations.isEmpty()) {
+ for (Operation op :
+ operations.stream().flatMap(Operation::getAllOperations).distinct().collect(Collectors.toList())) {
+ dataframe = dataframe.withColumn(op.getName(), op.toAggregateColumn());
+ }
+ }
- final Column firstAggColumn = aggregations.get(0).toAggregateColumn();
- final Column[] nextAggColumns =
- aggregations.subList(1, aggregations.size()).stream()
- .map(AggregatedValue::toAggregateColumn)
- .toArray(Column[]::new);
- return dataframe
- .filter(condition.getCondition())
- .groupBy(columns)
- .agg(firstAggColumn, nextAggColumns)
- .select(columnsToSelect);
+ return dataframe.select(columnsToSelect);
}
public static Dataset aggregateSql(
diff --git a/src/main/java/io/atoti/spark/CsvReader.java b/src/main/java/io/atoti/spark/CsvReader.java
index 0eacb95..6dc41ca 100644
--- a/src/main/java/io/atoti/spark/CsvReader.java
+++ b/src/main/java/io/atoti/spark/CsvReader.java
@@ -6,8 +6,11 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.functions;
+import org.apache.spark.sql.types.DataTypes;
public class CsvReader {
+
public static Dataset read(String path, SparkSession session) {
return CsvReader.read(path, session, ";");
}
@@ -27,6 +30,20 @@ public static Dataset read(String path, SparkSession session, String separa
.option("timestampFormat", "dd/MM/yyyy")
.option("inferSchema", true)
.load(url.toURI().getPath());
+ for (int i = 0; i < dataframe.columns().length; i++) {
+ if (dataframe.dtypes()[i]._2 == DataTypes.StringType.toString()) {
+ String col = dataframe.columns()[i];
+ if (dataframe
+ .filter(functions.col(col).$eq$eq$eq("").$bar$bar(functions.col(col).rlike(",")))
+ .count()
+ == dataframe.count()) {
+ // This prototype only supports arrays of integers
+ dataframe =
+ dataframe.withColumn(
+ col, functions.split(functions.col(col), ",").cast("array"));
+ }
+ }
+ }
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Failed to read csv " + path, e);
}
diff --git a/src/main/java/io/atoti/spark/Discovery.java b/src/main/java/io/atoti/spark/Discovery.java
index bd90321..e670012 100644
--- a/src/main/java/io/atoti/spark/Discovery.java
+++ b/src/main/java/io/atoti/spark/Discovery.java
@@ -30,6 +30,7 @@ private Discovery() {}
DataTypes.ShortType,
DataTypes.StringType,
DataTypes.TimestampType,
+ DataTypes.createArrayType(DataTypes.LongType),
})
.collect(Collectors.toMap(DataType::toString, Function.identity()));
diff --git a/src/main/java/io/atoti/spark/Main.java b/src/main/java/io/atoti/spark/Main.java
new file mode 100644
index 0000000..3f2f9f5
--- /dev/null
+++ b/src/main/java/io/atoti/spark/Main.java
@@ -0,0 +1,36 @@
+package io.atoti.spark;
+
+import io.atoti.spark.aggregation.SumArray;
+
+import java.util.Arrays;
+import java.util.List;
+
+import io.atoti.spark.operation.Quantile;
+import io.github.cdimascio.dotenv.Dotenv;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+
+public class Main {
+
+ static Dotenv dotenv = Dotenv.load();
+ static SparkSession spark =
+ SparkSession.builder()
+ .appName("Spark Atoti")
+ .config("spark.master", "local")
+ .config("spark.databricks.service.clusterId", dotenv.get("clusterId"))
+ .getOrCreate();
+
+ public static void main(String[] args) {
+ spark.sparkContext().addJar("./target/spark-lib-0.0.1-SNAPSHOT.jar");
+ final Dataset dataframe = spark.read().table("array");
+ SumArray price_simulations = new SumArray(
+ "price_simulations_sum", "price_simulations"
+ );
+ Quantile quantile = new Quantile("quantile", price_simulations, 95f);
+ List rows = AggregateQuery.aggregate(
+ dataframe, Arrays.asList("id"), Arrays.asList(price_simulations), Arrays.asList(quantile))
+ .collectAsList();
+ System.out.println(rows);
+ }
+}
diff --git a/src/main/java/io/atoti/spark/Utils.java b/src/main/java/io/atoti/spark/Utils.java
new file mode 100644
index 0000000..f79aa11
--- /dev/null
+++ b/src/main/java/io/atoti/spark/Utils.java
@@ -0,0 +1,83 @@
+package io.atoti.spark;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.PriorityQueue;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+class ArrayElement {
+ int index;
+ long value;
+
+ public ArrayElement(int index, long value) {
+ this.index = index;
+ this.value = value;
+ }
+}
+
+public class Utils {
+ public static ArrayList convertScalaArrayToArray(Seq arr) {
+ return new ArrayList(JavaConverters.asJavaCollectionConverter(arr).asJavaCollection());
+ }
+
+ public static Seq convertToArrayListToScalaArraySeq(List arr) {
+ return JavaConverters.asScalaBuffer(arr).iterator().toSeq();
+ }
+
+ public static long t(ArrayElement a, ArrayElement b) {
+ return b.value - a.value;
+ }
+
+ public static PriorityQueue constructMaxHeap(ArrayList arr) {
+ PriorityQueue pq =
+ new PriorityQueue(
+ (ArrayElement a, ArrayElement b) -> Long.compare(a.value, b.value));
+ pq.addAll(
+ IntStream.range(0, arr.size())
+ .mapToObj((int k) -> new ArrayElement(k, arr.get(k)))
+ .collect(Collectors.toList()));
+ return pq;
+ }
+
+ public static long quantile(ArrayList arr, float percent) {
+ PriorityQueue pq = constructMaxHeap(arr);
+ int index = (int) Math.floor(arr.size() * (100 - percent) / 100);
+
+ for (int i = arr.size() - 1; i > index; i--) {
+ pq.poll();
+ }
+
+ return pq.poll().value;
+ }
+
+ public static int quantileIndex(ArrayList arr, float percent) {
+ PriorityQueue pq = constructMaxHeap(arr);
+ int index = (int) Math.floor(arr.size() * (100 - percent) / 100);
+
+ for (int i = arr.size() - 1; i > index; i--) {
+ pq.poll();
+ }
+
+ return pq.poll().index;
+ }
+
+ public static int findKthLargestElement(ArrayList arr, int k) {
+ if (k < arr.size()) {
+ throw new ArrayIndexOutOfBoundsException();
+ }
+
+ PriorityQueue pq = new PriorityQueue(Comparator.reverseOrder());
+
+ pq.addAll(arr);
+
+ for (int i = 0; i < k; i++) {
+ pq.poll();
+ }
+
+ return pq.peek();
+ }
+}
diff --git a/src/main/java/io/atoti/spark/aggregation/SumArray.java b/src/main/java/io/atoti/spark/aggregation/SumArray.java
new file mode 100644
index 0000000..9bc0b8b
--- /dev/null
+++ b/src/main/java/io/atoti/spark/aggregation/SumArray.java
@@ -0,0 +1,136 @@
+package io.atoti.spark.aggregation;
+
+import static org.apache.spark.sql.functions.col;
+
+import io.atoti.spark.Utils;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Objects;
+import java.util.stream.IntStream;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.expressions.Aggregator;
+import scala.collection.IndexedSeq;
+
+public final class SumArray implements AggregatedValue, Serializable {
+ private static final long serialVersionUID = 8932076027241294986L;
+
+ public String name;
+ public String column;
+ private Aggregator, ?, ?> udaf;
+
+ private static long[] sum(long[] a, long[] b) {
+ if (a.length == 0) {
+ return b;
+ }
+ if (b.length == 0) {
+ return a;
+ }
+ if (a.length != b.length) {
+ throw new UnsupportedOperationException("Cannot sum arrays of different size");
+ }
+ return IntStream.range(0, a.length).mapToLong((int i) -> a[i] + b[i]).toArray();
+ }
+
+ private static long[] sum(long[] buffer, IndexedSeq value) {
+ if (buffer.length == 0) {
+ return Utils.convertScalaArrayToArray(value).stream().mapToLong(Long::longValue).toArray();
+ }
+ if (value.length() == 0) {
+ return buffer;
+ }
+ if (buffer.length != value.length()) {
+ throw new UnsupportedOperationException("Cannot sum arrays of different size");
+ }
+ return IntStream.range(0, buffer.length).mapToLong((int i) -> buffer[i] + value.apply$mcII$sp(i)).toArray();
+ }
+
+ public SumArray(String name, String column) {
+ Objects.requireNonNull(name, "No name provided");
+ Objects.requireNonNull(column, "No column provided");
+ this.name = name;
+ this.column = column;
+ this.udaf =
+ new Aggregator() {
+ private static final long serialVersionUID = -6760989932234595260L;
+
+ @Override
+ public Encoder bufferEncoder() {
+ return SparkSession.active().implicits().newLongArrayEncoder();
+ }
+
+ @Override
+ public long[] finish(long[] reduction) {
+ return reduction;
+ }
+
+ @Override
+ public long[] merge(long[] b1, long[] b2) {
+ return sum(b1, b2);
+ }
+
+ @Override
+ public Encoder outputEncoder() {
+ return SparkSession.active().implicits().newLongArrayEncoder();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public long[] reduce(long[] buffer, Row row) {
+ IndexedSeq arraySeq;
+ try {
+ arraySeq = row.getAs(column);
+ } catch (ClassCastException e) {
+ throw new UnsupportedOperationException("Column did not contains only arrays", e);
+ }
+ System.err.println("[coucou] received: " + arraySeq);
+ final long[] result = sum(buffer, arraySeq);
+ System.err.println("[coucou] result: " + Arrays.toString(result));
+ return result;
+ }
+
+ @Override
+ public long[] zero() {
+ return new long[0];
+ }
+ };
+ }
+
+ public Column toAggregateColumn() {
+ return udaf.toColumn().as(this.name);
+ }
+
+ public Column toColumn() {
+ return col(this.name);
+ }
+
+ public String toSqlQuery() {
+ throw new UnsupportedOperationException("TODO");
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+
+ if (obj.getClass() != this.getClass()) {
+ return false;
+ }
+
+ final SumArray sumArray = (SumArray) obj;
+ return sumArray.name.equals(this.name) && sumArray.column.equals(this.column);
+ }
+
+ @Override
+ public String toString() {
+ return name + " | " + column;
+ }
+
+ @Override
+ public int hashCode() {
+ return this.toString().hashCode();
+ }
+}
diff --git a/src/main/java/io/atoti/spark/aggregation/SumArrayLength.java b/src/main/java/io/atoti/spark/aggregation/SumArrayLength.java
new file mode 100644
index 0000000..4c08a59
--- /dev/null
+++ b/src/main/java/io/atoti/spark/aggregation/SumArrayLength.java
@@ -0,0 +1,104 @@
+package io.atoti.spark.aggregation;
+
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.expressions.Aggregator;
+import scala.collection.mutable.IndexedSeq;
+
+import java.io.Serializable;
+import java.util.Objects;
+
+import static org.apache.spark.sql.functions.col;
+
+public final class SumArrayLength implements AggregatedValue, Serializable {
+ private static final long serialVersionUID = 20220330_0933L;
+
+ public String name;
+ public String column;
+ private Aggregator udaf;
+
+ public SumArrayLength(final String name, final String column) {
+ Objects.requireNonNull(name, "No name provided");
+ Objects.requireNonNull(column, "No column provided");
+ this.name = name;
+ this.column = column;
+ this.udaf =
+ new Aggregator() {
+ private static final long serialVersionUID = 20220330_1005L;
+
+ @Override
+ public Encoder bufferEncoder() {
+ return Encoders.LONG();
+ }
+
+ @Override
+ public Long finish(final Long reduction) {
+ return reduction;
+ }
+
+ @Override
+ public Long merge(final Long a, final Long b) {
+ return a + b;
+ }
+
+ @Override
+ public Encoder outputEncoder() {
+ return Encoders.LONG();
+ }
+
+ @Override
+ public Long reduce(final Long result, final Row row) {
+ final IndexedSeq arraySeq;
+ try {
+ arraySeq = row.getAs(column);
+ } catch (final ClassCastException e) {
+ throw new UnsupportedOperationException("Column did not contains only arrays", e);
+ }
+ return result + arraySeq.length();
+ }
+
+ @Override
+ public Long zero() {
+ return 0L;
+ }
+ };
+ }
+
+ public Column toAggregateColumn() {
+ return this.udaf.toColumn().as(this.name);
+ }
+
+ public Column toColumn() {
+ return col(this.name);
+ }
+
+ public String toSqlQuery() {
+ throw new UnsupportedOperationException("TODO");
+ }
+
+ @Override
+ public boolean equals(final Object obj) {
+ if (obj == null) {
+ return false;
+ }
+
+ if (obj.getClass() != this.getClass()) {
+ return false;
+ }
+
+ final SumArrayLength sumArray = (SumArrayLength) obj;
+ return sumArray.name.equals(this.name) && sumArray.column.equals(this.column);
+ }
+
+ @Override
+ public String toString() {
+ return getClass().getSimpleName() + "[" + this.name + " | " + this.column + "]";
+ }
+
+ @Override
+ public int hashCode() {
+ return this.toString().hashCode();
+ }
+}
diff --git a/src/main/java/io/atoti/spark/operation/Multiply.java b/src/main/java/io/atoti/spark/operation/Multiply.java
new file mode 100644
index 0000000..309bcb0
--- /dev/null
+++ b/src/main/java/io/atoti/spark/operation/Multiply.java
@@ -0,0 +1,86 @@
+package io.atoti.spark.operation;
+
+import static io.atoti.spark.Utils.convertScalaArrayToArray;
+import static io.atoti.spark.Utils.convertToArrayListToScalaArraySeq;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.udf;
+
+import io.atoti.spark.aggregation.AggregatedValue;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.stream.Collectors;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.expressions.UserDefinedFunction;
+import org.apache.spark.sql.types.DataTypes;
+import scala.collection.Seq;
+
+public final class Multiply extends Operation {
+
+ private static UserDefinedFunction udf =
+ udf(
+ (Long x, Seq s) -> {
+ ArrayList list = convertScalaArrayToArray(s);
+ return convertToArrayListToScalaArraySeq(
+ list.stream().map((Long value) -> value * x).collect(Collectors.toList()));
+ },
+ DataTypes.createArrayType(DataTypes.LongType));
+
+ public Multiply(String name, String scalarColumn, String arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(col(scalarColumn), col(arrayColumn)).alias(name);
+ }
+
+ public Multiply(String name, String scalarColumn, AggregatedValue arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(col(scalarColumn), arrayColumn.toColumn()).alias(name);
+ this.neededAggregations = Arrays.asList(arrayColumn);
+ }
+
+ public Multiply(String name, String scalarColumn, Operation arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(col(scalarColumn), arrayColumn.toColumn()).alias(name);
+ this.neededOperations = Arrays.asList(arrayColumn);
+ }
+
+ public Multiply(String name, AggregatedValue scalarColumn, String arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(scalarColumn.toColumn(), col(arrayColumn)).alias(name);
+ this.neededAggregations = Arrays.asList(scalarColumn);
+ }
+
+ public Multiply(String name, AggregatedValue scalarColumn, AggregatedValue arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name);
+ this.neededAggregations = Arrays.asList(scalarColumn, arrayColumn);
+ }
+
+ public Multiply(String name, AggregatedValue scalarColumn, Operation arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name);
+ this.neededAggregations = Arrays.asList(scalarColumn);
+ this.neededOperations = Arrays.asList(arrayColumn);
+ }
+
+ public Multiply(String name, Operation scalarColumn, String arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(scalarColumn.toColumn(), col(arrayColumn)).alias(name);
+ this.neededOperations = Arrays.asList(scalarColumn);
+ }
+
+ public Multiply(String name, Operation scalarColumn, AggregatedValue arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name);
+ this.neededAggregations = Arrays.asList(arrayColumn);
+ this.neededOperations = Arrays.asList(scalarColumn);
+ }
+
+ public Multiply(String name, Operation scalarColumn, Operation arrayColumn) {
+ super(name);
+ this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name);
+ this.neededOperations = Arrays.asList(scalarColumn, arrayColumn);
+ }
+
+ public Column toAggregateColumn() {
+ return this.column;
+ }
+}
diff --git a/src/main/java/io/atoti/spark/operation/Operation.java b/src/main/java/io/atoti/spark/operation/Operation.java
new file mode 100644
index 0000000..566d347
--- /dev/null
+++ b/src/main/java/io/atoti/spark/operation/Operation.java
@@ -0,0 +1,75 @@
+package io.atoti.spark.operation;
+
+import static org.apache.spark.sql.functions.col;
+
+import io.atoti.spark.aggregation.AggregatedValue;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Stream;
+import org.apache.spark.sql.Column;
+
+public abstract class Operation implements Serializable {
+
+ private static final long serialVersionUID = 8932076027291294986L;
+ protected String name;
+ protected Column column;
+ protected List neededAggregations;
+ protected List neededOperations;
+
+ public Operation(String name) {
+ this.name = name;
+ this.neededAggregations = Arrays.asList();
+ this.neededOperations = Arrays.asList();
+ }
+
+ public Column toAggregateColumn() {
+ return this.column;
+ }
+
+ public Column toColumn() {
+ return col(name);
+ }
+
+ public String getName() {
+ return this.name;
+ }
+
+ public Stream getNeededAggregations() {
+ return Stream.concat(
+ this.neededAggregations.stream(),
+ this.neededOperations.stream().flatMap(Operation::getNeededAggregations));
+ }
+
+ public Stream getAllOperations() {
+ return Stream.concat(
+ this.neededOperations.stream().flatMap(Operation::getAllOperations), Stream.of(this));
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null) {
+ return false;
+ }
+
+ if (obj.getClass() != this.getClass()) {
+ return false;
+ }
+
+ final Operation op = (Operation) obj;
+ return op.name.equals(this.name)
+ && op.neededAggregations.equals(this.neededAggregations)
+ && op.neededOperations.equals(this.neededOperations);
+ }
+
+ @Override
+ public String toString() {
+ return name + " | " + neededAggregations + " | " + neededOperations;
+ }
+
+ @Override
+ public int hashCode() {
+ return this.toString().hashCode();
+ }
+}
diff --git a/src/main/java/io/atoti/spark/operation/Quantile.java b/src/main/java/io/atoti/spark/operation/Quantile.java
new file mode 100644
index 0000000..0eb2cae
--- /dev/null
+++ b/src/main/java/io/atoti/spark/operation/Quantile.java
@@ -0,0 +1,43 @@
+package io.atoti.spark.operation;
+
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.udf;
+
+import io.atoti.spark.Utils;
+import io.atoti.spark.aggregation.AggregatedValue;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import org.apache.spark.sql.expressions.UserDefinedFunction;
+import org.apache.spark.sql.types.DataTypes;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+public final class Quantile extends Operation {
+
+ static UserDefinedFunction quantileUdf(float percent) {
+ return udf(
+ (Seq arr) -> {
+ ArrayList javaArr = Utils.convertScalaArrayToArray(arr);
+ return Utils.quantile(javaArr, percent);
+ },
+ DataTypes.LongType);
+ }
+
+ public Quantile(String name, String arrayColumn, float percent) {
+ super(name);
+ this.column = quantileUdf(percent).apply(col(arrayColumn)).alias(name);
+ }
+
+ public Quantile(String name, AggregatedValue arrayColumn, float percent) {
+ super(name);
+ this.column = quantileUdf(percent).apply(arrayColumn.toColumn()).alias(name);
+ this.neededAggregations = Arrays.asList(arrayColumn);
+ }
+
+ public Quantile(String name, Operation arrayColumn, float percent) {
+ super(name);
+ this.column = quantileUdf(percent).apply(arrayColumn.toColumn()).alias(name);
+ this.neededOperations = Arrays.asList(arrayColumn);
+ }
+}
diff --git a/src/main/java/io/atoti/spark/operation/QuantileIndex.java b/src/main/java/io/atoti/spark/operation/QuantileIndex.java
new file mode 100644
index 0000000..0507453
--- /dev/null
+++ b/src/main/java/io/atoti/spark/operation/QuantileIndex.java
@@ -0,0 +1,57 @@
+package io.atoti.spark.operation;
+
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.udf;
+
+import io.atoti.spark.Utils;
+import io.atoti.spark.aggregation.AggregatedValue;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import org.apache.spark.sql.api.java.UDF1;
+import org.apache.spark.sql.expressions.UserDefinedFunction;
+import org.apache.spark.sql.types.DataTypes;
+import scala.Serializable;
+import scala.collection.Seq;
+
+public final class QuantileIndex extends Operation {
+
+ static UserDefinedFunction quantileIndexUdf(float percent) {
+ return udf(
+ new QuantileUdf(percent),
+ DataTypes.IntegerType);
+ }
+
+ public QuantileIndex(String name, String arrayColumn, float percent) {
+ super(name);
+ this.column = quantileIndexUdf(percent).apply(col(arrayColumn)).alias(name);
+ }
+
+ public QuantileIndex(String name, AggregatedValue arrayColumn, float percent) {
+ super(name);
+ this.column = quantileIndexUdf(percent).apply(arrayColumn.toColumn()).alias(name);
+ this.neededAggregations = Arrays.asList(arrayColumn);
+ }
+
+ public QuantileIndex(String name, Operation arrayColumn, float percent) {
+ super(name);
+ this.column = quantileIndexUdf(percent).apply(arrayColumn.toColumn()).alias(name);
+ this.neededOperations = Arrays.asList(arrayColumn);
+ }
+
+ private static class QuantileUdf implements UDF1, Integer>, Serializable {
+ private static final long serialVersionUID = 20220330_1217L;
+
+ private final float percent;
+
+ public QuantileUdf(float percent) {
+ this.percent = percent;
+ }
+
+ @Override
+ public Integer call(Seq arr) {
+ ArrayList javaArr = Utils.convertScalaArrayToArray(arr);
+ return Utils.quantileIndex(javaArr, percent) + 1;
+ }
+ }
+}
diff --git a/src/main/java/io/atoti/spark/operation/VectorAt.java b/src/main/java/io/atoti/spark/operation/VectorAt.java
new file mode 100644
index 0000000..37e6f04
--- /dev/null
+++ b/src/main/java/io/atoti/spark/operation/VectorAt.java
@@ -0,0 +1,29 @@
+package io.atoti.spark.operation;
+
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.element_at;
+
+import io.atoti.spark.aggregation.AggregatedValue;
+
+import java.util.Arrays;
+import java.util.List;
+
+public final class VectorAt extends Operation {
+
+ public VectorAt(String name, String arrayColumn, int position) {
+ super(name);
+ this.column = element_at(col(arrayColumn), position).alias(name);
+ }
+
+ public VectorAt(String name, AggregatedValue arrayColumn, int position) {
+ super(name);
+ this.column = element_at(arrayColumn.toColumn(), position).alias(name);
+ this.neededAggregations = Arrays.asList(arrayColumn);
+ }
+
+ public VectorAt(String name, Operation arrayColumn, int position) {
+ super(name);
+ this.column = element_at(arrayColumn.toColumn(), position).alias(name);
+ this.neededOperations = Arrays.asList(arrayColumn);
+ }
+}
diff --git a/src/main/resources/csv/array.csv b/src/main/resources/csv/array.csv
new file mode 100644
index 0000000..7deb9b3
--- /dev/null
+++ b/src/main/resources/csv/array.csv
@@ -0,0 +1,3 @@
+id;quantity;price;price_simulations
+1;10;4.98;5,7,1
+2;20;1.95;2,3,1
diff --git a/src/test/java/io/atoti/spark/TestDiscovery.java b/src/test/java/io/atoti/spark/TestDiscovery.java
index 53b8fe7..4490419 100644
--- a/src/test/java/io/atoti/spark/TestDiscovery.java
+++ b/src/test/java/io/atoti/spark/TestDiscovery.java
@@ -112,4 +112,15 @@ void testDiscoveryCalculate() {
assertTrue(dTypes.containsKey("val1_equal_val2"));
assertThat(dTypes.get("val1_equal_val2")).isEqualTo(DataTypes.BooleanType);
}
+
+ @Test
+ void testDiscoveryArray() {
+ Dataset dataframe = spark.read().table("array");
+ final Map dTypes = Discovery.discoverDataframe(dataframe);
+
+ assertThat(dataframe).isNotNull();
+ assertThat(dTypes).isNotNull();
+ assertThat(dTypes.get("price_simulations"))
+ .isEqualTo(DataTypes.createArrayType(DataTypes.LongType));
+ }
}
diff --git a/src/test/java/io/atoti/spark/TestLocalVectorAggregation.java b/src/test/java/io/atoti/spark/TestLocalVectorAggregation.java
new file mode 100644
index 0000000..696b0bd
--- /dev/null
+++ b/src/test/java/io/atoti/spark/TestLocalVectorAggregation.java
@@ -0,0 +1,66 @@
+/*
+ * (C) ActiveViam 2022
+ * ALL RIGHTS RESERVED. This material is the CONFIDENTIAL and PROPRIETARY
+ * property of ActiveViam. Any unauthorized use,
+ * reproduction or transfer of this material is strictly prohibited
+ */
+package io.atoti.spark;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import io.atoti.spark.aggregation.SumArray;
+import io.atoti.spark.aggregation.SumArrayLength;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.junit.jupiter.api.Test;
+
+public class TestLocalVectorAggregation {
+
+ SparkSession spark =
+ SparkSession.builder().appName("Spark Atoti").config("spark.master", "local").getOrCreate();
+
+ public TestLocalVectorAggregation() {
+ spark.sparkContext().setLogLevel("ERROR");
+ spark.sparkContext().addJar("./target/spark-lib-0.0.1-SNAPSHOT.jar");
+ }
+
+ @Test
+ void sumVector() {
+ final Dataset dataframe = CsvReader.read("csv/array.csv", spark);
+ var price_simulations =
+ new SumArray(
+ "price_simulations_bis", "price_simulations");
+ var rows =
+ AggregateQuery.aggregate(dataframe, List.of("id"), List.of(), List.of())
+ .collectAsList();
+
+ // result must have 2 values
+ assertThat(rows).hasSize(2);
+
+ final var rowsById =
+ rows.stream().collect(Collectors.toUnmodifiableMap(row -> (row.getAs("id")), row -> (row)));
+
+ assertThat((long) rowsById.get(1).getAs("vector-at")).isEqualTo(3);
+
+ assertThat((long) rowsById.get(2).getAs("vector-at")).isEqualTo(1);
+ }
+
+ @Test
+ void testSumArrayLength() {
+ final Dataset dataframe = CsvReader.read("csv/array.csv", spark);
+ var price_simulations =
+ new SumArrayLength(
+ "price_simulations_sum", "price_simulations");
+ var rows =
+ AggregateQuery.aggregate(
+ dataframe, List.of(), List.of(price_simulations), List.of())
+ .collectAsList();
+ System.out.println(rows);
+ assertThat(rows).hasSize(1)
+ .extracting(row -> row.getLong(0))
+ .containsExactlyInAnyOrder(6L);
+ }
+}
diff --git a/src/test/java/io/atoti/spark/TestVectorAggregation.java b/src/test/java/io/atoti/spark/TestVectorAggregation.java
new file mode 100644
index 0000000..5e37b4d
--- /dev/null
+++ b/src/test/java/io/atoti/spark/TestVectorAggregation.java
@@ -0,0 +1,323 @@
+/*
+ * (C) ActiveViam 2022
+ * ALL RIGHTS RESERVED. This material is the CONFIDENTIAL and PROPRIETARY
+ * property of ActiveViam. Any unauthorized use,
+ * reproduction or transfer of this material is strictly prohibited
+ */
+package io.atoti.spark;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import io.atoti.spark.aggregation.Sum;
+import io.atoti.spark.aggregation.SumArray;
+import io.atoti.spark.aggregation.SumArrayLength;
+import io.atoti.spark.operation.Multiply;
+import io.atoti.spark.operation.Quantile;
+import io.atoti.spark.operation.QuantileIndex;
+import io.atoti.spark.operation.VectorAt;
+import io.github.cdimascio.dotenv.Dotenv;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.junit.jupiter.api.Test;
+import scala.collection.IndexedSeq;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+public class TestVectorAggregation {
+
+ static Dotenv dotenv = Dotenv.load();
+ static SparkSession spark =
+ SparkSession.builder()
+ .appName("Spark Atoti")
+ .config("spark.master", "local")
+ .config("spark.databricks.service.clusterId", dotenv.get("clusterId", "local-id"))
+ .getOrCreate();
+
+ public TestVectorAggregation() {
+ spark.sparkContext().setLogLevel("ERROR");
+ spark.sparkContext().addJar("./target/spark-lib-0.0.1-SNAPSHOT.jar");
+ }
+
+ private static List convertScalaArrayToArray(Seq arr) {
+ return new ArrayList(JavaConverters.asJavaCollectionConverter(arr).asJavaCollection());
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ void quantile() {
+
+ final Dataset dataframe = spark.read().table("array");
+ var price_simulations =
+ new SumArray(
+ "price_simulations_sum", "price_simulations");
+ var quantile = new Quantile("quantile", price_simulations, 95f);
+ var rows =
+ AggregateQuery.aggregate(
+ dataframe, List.of("id"), List.of(price_simulations), List.of(quantile))
+ .collectAsList();
+
+ // result must have 2 values
+ assertThat(rows).hasSize(2);
+
+ final var rowsById =
+ rows.stream().collect(Collectors.toUnmodifiableMap(
+ row -> row.getAs("id").intValue(),
+ row -> (row)));
+
+ assertThat((long) rowsById.get(1).getAs("quantile")).isEqualTo(7L);
+ assertThat((long) rowsById.get(2).getAs("quantile")).isEqualTo(3L);
+ for (int i = 0; i < 3; i++) {
+ assertThat(
+ convertScalaArrayToArray(
+ rowsById.get(1).>getAs("price_simulations_sum"))
+ .get(i))
+ .isEqualTo(List.of(3L, 7L, 5L).get(i));
+ assertThat(
+ convertScalaArrayToArray(
+ rowsById.get(2).>getAs("price_simulations_sum"))
+ .get(i))
+ .isEqualTo(List.of(1L, 3L, 2L).get(i));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ void quantileIndex() {
+
+ final Dataset dataframe = spark.read().table("array");
+ var price_simulations =
+ new SumArray(
+ "price_simulations_sum", "price_simulations");
+ var quantile = new QuantileIndex("quantile", price_simulations, 95f);
+ var rows =
+ AggregateQuery.aggregate(
+ dataframe, List.of("id"), List.of(price_simulations), List.of(quantile))
+ .collectAsList();
+
+ // result must have 2 values
+ assertThat(rows).hasSize(2);
+ System.out.println(rows);
+ }
+
+ //
+ @Test
+ void vectorAt() {
+ final Dataset dataframe = spark.read().table("array");
+ var price_simulations =
+ new SumArray(
+ "price_simulations_bis", "price_simulations");
+ var vectorAt = new VectorAt("vector-at", price_simulations, 1);
+ var rows =
+ AggregateQuery.aggregate(dataframe, List.of("id"), List.of(), List.of(vectorAt))
+ .collectAsList();
+
+ // result must have 2 values
+ assertThat(rows).hasSize(2);
+
+ final var rowsById =
+ rows.stream().collect(Collectors.toUnmodifiableMap(row -> row.getAs("id").intValue(), row -> (row)));
+
+ assertThat((long) rowsById.get(1).getAs("vector-at")).isEqualTo(3);
+
+ assertThat((long) rowsById.get(2).getAs("vector-at")).isEqualTo(1);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ void simpleAggregation() {
+ final Dataset dataframe = spark.read().table("array");
+ var sumVector =
+ new SumArray("sum(vector)", "price_simulations");
+ var rows =
+ AggregateQuery.aggregate(dataframe, List.of("id"), List.of(sumVector)).collectAsList();
+
+ // result must have 2 values
+ assertThat(rows).hasSize(2);
+
+ final var rowsById =
+ rows.stream().collect(Collectors.toUnmodifiableMap(
+ row -> row.getAs("id").intValue(), row -> (row)));
+
+ for (int i = 0; i < 3; i++) {
+ assertThat(
+ convertScalaArrayToArray(rowsById.get(1).>getAs("sum(vector)"))
+ .get(i))
+ .isEqualTo(List.of(3L, 7L, 5L).get(i));
+ assertThat(
+ convertScalaArrayToArray((Seq) rowsById.get(2).getAs("sum(vector)"))
+ .get(i))
+ .isEqualTo(List.of(1L, 3L, 2L).get(i));
+ }
+ }
+
+ //
+ @SuppressWarnings("unchecked")
+ @Test
+ void vectorScaling() {
+ final Dataset dataframe = spark.read().table("array");
+ final var f_vector =
+ new Multiply(
+ "f * vector",
+ new Sum("f", "price"),
+ new SumArray(
+ "sum(vector)", "price_simulations"));
+ var rows =
+ AggregateQuery.aggregate(dataframe, List.of("id"), List.of(), List.of(f_vector))
+ .collectAsList();
+
+ // result must have 2 values
+ assertThat(rows).hasSize(2);
+
+ final var rowsById =
+ rows.stream().collect(Collectors.toUnmodifiableMap(row -> (row.getAs("id").intValue()), row -> (row)));
+
+ for (int i = 0; i < 3; i++) {
+ assertThat(
+ convertScalaArrayToArray((Seq) rowsById.get(1).getAs("f * vector")).get(i))
+ .isEqualTo(List.of(15L, 35L, 25L).get(i));
+ assertThat(
+ convertScalaArrayToArray((Seq) rowsById.get(2).getAs("f * vector")).get(i))
+ .isEqualTo(List.of(2L, 6L, 4L).get(i));
+ }
+ }
+
+ //
+ @Test
+ void vectorQuantile() {
+ final Dataset dataframe = spark.read().table("simulations");
+ var rows =
+ AggregateQuery.aggregate(
+ dataframe,
+ List.of("simulation"),
+ List.of(),
+ List.of(
+ new QuantileIndex(
+ "i95%",
+ new Multiply(
+ "f * vector",
+ new Sum("f", "factor-field"),
+ new SumArray(
+ "sum(vector)",
+ "vector-field"
+ )),
+ 95f)))
+ .collectAsList();
+
+ // result must have 3 values
+ assertThat(rows).hasSize(3);
+
+ final var rowsById =
+ rows.stream()
+ .collect(Collectors.toUnmodifiableMap(row -> (row.getAs("simulation").intValue()), row -> (row)));
+
+ assertThat((int) rowsById.get(1).getAs("i95%")).isEqualTo(7);
+
+ assertThat((int) rowsById.get(2).getAs("i95%")).isEqualTo(7);
+
+ assertThat((int) rowsById.get(3).getAs("i95%")).isEqualTo(7);
+ }
+
+ //
+ @Test
+ void simulationExplorationAtQuantile() {
+ final Dataset dataframe = spark.read().table("simulations");
+ final var revenues =
+ new Multiply(
+ "f * vector",
+ new Sum("f", "factor-field"),
+ new SumArray("sum(vector)", "vector-field"));
+ final List rows =
+ AggregateQuery.aggregate(
+ dataframe,
+ List.of("simulation"),
+ List.of(),
+ List.of(
+ new Quantile("v95%", revenues, 95f), new QuantileIndex("i95%", revenues, 95f)))
+ .collectAsList();
+ final var compareByV95 = Comparator.comparingLong(row -> (Long) (row.getAs("v95%")));
+ final int bestSimulation =
+ rows.stream()
+ .sorted(compareByV95.reversed())
+ .mapToInt(row -> (int) (row.getAs("i95%")))
+ .findFirst()
+ .orElseThrow(() -> new IllegalStateException("No data to look at"));
+ final int worstSimulation =
+ rows.stream()
+ .sorted(compareByV95)
+ .mapToInt(row -> (int) (row.getAs("i95%")))
+ .findFirst()
+ .orElseThrow(() -> new IllegalStateException("No data to look at"));
+ // Study the chosen simulations on the prices per category
+ var result =
+ AggregateQuery.aggregate(
+ dataframe,
+ List.of("simulation"),
+ List.of(),
+ List.of(
+ new VectorAt("revenue-at-best", revenues, bestSimulation),
+ new VectorAt("revenue-at-worst", revenues, worstSimulation)))
+ .collectAsList();
+
+ // result must have 3 values
+ assertThat(result).hasSize(3);
+
+ final var rowsById =
+ result.stream()
+ .collect(Collectors.toUnmodifiableMap(row -> (row.getAs("simulation").intValue()), row -> (row)));
+
+ assertThat((long) rowsById.get(1).getAs("revenue-at-best")).isEqualTo(840L);
+ assertThat((long) rowsById.get(1).getAs("revenue-at-worst")).isEqualTo(840L);
+
+ assertThat((long) rowsById.get(2).getAs("revenue-at-best")).isEqualTo(560L);
+ assertThat((long) rowsById.get(2).getAs("revenue-at-worst")).isEqualTo(560L);
+
+ assertThat((long) rowsById.get(3).getAs("revenue-at-best")).isEqualTo(690L);
+ assertThat((long) rowsById.get(3).getAs("revenue-at-worst")).isEqualTo(690L);
+ }
+
+ @Test
+ void testSumArrayLength() {
+ final Dataset dataframe = spark.read().table("array");
+ var price_simulations =
+ new SumArrayLength(
+ "price_simulations_sum", "price_simulations");
+ {
+ var totalRows =
+ AggregateQuery.aggregate(
+ dataframe, List.of(), List.of(price_simulations), List.of())
+ .collectAsList();
+ System.out.println(totalRows);
+ assertThat(totalRows).hasSize(1)
+ .extracting(row -> row.getLong(0))
+ .containsExactlyInAnyOrder(6L);
+ }
+ {
+ var rowsById =
+ AggregateQuery.aggregate(
+ dataframe, List.of("id"), List.of(price_simulations), List.of())
+ .collectAsList();
+ System.out.println(rowsById);
+ assertThat(rowsById).hasSize(2)
+ .extracting(row -> row.getLong(1))
+ .containsExactlyInAnyOrder(3L, 3L);
+ }
+ }
+
+ @Test
+ void sumVector() {
+ final Dataset dataframe = spark.read().table("array");
+ var price_simulations =
+ new SumArray(
+ "price_simulations_bis", "price_simulations");
+ var rows =
+ AggregateQuery.aggregate(dataframe, List.of("id"), List.of(price_simulations), List.of())
+ .collectAsList();
+
+ System.out.println(rows);
+ }
+}
diff --git a/src/test/resources/csv/array.csv b/src/test/resources/csv/array.csv
new file mode 100644
index 0000000..5eed92c
--- /dev/null
+++ b/src/test/resources/csv/array.csv
@@ -0,0 +1,3 @@
+id;quantity;price;price_simulations
+1;10;5;3,7,5
+2;20;2;1,3,2
diff --git a/src/test/resources/csv/simulations.csv b/src/test/resources/csv/simulations.csv
new file mode 100644
index 0000000..f4f702e
--- /dev/null
+++ b/src/test/resources/csv/simulations.csv
@@ -0,0 +1,4 @@
+simulation;factor-field;vector-field
+1;10;5,7,3,5,3,6,84
+2;20;2,3,2,4,2,3,28
+3;30;2,3,1,2,1,6,23