Skip to content
This repository has been archived by the owner on May 18, 2022. It is now read-only.

Commit

Permalink
Merge pull request #23 from activeviam/databricks-cluster-vector
Browse files Browse the repository at this point in the history
Databricks cluster vector
  • Loading branch information
arnaudframmery authored Mar 31, 2022
2 parents 65de8e5 + 5ceddab commit bc4ff47
Show file tree
Hide file tree
Showing 19 changed files with 1,154 additions and 17 deletions.
7 changes: 5 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@
<!-- Only use this version for creating the jar with command ./mvnw -P jar-creation clean package clean -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.13</artifactId>
<version>3.2.0</version>
<artifactId>spark-sql_2.12</artifactId>
<version>3.1.2</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
Expand All @@ -227,6 +228,7 @@
<include>io/atoti/spark/condition/**/*.java</include>
<include>io/atoti/spark/aggregation/**/*.java</include>
<include>io/atoti/spark/join/**/*.java</include>
<include>io/atoti/spark/operation/**/*.java</include>
<include>io/atoti/spark/ListQuery.java</include>
</includes>
</configuration>
Expand Down Expand Up @@ -257,6 +259,7 @@
<include>io/atoti/spark/condition/**/*.java</include>
<include>io/atoti/spark/aggregation/**/*.java</include>
<include>io/atoti/spark/join/**/*.java</include>
<include>io/atoti/spark/operation/**/*.java</include>
<include>io/atoti/spark/ListQuery.java</include>
</includes>
<excludes>
Expand Down
87 changes: 72 additions & 15 deletions src/main/java/io/atoti/spark/AggregateQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import io.atoti.spark.aggregation.AggregatedValue;
import io.atoti.spark.condition.QueryCondition;
import io.atoti.spark.condition.TrueCondition;
import io.atoti.spark.operation.Operation;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand All @@ -29,27 +32,81 @@ public static Dataset<Row> aggregate(
List<String> groupByColumns,
List<AggregatedValue> aggregations,
QueryCondition condition) {
if (aggregations.isEmpty()) {
return aggregate(dataframe, groupByColumns, aggregations, Arrays.asList(), condition);
}

public static Dataset<Row> aggregate(
Dataset<Row> dataframe,
List<String> groupByColumns,
List<AggregatedValue> aggregations,
List<Operation> operations) {
return aggregate(dataframe, groupByColumns, aggregations, operations, TrueCondition.value());
}

public static Dataset<Row> aggregate(
Dataset<Row> dataframe,
List<String> groupByColumns,
List<AggregatedValue> aggregations,
List<Operation> operations,
QueryCondition condition) {
if (aggregations.isEmpty() && operations.isEmpty()) {
throw new IllegalArgumentException(
"#aggregate can only be called with at least one AggregatedValue");
"#aggregate can only be called with at least one AggregatedValue or Operation");
}

final Column[] columns = groupByColumns.stream().map(functions::col).toArray(Column[]::new);
final Column[] createdColumns =
final Column[] createdAggregatedColumns =
aggregations.stream().map(AggregatedValue::toColumn).toArray(Column[]::new);
final Column[] columnsToSelect = Arrays.copyOf(columns, columns.length + createdColumns.length);
System.arraycopy(createdColumns, 0, columnsToSelect, columns.length, createdColumns.length);
final Column[] createdOperationColumns =
operations.stream().map(Operation::toColumn).toArray(Column[]::new);
final Column[] columnsToSelect =
Arrays.copyOf(
columns,
columns.length + createdAggregatedColumns.length + createdOperationColumns.length);
System.arraycopy(
createdAggregatedColumns,
0,
columnsToSelect,
columns.length,
createdAggregatedColumns.length);
System.arraycopy(
createdOperationColumns,
0,
columnsToSelect,
columns.length + createdAggregatedColumns.length,
createdOperationColumns.length);

// Add needed aggregations for operations to the `aggregations` list
final List<AggregatedValue> neededAggregations =
Stream.concat(
operations.stream().flatMap(Operation::getNeededAggregations),
aggregations.stream())
.distinct()
.collect(Collectors.toList());

// Aggregations
if (!neededAggregations.isEmpty()) {
final Column firstAggColumn = neededAggregations.get(0).toAggregateColumn();
final Column[] nextAggColumns =
neededAggregations.subList(1, neededAggregations.size()).stream()
.map(AggregatedValue::toAggregateColumn)
.toArray(Column[]::new);
dataframe =
dataframe
.filter(condition.getCondition())
.groupBy(columns)
.agg(firstAggColumn, nextAggColumns);
}

// Operations
if (!operations.isEmpty()) {
for (Operation op :
operations.stream().flatMap(Operation::getAllOperations).distinct().collect(Collectors.toList())) {
dataframe = dataframe.withColumn(op.getName(), op.toAggregateColumn());
}
}

final Column firstAggColumn = aggregations.get(0).toAggregateColumn();
final Column[] nextAggColumns =
aggregations.subList(1, aggregations.size()).stream()
.map(AggregatedValue::toAggregateColumn)
.toArray(Column[]::new);
return dataframe
.filter(condition.getCondition())
.groupBy(columns)
.agg(firstAggColumn, nextAggColumns)
.select(columnsToSelect);
return dataframe.select(columnsToSelect);
}

public static Dataset<Row> aggregateSql(
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/io/atoti/spark/CsvReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;

public class CsvReader {

public static Dataset<Row> read(String path, SparkSession session) {
return CsvReader.read(path, session, ";");
}
Expand All @@ -27,6 +30,20 @@ public static Dataset<Row> read(String path, SparkSession session, String separa
.option("timestampFormat", "dd/MM/yyyy")
.option("inferSchema", true)
.load(url.toURI().getPath());
for (int i = 0; i < dataframe.columns().length; i++) {
if (dataframe.dtypes()[i]._2 == DataTypes.StringType.toString()) {
String col = dataframe.columns()[i];
if (dataframe
.filter(functions.col(col).$eq$eq$eq("").$bar$bar(functions.col(col).rlike(",")))
.count()
== dataframe.count()) {
// This prototype only supports arrays of integers
dataframe =
dataframe.withColumn(
col, functions.split(functions.col(col), ",").cast("array<long>"));
}
}
}
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Failed to read csv " + path, e);
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/io/atoti/spark/Discovery.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ private Discovery() {}
DataTypes.ShortType,
DataTypes.StringType,
DataTypes.TimestampType,
DataTypes.createArrayType(DataTypes.LongType),
})
.collect(Collectors.toMap(DataType::toString, Function.identity()));

Expand Down
36 changes: 36 additions & 0 deletions src/main/java/io/atoti/spark/Main.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.atoti.spark;

import io.atoti.spark.aggregation.SumArray;

import java.util.Arrays;
import java.util.List;

import io.atoti.spark.operation.Quantile;
import io.github.cdimascio.dotenv.Dotenv;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class Main {

static Dotenv dotenv = Dotenv.load();
static SparkSession spark =
SparkSession.builder()
.appName("Spark Atoti")
.config("spark.master", "local")
.config("spark.databricks.service.clusterId", dotenv.get("clusterId"))
.getOrCreate();

public static void main(String[] args) {
spark.sparkContext().addJar("./target/spark-lib-0.0.1-SNAPSHOT.jar");
final Dataset<Row> dataframe = spark.read().table("array");
SumArray price_simulations = new SumArray(
"price_simulations_sum", "price_simulations"
);
Quantile quantile = new Quantile("quantile", price_simulations, 95f);
List<Row> rows = AggregateQuery.aggregate(
dataframe, Arrays.asList("id"), Arrays.asList(price_simulations), Arrays.asList(quantile))
.collectAsList();
System.out.println(rows);
}
}
83 changes: 83 additions & 0 deletions src/main/java/io/atoti/spark/Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package io.atoti.spark;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import scala.collection.JavaConverters;
import scala.collection.Seq;

class ArrayElement {
int index;
long value;

public ArrayElement(int index, long value) {
this.index = index;
this.value = value;
}
}

public class Utils {
public static <T> ArrayList<T> convertScalaArrayToArray(Seq<T> arr) {
return new ArrayList<T>(JavaConverters.asJavaCollectionConverter(arr).asJavaCollection());
}

public static Seq<Long> convertToArrayListToScalaArraySeq(List<Long> arr) {
return JavaConverters.asScalaBuffer(arr).iterator().toSeq();
}

public static long t(ArrayElement a, ArrayElement b) {
return b.value - a.value;
}

public static PriorityQueue<ArrayElement> constructMaxHeap(ArrayList<Long> arr) {
PriorityQueue<ArrayElement> pq =
new PriorityQueue<ArrayElement>(
(ArrayElement a, ArrayElement b) -> Long.compare(a.value, b.value));
pq.addAll(
IntStream.range(0, arr.size())
.mapToObj((int k) -> new ArrayElement(k, arr.get(k)))
.collect(Collectors.toList()));
return pq;
}

public static long quantile(ArrayList<Long> arr, float percent) {
PriorityQueue<ArrayElement> pq = constructMaxHeap(arr);
int index = (int) Math.floor(arr.size() * (100 - percent) / 100);

for (int i = arr.size() - 1; i > index; i--) {
pq.poll();
}

return pq.poll().value;
}

public static int quantileIndex(ArrayList<Long> arr, float percent) {
PriorityQueue<ArrayElement> pq = constructMaxHeap(arr);
int index = (int) Math.floor(arr.size() * (100 - percent) / 100);

for (int i = arr.size() - 1; i > index; i--) {
pq.poll();
}

return pq.poll().index;
}

public static int findKthLargestElement(ArrayList<Integer> arr, int k) {
if (k < arr.size()) {
throw new ArrayIndexOutOfBoundsException();
}

PriorityQueue<Integer> pq = new PriorityQueue<Integer>(Comparator.reverseOrder());

pq.addAll(arr);

for (int i = 0; i < k; i++) {
pq.poll();
}

return pq.peek();
}
}
Loading

0 comments on commit bc4ff47

Please sign in to comment.