diff --git a/pom.xml b/pom.xml index 0f674bf..eb757e8 100644 --- a/pom.xml +++ b/pom.xml @@ -209,8 +209,9 @@ org.apache.spark - spark-sql_2.13 - 3.2.0 + spark-sql_2.12 + 3.1.2 + provided @@ -227,6 +228,7 @@ io/atoti/spark/condition/**/*.java io/atoti/spark/aggregation/**/*.java io/atoti/spark/join/**/*.java + io/atoti/spark/operation/**/*.java io/atoti/spark/ListQuery.java @@ -257,6 +259,7 @@ io/atoti/spark/condition/**/*.java io/atoti/spark/aggregation/**/*.java io/atoti/spark/join/**/*.java + io/atoti/spark/operation/**/*.java io/atoti/spark/ListQuery.java diff --git a/src/main/java/io/atoti/spark/AggregateQuery.java b/src/main/java/io/atoti/spark/AggregateQuery.java index 496614a..cf08871 100644 --- a/src/main/java/io/atoti/spark/AggregateQuery.java +++ b/src/main/java/io/atoti/spark/AggregateQuery.java @@ -3,8 +3,11 @@ import io.atoti.spark.aggregation.AggregatedValue; import io.atoti.spark.condition.QueryCondition; import io.atoti.spark.condition.TrueCondition; +import io.atoti.spark.operation.Operation; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -29,27 +32,81 @@ public static Dataset aggregate( List groupByColumns, List aggregations, QueryCondition condition) { - if (aggregations.isEmpty()) { + return aggregate(dataframe, groupByColumns, aggregations, Arrays.asList(), condition); + } + + public static Dataset aggregate( + Dataset dataframe, + List groupByColumns, + List aggregations, + List operations) { + return aggregate(dataframe, groupByColumns, aggregations, operations, TrueCondition.value()); + } + + public static Dataset aggregate( + Dataset dataframe, + List groupByColumns, + List aggregations, + List operations, + QueryCondition condition) { + if (aggregations.isEmpty() && operations.isEmpty()) { throw new IllegalArgumentException( - "#aggregate can only be called with at least one AggregatedValue"); + "#aggregate can only be called with at least one AggregatedValue or Operation"); } final Column[] columns = groupByColumns.stream().map(functions::col).toArray(Column[]::new); - final Column[] createdColumns = + final Column[] createdAggregatedColumns = aggregations.stream().map(AggregatedValue::toColumn).toArray(Column[]::new); - final Column[] columnsToSelect = Arrays.copyOf(columns, columns.length + createdColumns.length); - System.arraycopy(createdColumns, 0, columnsToSelect, columns.length, createdColumns.length); + final Column[] createdOperationColumns = + operations.stream().map(Operation::toColumn).toArray(Column[]::new); + final Column[] columnsToSelect = + Arrays.copyOf( + columns, + columns.length + createdAggregatedColumns.length + createdOperationColumns.length); + System.arraycopy( + createdAggregatedColumns, + 0, + columnsToSelect, + columns.length, + createdAggregatedColumns.length); + System.arraycopy( + createdOperationColumns, + 0, + columnsToSelect, + columns.length + createdAggregatedColumns.length, + createdOperationColumns.length); + + // Add needed aggregations for operations to the `aggregations` list + final List neededAggregations = + Stream.concat( + operations.stream().flatMap(Operation::getNeededAggregations), + aggregations.stream()) + .distinct() + .collect(Collectors.toList()); + + // Aggregations + if (!neededAggregations.isEmpty()) { + final Column firstAggColumn = neededAggregations.get(0).toAggregateColumn(); + final Column[] nextAggColumns = + neededAggregations.subList(1, neededAggregations.size()).stream() + .map(AggregatedValue::toAggregateColumn) + .toArray(Column[]::new); + dataframe = + dataframe + .filter(condition.getCondition()) + .groupBy(columns) + .agg(firstAggColumn, nextAggColumns); + } + + // Operations + if (!operations.isEmpty()) { + for (Operation op : + operations.stream().flatMap(Operation::getAllOperations).distinct().collect(Collectors.toList())) { + dataframe = dataframe.withColumn(op.getName(), op.toAggregateColumn()); + } + } - final Column firstAggColumn = aggregations.get(0).toAggregateColumn(); - final Column[] nextAggColumns = - aggregations.subList(1, aggregations.size()).stream() - .map(AggregatedValue::toAggregateColumn) - .toArray(Column[]::new); - return dataframe - .filter(condition.getCondition()) - .groupBy(columns) - .agg(firstAggColumn, nextAggColumns) - .select(columnsToSelect); + return dataframe.select(columnsToSelect); } public static Dataset aggregateSql( diff --git a/src/main/java/io/atoti/spark/CsvReader.java b/src/main/java/io/atoti/spark/CsvReader.java index 0eacb95..6dc41ca 100644 --- a/src/main/java/io/atoti/spark/CsvReader.java +++ b/src/main/java/io/atoti/spark/CsvReader.java @@ -6,8 +6,11 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; public class CsvReader { + public static Dataset read(String path, SparkSession session) { return CsvReader.read(path, session, ";"); } @@ -27,6 +30,20 @@ public static Dataset read(String path, SparkSession session, String separa .option("timestampFormat", "dd/MM/yyyy") .option("inferSchema", true) .load(url.toURI().getPath()); + for (int i = 0; i < dataframe.columns().length; i++) { + if (dataframe.dtypes()[i]._2 == DataTypes.StringType.toString()) { + String col = dataframe.columns()[i]; + if (dataframe + .filter(functions.col(col).$eq$eq$eq("").$bar$bar(functions.col(col).rlike(","))) + .count() + == dataframe.count()) { + // This prototype only supports arrays of integers + dataframe = + dataframe.withColumn( + col, functions.split(functions.col(col), ",").cast("array")); + } + } + } } catch (URISyntaxException e) { throw new IllegalArgumentException("Failed to read csv " + path, e); } diff --git a/src/main/java/io/atoti/spark/Discovery.java b/src/main/java/io/atoti/spark/Discovery.java index bd90321..e670012 100644 --- a/src/main/java/io/atoti/spark/Discovery.java +++ b/src/main/java/io/atoti/spark/Discovery.java @@ -30,6 +30,7 @@ private Discovery() {} DataTypes.ShortType, DataTypes.StringType, DataTypes.TimestampType, + DataTypes.createArrayType(DataTypes.LongType), }) .collect(Collectors.toMap(DataType::toString, Function.identity())); diff --git a/src/main/java/io/atoti/spark/Main.java b/src/main/java/io/atoti/spark/Main.java new file mode 100644 index 0000000..3f2f9f5 --- /dev/null +++ b/src/main/java/io/atoti/spark/Main.java @@ -0,0 +1,36 @@ +package io.atoti.spark; + +import io.atoti.spark.aggregation.SumArray; + +import java.util.Arrays; +import java.util.List; + +import io.atoti.spark.operation.Quantile; +import io.github.cdimascio.dotenv.Dotenv; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +public class Main { + + static Dotenv dotenv = Dotenv.load(); + static SparkSession spark = + SparkSession.builder() + .appName("Spark Atoti") + .config("spark.master", "local") + .config("spark.databricks.service.clusterId", dotenv.get("clusterId")) + .getOrCreate(); + + public static void main(String[] args) { + spark.sparkContext().addJar("./target/spark-lib-0.0.1-SNAPSHOT.jar"); + final Dataset dataframe = spark.read().table("array"); + SumArray price_simulations = new SumArray( + "price_simulations_sum", "price_simulations" + ); + Quantile quantile = new Quantile("quantile", price_simulations, 95f); + List rows = AggregateQuery.aggregate( + dataframe, Arrays.asList("id"), Arrays.asList(price_simulations), Arrays.asList(quantile)) + .collectAsList(); + System.out.println(rows); + } +} diff --git a/src/main/java/io/atoti/spark/Utils.java b/src/main/java/io/atoti/spark/Utils.java new file mode 100644 index 0000000..f79aa11 --- /dev/null +++ b/src/main/java/io/atoti/spark/Utils.java @@ -0,0 +1,83 @@ +package io.atoti.spark; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +class ArrayElement { + int index; + long value; + + public ArrayElement(int index, long value) { + this.index = index; + this.value = value; + } +} + +public class Utils { + public static ArrayList convertScalaArrayToArray(Seq arr) { + return new ArrayList(JavaConverters.asJavaCollectionConverter(arr).asJavaCollection()); + } + + public static Seq convertToArrayListToScalaArraySeq(List arr) { + return JavaConverters.asScalaBuffer(arr).iterator().toSeq(); + } + + public static long t(ArrayElement a, ArrayElement b) { + return b.value - a.value; + } + + public static PriorityQueue constructMaxHeap(ArrayList arr) { + PriorityQueue pq = + new PriorityQueue( + (ArrayElement a, ArrayElement b) -> Long.compare(a.value, b.value)); + pq.addAll( + IntStream.range(0, arr.size()) + .mapToObj((int k) -> new ArrayElement(k, arr.get(k))) + .collect(Collectors.toList())); + return pq; + } + + public static long quantile(ArrayList arr, float percent) { + PriorityQueue pq = constructMaxHeap(arr); + int index = (int) Math.floor(arr.size() * (100 - percent) / 100); + + for (int i = arr.size() - 1; i > index; i--) { + pq.poll(); + } + + return pq.poll().value; + } + + public static int quantileIndex(ArrayList arr, float percent) { + PriorityQueue pq = constructMaxHeap(arr); + int index = (int) Math.floor(arr.size() * (100 - percent) / 100); + + for (int i = arr.size() - 1; i > index; i--) { + pq.poll(); + } + + return pq.poll().index; + } + + public static int findKthLargestElement(ArrayList arr, int k) { + if (k < arr.size()) { + throw new ArrayIndexOutOfBoundsException(); + } + + PriorityQueue pq = new PriorityQueue(Comparator.reverseOrder()); + + pq.addAll(arr); + + for (int i = 0; i < k; i++) { + pq.poll(); + } + + return pq.peek(); + } +} diff --git a/src/main/java/io/atoti/spark/aggregation/SumArray.java b/src/main/java/io/atoti/spark/aggregation/SumArray.java new file mode 100644 index 0000000..9bc0b8b --- /dev/null +++ b/src/main/java/io/atoti/spark/aggregation/SumArray.java @@ -0,0 +1,136 @@ +package io.atoti.spark.aggregation; + +import static org.apache.spark.sql.functions.col; + +import io.atoti.spark.Utils; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Objects; +import java.util.stream.IntStream; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.expressions.Aggregator; +import scala.collection.IndexedSeq; + +public final class SumArray implements AggregatedValue, Serializable { + private static final long serialVersionUID = 8932076027241294986L; + + public String name; + public String column; + private Aggregator udaf; + + private static long[] sum(long[] a, long[] b) { + if (a.length == 0) { + return b; + } + if (b.length == 0) { + return a; + } + if (a.length != b.length) { + throw new UnsupportedOperationException("Cannot sum arrays of different size"); + } + return IntStream.range(0, a.length).mapToLong((int i) -> a[i] + b[i]).toArray(); + } + + private static long[] sum(long[] buffer, IndexedSeq value) { + if (buffer.length == 0) { + return Utils.convertScalaArrayToArray(value).stream().mapToLong(Long::longValue).toArray(); + } + if (value.length() == 0) { + return buffer; + } + if (buffer.length != value.length()) { + throw new UnsupportedOperationException("Cannot sum arrays of different size"); + } + return IntStream.range(0, buffer.length).mapToLong((int i) -> buffer[i] + value.apply$mcII$sp(i)).toArray(); + } + + public SumArray(String name, String column) { + Objects.requireNonNull(name, "No name provided"); + Objects.requireNonNull(column, "No column provided"); + this.name = name; + this.column = column; + this.udaf = + new Aggregator() { + private static final long serialVersionUID = -6760989932234595260L; + + @Override + public Encoder bufferEncoder() { + return SparkSession.active().implicits().newLongArrayEncoder(); + } + + @Override + public long[] finish(long[] reduction) { + return reduction; + } + + @Override + public long[] merge(long[] b1, long[] b2) { + return sum(b1, b2); + } + + @Override + public Encoder outputEncoder() { + return SparkSession.active().implicits().newLongArrayEncoder(); + } + + @SuppressWarnings("unchecked") + @Override + public long[] reduce(long[] buffer, Row row) { + IndexedSeq arraySeq; + try { + arraySeq = row.getAs(column); + } catch (ClassCastException e) { + throw new UnsupportedOperationException("Column did not contains only arrays", e); + } + System.err.println("[coucou] received: " + arraySeq); + final long[] result = sum(buffer, arraySeq); + System.err.println("[coucou] result: " + Arrays.toString(result)); + return result; + } + + @Override + public long[] zero() { + return new long[0]; + } + }; + } + + public Column toAggregateColumn() { + return udaf.toColumn().as(this.name); + } + + public Column toColumn() { + return col(this.name); + } + + public String toSqlQuery() { + throw new UnsupportedOperationException("TODO"); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + + if (obj.getClass() != this.getClass()) { + return false; + } + + final SumArray sumArray = (SumArray) obj; + return sumArray.name.equals(this.name) && sumArray.column.equals(this.column); + } + + @Override + public String toString() { + return name + " | " + column; + } + + @Override + public int hashCode() { + return this.toString().hashCode(); + } +} diff --git a/src/main/java/io/atoti/spark/aggregation/SumArrayLength.java b/src/main/java/io/atoti/spark/aggregation/SumArrayLength.java new file mode 100644 index 0000000..4c08a59 --- /dev/null +++ b/src/main/java/io/atoti/spark/aggregation/SumArrayLength.java @@ -0,0 +1,104 @@ +package io.atoti.spark.aggregation; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.Aggregator; +import scala.collection.mutable.IndexedSeq; + +import java.io.Serializable; +import java.util.Objects; + +import static org.apache.spark.sql.functions.col; + +public final class SumArrayLength implements AggregatedValue, Serializable { + private static final long serialVersionUID = 20220330_0933L; + + public String name; + public String column; + private Aggregator udaf; + + public SumArrayLength(final String name, final String column) { + Objects.requireNonNull(name, "No name provided"); + Objects.requireNonNull(column, "No column provided"); + this.name = name; + this.column = column; + this.udaf = + new Aggregator() { + private static final long serialVersionUID = 20220330_1005L; + + @Override + public Encoder bufferEncoder() { + return Encoders.LONG(); + } + + @Override + public Long finish(final Long reduction) { + return reduction; + } + + @Override + public Long merge(final Long a, final Long b) { + return a + b; + } + + @Override + public Encoder outputEncoder() { + return Encoders.LONG(); + } + + @Override + public Long reduce(final Long result, final Row row) { + final IndexedSeq arraySeq; + try { + arraySeq = row.getAs(column); + } catch (final ClassCastException e) { + throw new UnsupportedOperationException("Column did not contains only arrays", e); + } + return result + arraySeq.length(); + } + + @Override + public Long zero() { + return 0L; + } + }; + } + + public Column toAggregateColumn() { + return this.udaf.toColumn().as(this.name); + } + + public Column toColumn() { + return col(this.name); + } + + public String toSqlQuery() { + throw new UnsupportedOperationException("TODO"); + } + + @Override + public boolean equals(final Object obj) { + if (obj == null) { + return false; + } + + if (obj.getClass() != this.getClass()) { + return false; + } + + final SumArrayLength sumArray = (SumArrayLength) obj; + return sumArray.name.equals(this.name) && sumArray.column.equals(this.column); + } + + @Override + public String toString() { + return getClass().getSimpleName() + "[" + this.name + " | " + this.column + "]"; + } + + @Override + public int hashCode() { + return this.toString().hashCode(); + } +} diff --git a/src/main/java/io/atoti/spark/operation/Multiply.java b/src/main/java/io/atoti/spark/operation/Multiply.java new file mode 100644 index 0000000..309bcb0 --- /dev/null +++ b/src/main/java/io/atoti/spark/operation/Multiply.java @@ -0,0 +1,86 @@ +package io.atoti.spark.operation; + +import static io.atoti.spark.Utils.convertScalaArrayToArray; +import static io.atoti.spark.Utils.convertToArrayListToScalaArraySeq; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.udf; + +import io.atoti.spark.aggregation.AggregatedValue; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.stream.Collectors; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.types.DataTypes; +import scala.collection.Seq; + +public final class Multiply extends Operation { + + private static UserDefinedFunction udf = + udf( + (Long x, Seq s) -> { + ArrayList list = convertScalaArrayToArray(s); + return convertToArrayListToScalaArraySeq( + list.stream().map((Long value) -> value * x).collect(Collectors.toList())); + }, + DataTypes.createArrayType(DataTypes.LongType)); + + public Multiply(String name, String scalarColumn, String arrayColumn) { + super(name); + this.column = Multiply.udf.apply(col(scalarColumn), col(arrayColumn)).alias(name); + } + + public Multiply(String name, String scalarColumn, AggregatedValue arrayColumn) { + super(name); + this.column = Multiply.udf.apply(col(scalarColumn), arrayColumn.toColumn()).alias(name); + this.neededAggregations = Arrays.asList(arrayColumn); + } + + public Multiply(String name, String scalarColumn, Operation arrayColumn) { + super(name); + this.column = Multiply.udf.apply(col(scalarColumn), arrayColumn.toColumn()).alias(name); + this.neededOperations = Arrays.asList(arrayColumn); + } + + public Multiply(String name, AggregatedValue scalarColumn, String arrayColumn) { + super(name); + this.column = Multiply.udf.apply(scalarColumn.toColumn(), col(arrayColumn)).alias(name); + this.neededAggregations = Arrays.asList(scalarColumn); + } + + public Multiply(String name, AggregatedValue scalarColumn, AggregatedValue arrayColumn) { + super(name); + this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name); + this.neededAggregations = Arrays.asList(scalarColumn, arrayColumn); + } + + public Multiply(String name, AggregatedValue scalarColumn, Operation arrayColumn) { + super(name); + this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name); + this.neededAggregations = Arrays.asList(scalarColumn); + this.neededOperations = Arrays.asList(arrayColumn); + } + + public Multiply(String name, Operation scalarColumn, String arrayColumn) { + super(name); + this.column = Multiply.udf.apply(scalarColumn.toColumn(), col(arrayColumn)).alias(name); + this.neededOperations = Arrays.asList(scalarColumn); + } + + public Multiply(String name, Operation scalarColumn, AggregatedValue arrayColumn) { + super(name); + this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name); + this.neededAggregations = Arrays.asList(arrayColumn); + this.neededOperations = Arrays.asList(scalarColumn); + } + + public Multiply(String name, Operation scalarColumn, Operation arrayColumn) { + super(name); + this.column = Multiply.udf.apply(scalarColumn.toColumn(), arrayColumn.toColumn()).alias(name); + this.neededOperations = Arrays.asList(scalarColumn, arrayColumn); + } + + public Column toAggregateColumn() { + return this.column; + } +} diff --git a/src/main/java/io/atoti/spark/operation/Operation.java b/src/main/java/io/atoti/spark/operation/Operation.java new file mode 100644 index 0000000..566d347 --- /dev/null +++ b/src/main/java/io/atoti/spark/operation/Operation.java @@ -0,0 +1,75 @@ +package io.atoti.spark.operation; + +import static org.apache.spark.sql.functions.col; + +import io.atoti.spark.aggregation.AggregatedValue; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; +import org.apache.spark.sql.Column; + +public abstract class Operation implements Serializable { + + private static final long serialVersionUID = 8932076027291294986L; + protected String name; + protected Column column; + protected List neededAggregations; + protected List neededOperations; + + public Operation(String name) { + this.name = name; + this.neededAggregations = Arrays.asList(); + this.neededOperations = Arrays.asList(); + } + + public Column toAggregateColumn() { + return this.column; + } + + public Column toColumn() { + return col(name); + } + + public String getName() { + return this.name; + } + + public Stream getNeededAggregations() { + return Stream.concat( + this.neededAggregations.stream(), + this.neededOperations.stream().flatMap(Operation::getNeededAggregations)); + } + + public Stream getAllOperations() { + return Stream.concat( + this.neededOperations.stream().flatMap(Operation::getAllOperations), Stream.of(this)); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + + if (obj.getClass() != this.getClass()) { + return false; + } + + final Operation op = (Operation) obj; + return op.name.equals(this.name) + && op.neededAggregations.equals(this.neededAggregations) + && op.neededOperations.equals(this.neededOperations); + } + + @Override + public String toString() { + return name + " | " + neededAggregations + " | " + neededOperations; + } + + @Override + public int hashCode() { + return this.toString().hashCode(); + } +} diff --git a/src/main/java/io/atoti/spark/operation/Quantile.java b/src/main/java/io/atoti/spark/operation/Quantile.java new file mode 100644 index 0000000..0eb2cae --- /dev/null +++ b/src/main/java/io/atoti/spark/operation/Quantile.java @@ -0,0 +1,43 @@ +package io.atoti.spark.operation; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.udf; + +import io.atoti.spark.Utils; +import io.atoti.spark.aggregation.AggregatedValue; + +import java.util.ArrayList; +import java.util.Arrays; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.types.DataTypes; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public final class Quantile extends Operation { + + static UserDefinedFunction quantileUdf(float percent) { + return udf( + (Seq arr) -> { + ArrayList javaArr = Utils.convertScalaArrayToArray(arr); + return Utils.quantile(javaArr, percent); + }, + DataTypes.LongType); + } + + public Quantile(String name, String arrayColumn, float percent) { + super(name); + this.column = quantileUdf(percent).apply(col(arrayColumn)).alias(name); + } + + public Quantile(String name, AggregatedValue arrayColumn, float percent) { + super(name); + this.column = quantileUdf(percent).apply(arrayColumn.toColumn()).alias(name); + this.neededAggregations = Arrays.asList(arrayColumn); + } + + public Quantile(String name, Operation arrayColumn, float percent) { + super(name); + this.column = quantileUdf(percent).apply(arrayColumn.toColumn()).alias(name); + this.neededOperations = Arrays.asList(arrayColumn); + } +} diff --git a/src/main/java/io/atoti/spark/operation/QuantileIndex.java b/src/main/java/io/atoti/spark/operation/QuantileIndex.java new file mode 100644 index 0000000..0507453 --- /dev/null +++ b/src/main/java/io/atoti/spark/operation/QuantileIndex.java @@ -0,0 +1,57 @@ +package io.atoti.spark.operation; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.udf; + +import io.atoti.spark.Utils; +import io.atoti.spark.aggregation.AggregatedValue; + +import java.util.ArrayList; +import java.util.Arrays; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.expressions.UserDefinedFunction; +import org.apache.spark.sql.types.DataTypes; +import scala.Serializable; +import scala.collection.Seq; + +public final class QuantileIndex extends Operation { + + static UserDefinedFunction quantileIndexUdf(float percent) { + return udf( + new QuantileUdf(percent), + DataTypes.IntegerType); + } + + public QuantileIndex(String name, String arrayColumn, float percent) { + super(name); + this.column = quantileIndexUdf(percent).apply(col(arrayColumn)).alias(name); + } + + public QuantileIndex(String name, AggregatedValue arrayColumn, float percent) { + super(name); + this.column = quantileIndexUdf(percent).apply(arrayColumn.toColumn()).alias(name); + this.neededAggregations = Arrays.asList(arrayColumn); + } + + public QuantileIndex(String name, Operation arrayColumn, float percent) { + super(name); + this.column = quantileIndexUdf(percent).apply(arrayColumn.toColumn()).alias(name); + this.neededOperations = Arrays.asList(arrayColumn); + } + + private static class QuantileUdf implements UDF1, Integer>, Serializable { + private static final long serialVersionUID = 20220330_1217L; + + private final float percent; + + public QuantileUdf(float percent) { + this.percent = percent; + } + + @Override + public Integer call(Seq arr) { + ArrayList javaArr = Utils.convertScalaArrayToArray(arr); + return Utils.quantileIndex(javaArr, percent) + 1; + } + } +} diff --git a/src/main/java/io/atoti/spark/operation/VectorAt.java b/src/main/java/io/atoti/spark/operation/VectorAt.java new file mode 100644 index 0000000..37e6f04 --- /dev/null +++ b/src/main/java/io/atoti/spark/operation/VectorAt.java @@ -0,0 +1,29 @@ +package io.atoti.spark.operation; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.element_at; + +import io.atoti.spark.aggregation.AggregatedValue; + +import java.util.Arrays; +import java.util.List; + +public final class VectorAt extends Operation { + + public VectorAt(String name, String arrayColumn, int position) { + super(name); + this.column = element_at(col(arrayColumn), position).alias(name); + } + + public VectorAt(String name, AggregatedValue arrayColumn, int position) { + super(name); + this.column = element_at(arrayColumn.toColumn(), position).alias(name); + this.neededAggregations = Arrays.asList(arrayColumn); + } + + public VectorAt(String name, Operation arrayColumn, int position) { + super(name); + this.column = element_at(arrayColumn.toColumn(), position).alias(name); + this.neededOperations = Arrays.asList(arrayColumn); + } +} diff --git a/src/main/resources/csv/array.csv b/src/main/resources/csv/array.csv new file mode 100644 index 0000000..7deb9b3 --- /dev/null +++ b/src/main/resources/csv/array.csv @@ -0,0 +1,3 @@ +id;quantity;price;price_simulations +1;10;4.98;5,7,1 +2;20;1.95;2,3,1 diff --git a/src/test/java/io/atoti/spark/TestDiscovery.java b/src/test/java/io/atoti/spark/TestDiscovery.java index 53b8fe7..4490419 100644 --- a/src/test/java/io/atoti/spark/TestDiscovery.java +++ b/src/test/java/io/atoti/spark/TestDiscovery.java @@ -112,4 +112,15 @@ void testDiscoveryCalculate() { assertTrue(dTypes.containsKey("val1_equal_val2")); assertThat(dTypes.get("val1_equal_val2")).isEqualTo(DataTypes.BooleanType); } + + @Test + void testDiscoveryArray() { + Dataset dataframe = spark.read().table("array"); + final Map dTypes = Discovery.discoverDataframe(dataframe); + + assertThat(dataframe).isNotNull(); + assertThat(dTypes).isNotNull(); + assertThat(dTypes.get("price_simulations")) + .isEqualTo(DataTypes.createArrayType(DataTypes.LongType)); + } } diff --git a/src/test/java/io/atoti/spark/TestLocalVectorAggregation.java b/src/test/java/io/atoti/spark/TestLocalVectorAggregation.java new file mode 100644 index 0000000..696b0bd --- /dev/null +++ b/src/test/java/io/atoti/spark/TestLocalVectorAggregation.java @@ -0,0 +1,66 @@ +/* + * (C) ActiveViam 2022 + * ALL RIGHTS RESERVED. This material is the CONFIDENTIAL and PROPRIETARY + * property of ActiveViam. Any unauthorized use, + * reproduction or transfer of this material is strictly prohibited + */ +package io.atoti.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.atoti.spark.aggregation.SumArray; +import io.atoti.spark.aggregation.SumArrayLength; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.Test; + +public class TestLocalVectorAggregation { + + SparkSession spark = + SparkSession.builder().appName("Spark Atoti").config("spark.master", "local").getOrCreate(); + + public TestLocalVectorAggregation() { + spark.sparkContext().setLogLevel("ERROR"); + spark.sparkContext().addJar("./target/spark-lib-0.0.1-SNAPSHOT.jar"); + } + + @Test + void sumVector() { + final Dataset dataframe = CsvReader.read("csv/array.csv", spark); + var price_simulations = + new SumArray( + "price_simulations_bis", "price_simulations"); + var rows = + AggregateQuery.aggregate(dataframe, List.of("id"), List.of(), List.of()) + .collectAsList(); + + // result must have 2 values + assertThat(rows).hasSize(2); + + final var rowsById = + rows.stream().collect(Collectors.toUnmodifiableMap(row -> (row.getAs("id")), row -> (row))); + + assertThat((long) rowsById.get(1).getAs("vector-at")).isEqualTo(3); + + assertThat((long) rowsById.get(2).getAs("vector-at")).isEqualTo(1); + } + + @Test + void testSumArrayLength() { + final Dataset dataframe = CsvReader.read("csv/array.csv", spark); + var price_simulations = + new SumArrayLength( + "price_simulations_sum", "price_simulations"); + var rows = + AggregateQuery.aggregate( + dataframe, List.of(), List.of(price_simulations), List.of()) + .collectAsList(); + System.out.println(rows); + assertThat(rows).hasSize(1) + .extracting(row -> row.getLong(0)) + .containsExactlyInAnyOrder(6L); + } +} diff --git a/src/test/java/io/atoti/spark/TestVectorAggregation.java b/src/test/java/io/atoti/spark/TestVectorAggregation.java new file mode 100644 index 0000000..5e37b4d --- /dev/null +++ b/src/test/java/io/atoti/spark/TestVectorAggregation.java @@ -0,0 +1,323 @@ +/* + * (C) ActiveViam 2022 + * ALL RIGHTS RESERVED. This material is the CONFIDENTIAL and PROPRIETARY + * property of ActiveViam. Any unauthorized use, + * reproduction or transfer of this material is strictly prohibited + */ +package io.atoti.spark; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.atoti.spark.aggregation.Sum; +import io.atoti.spark.aggregation.SumArray; +import io.atoti.spark.aggregation.SumArrayLength; +import io.atoti.spark.operation.Multiply; +import io.atoti.spark.operation.Quantile; +import io.atoti.spark.operation.QuantileIndex; +import io.atoti.spark.operation.VectorAt; +import io.github.cdimascio.dotenv.Dotenv; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.Test; +import scala.collection.IndexedSeq; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +public class TestVectorAggregation { + + static Dotenv dotenv = Dotenv.load(); + static SparkSession spark = + SparkSession.builder() + .appName("Spark Atoti") + .config("spark.master", "local") + .config("spark.databricks.service.clusterId", dotenv.get("clusterId", "local-id")) + .getOrCreate(); + + public TestVectorAggregation() { + spark.sparkContext().setLogLevel("ERROR"); + spark.sparkContext().addJar("./target/spark-lib-0.0.1-SNAPSHOT.jar"); + } + + private static List convertScalaArrayToArray(Seq arr) { + return new ArrayList(JavaConverters.asJavaCollectionConverter(arr).asJavaCollection()); + } + + @SuppressWarnings("unchecked") + @Test + void quantile() { + + final Dataset dataframe = spark.read().table("array"); + var price_simulations = + new SumArray( + "price_simulations_sum", "price_simulations"); + var quantile = new Quantile("quantile", price_simulations, 95f); + var rows = + AggregateQuery.aggregate( + dataframe, List.of("id"), List.of(price_simulations), List.of(quantile)) + .collectAsList(); + + // result must have 2 values + assertThat(rows).hasSize(2); + + final var rowsById = + rows.stream().collect(Collectors.toUnmodifiableMap( + row -> row.getAs("id").intValue(), + row -> (row))); + + assertThat((long) rowsById.get(1).getAs("quantile")).isEqualTo(7L); + assertThat((long) rowsById.get(2).getAs("quantile")).isEqualTo(3L); + for (int i = 0; i < 3; i++) { + assertThat( + convertScalaArrayToArray( + rowsById.get(1).>getAs("price_simulations_sum")) + .get(i)) + .isEqualTo(List.of(3L, 7L, 5L).get(i)); + assertThat( + convertScalaArrayToArray( + rowsById.get(2).>getAs("price_simulations_sum")) + .get(i)) + .isEqualTo(List.of(1L, 3L, 2L).get(i)); + } + } + + @SuppressWarnings("unchecked") + @Test + void quantileIndex() { + + final Dataset dataframe = spark.read().table("array"); + var price_simulations = + new SumArray( + "price_simulations_sum", "price_simulations"); + var quantile = new QuantileIndex("quantile", price_simulations, 95f); + var rows = + AggregateQuery.aggregate( + dataframe, List.of("id"), List.of(price_simulations), List.of(quantile)) + .collectAsList(); + + // result must have 2 values + assertThat(rows).hasSize(2); + System.out.println(rows); + } + + // + @Test + void vectorAt() { + final Dataset dataframe = spark.read().table("array"); + var price_simulations = + new SumArray( + "price_simulations_bis", "price_simulations"); + var vectorAt = new VectorAt("vector-at", price_simulations, 1); + var rows = + AggregateQuery.aggregate(dataframe, List.of("id"), List.of(), List.of(vectorAt)) + .collectAsList(); + + // result must have 2 values + assertThat(rows).hasSize(2); + + final var rowsById = + rows.stream().collect(Collectors.toUnmodifiableMap(row -> row.getAs("id").intValue(), row -> (row))); + + assertThat((long) rowsById.get(1).getAs("vector-at")).isEqualTo(3); + + assertThat((long) rowsById.get(2).getAs("vector-at")).isEqualTo(1); + } + + @SuppressWarnings("unchecked") + @Test + void simpleAggregation() { + final Dataset dataframe = spark.read().table("array"); + var sumVector = + new SumArray("sum(vector)", "price_simulations"); + var rows = + AggregateQuery.aggregate(dataframe, List.of("id"), List.of(sumVector)).collectAsList(); + + // result must have 2 values + assertThat(rows).hasSize(2); + + final var rowsById = + rows.stream().collect(Collectors.toUnmodifiableMap( + row -> row.getAs("id").intValue(), row -> (row))); + + for (int i = 0; i < 3; i++) { + assertThat( + convertScalaArrayToArray(rowsById.get(1).>getAs("sum(vector)")) + .get(i)) + .isEqualTo(List.of(3L, 7L, 5L).get(i)); + assertThat( + convertScalaArrayToArray((Seq) rowsById.get(2).getAs("sum(vector)")) + .get(i)) + .isEqualTo(List.of(1L, 3L, 2L).get(i)); + } + } + + // + @SuppressWarnings("unchecked") + @Test + void vectorScaling() { + final Dataset dataframe = spark.read().table("array"); + final var f_vector = + new Multiply( + "f * vector", + new Sum("f", "price"), + new SumArray( + "sum(vector)", "price_simulations")); + var rows = + AggregateQuery.aggregate(dataframe, List.of("id"), List.of(), List.of(f_vector)) + .collectAsList(); + + // result must have 2 values + assertThat(rows).hasSize(2); + + final var rowsById = + rows.stream().collect(Collectors.toUnmodifiableMap(row -> (row.getAs("id").intValue()), row -> (row))); + + for (int i = 0; i < 3; i++) { + assertThat( + convertScalaArrayToArray((Seq) rowsById.get(1).getAs("f * vector")).get(i)) + .isEqualTo(List.of(15L, 35L, 25L).get(i)); + assertThat( + convertScalaArrayToArray((Seq) rowsById.get(2).getAs("f * vector")).get(i)) + .isEqualTo(List.of(2L, 6L, 4L).get(i)); + } + } + + // + @Test + void vectorQuantile() { + final Dataset dataframe = spark.read().table("simulations"); + var rows = + AggregateQuery.aggregate( + dataframe, + List.of("simulation"), + List.of(), + List.of( + new QuantileIndex( + "i95%", + new Multiply( + "f * vector", + new Sum("f", "factor-field"), + new SumArray( + "sum(vector)", + "vector-field" + )), + 95f))) + .collectAsList(); + + // result must have 3 values + assertThat(rows).hasSize(3); + + final var rowsById = + rows.stream() + .collect(Collectors.toUnmodifiableMap(row -> (row.getAs("simulation").intValue()), row -> (row))); + + assertThat((int) rowsById.get(1).getAs("i95%")).isEqualTo(7); + + assertThat((int) rowsById.get(2).getAs("i95%")).isEqualTo(7); + + assertThat((int) rowsById.get(3).getAs("i95%")).isEqualTo(7); + } + + // + @Test + void simulationExplorationAtQuantile() { + final Dataset dataframe = spark.read().table("simulations"); + final var revenues = + new Multiply( + "f * vector", + new Sum("f", "factor-field"), + new SumArray("sum(vector)", "vector-field")); + final List rows = + AggregateQuery.aggregate( + dataframe, + List.of("simulation"), + List.of(), + List.of( + new Quantile("v95%", revenues, 95f), new QuantileIndex("i95%", revenues, 95f))) + .collectAsList(); + final var compareByV95 = Comparator.comparingLong(row -> (Long) (row.getAs("v95%"))); + final int bestSimulation = + rows.stream() + .sorted(compareByV95.reversed()) + .mapToInt(row -> (int) (row.getAs("i95%"))) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No data to look at")); + final int worstSimulation = + rows.stream() + .sorted(compareByV95) + .mapToInt(row -> (int) (row.getAs("i95%"))) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No data to look at")); + // Study the chosen simulations on the prices per category + var result = + AggregateQuery.aggregate( + dataframe, + List.of("simulation"), + List.of(), + List.of( + new VectorAt("revenue-at-best", revenues, bestSimulation), + new VectorAt("revenue-at-worst", revenues, worstSimulation))) + .collectAsList(); + + // result must have 3 values + assertThat(result).hasSize(3); + + final var rowsById = + result.stream() + .collect(Collectors.toUnmodifiableMap(row -> (row.getAs("simulation").intValue()), row -> (row))); + + assertThat((long) rowsById.get(1).getAs("revenue-at-best")).isEqualTo(840L); + assertThat((long) rowsById.get(1).getAs("revenue-at-worst")).isEqualTo(840L); + + assertThat((long) rowsById.get(2).getAs("revenue-at-best")).isEqualTo(560L); + assertThat((long) rowsById.get(2).getAs("revenue-at-worst")).isEqualTo(560L); + + assertThat((long) rowsById.get(3).getAs("revenue-at-best")).isEqualTo(690L); + assertThat((long) rowsById.get(3).getAs("revenue-at-worst")).isEqualTo(690L); + } + + @Test + void testSumArrayLength() { + final Dataset dataframe = spark.read().table("array"); + var price_simulations = + new SumArrayLength( + "price_simulations_sum", "price_simulations"); + { + var totalRows = + AggregateQuery.aggregate( + dataframe, List.of(), List.of(price_simulations), List.of()) + .collectAsList(); + System.out.println(totalRows); + assertThat(totalRows).hasSize(1) + .extracting(row -> row.getLong(0)) + .containsExactlyInAnyOrder(6L); + } + { + var rowsById = + AggregateQuery.aggregate( + dataframe, List.of("id"), List.of(price_simulations), List.of()) + .collectAsList(); + System.out.println(rowsById); + assertThat(rowsById).hasSize(2) + .extracting(row -> row.getLong(1)) + .containsExactlyInAnyOrder(3L, 3L); + } + } + + @Test + void sumVector() { + final Dataset dataframe = spark.read().table("array"); + var price_simulations = + new SumArray( + "price_simulations_bis", "price_simulations"); + var rows = + AggregateQuery.aggregate(dataframe, List.of("id"), List.of(price_simulations), List.of()) + .collectAsList(); + + System.out.println(rows); + } +} diff --git a/src/test/resources/csv/array.csv b/src/test/resources/csv/array.csv new file mode 100644 index 0000000..5eed92c --- /dev/null +++ b/src/test/resources/csv/array.csv @@ -0,0 +1,3 @@ +id;quantity;price;price_simulations +1;10;5;3,7,5 +2;20;2;1,3,2 diff --git a/src/test/resources/csv/simulations.csv b/src/test/resources/csv/simulations.csv new file mode 100644 index 0000000..f4f702e --- /dev/null +++ b/src/test/resources/csv/simulations.csv @@ -0,0 +1,4 @@ +simulation;factor-field;vector-field +1;10;5,7,3,5,3,6,84 +2;20;2,3,2,4,2,3,28 +3;30;2,3,1,2,1,6,23