Skip to content

Commit

Permalink
Implement DistinctSelect conversion (#3494)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsrnhld authored Jul 16, 2024
1 parent edc3bf1 commit 86daceb
Show file tree
Hide file tree
Showing 15 changed files with 344 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import com.bakdata.conquery.models.query.queryplan.aggregators.Aggregator;
import com.bakdata.conquery.models.query.queryplan.aggregators.specific.value.AllValuesAggregator;
import com.bakdata.conquery.models.types.ResultType;
import com.bakdata.conquery.sql.conversion.model.select.DistinctSelectConverter;
import com.bakdata.conquery.sql.conversion.model.select.SelectConverter;
import com.fasterxml.jackson.annotation.JsonCreator;

@CPSType(id = "DISTINCT", base = Select.class)
Expand All @@ -29,4 +31,9 @@ public Aggregator<?> createAggregator() {
public ResultType getResultType() {
return new ResultType.ListT(super.getResultType());
}

@Override
public SelectConverter<DistinctSelect> createConverter() {
return new DistinctSelectConverter();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ protected String print(PrintSettings cfg, @NonNull Object f) {

public abstract T getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException;

protected List<T> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
throw new UnsupportedOperationException("ResultType list of type %s not supported for now.".formatted(getClass().getSimpleName()));
}
protected abstract List<T> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException;

public static ResultType<?> resolveResultType(MajorTypeId majorTypeId) {
return switch (majorTypeId) {
Expand Down Expand Up @@ -99,6 +97,11 @@ public String print(PrintSettings cfg, Object f) {
public Boolean getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getBoolean(resultSet, columnIndex);
}

@Override
protected List<Boolean> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getBooleanList(resultSet, columnIndex);
}
}


Expand All @@ -120,6 +123,11 @@ public String print(PrintSettings cfg, Object f) {
public Integer getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getInteger(resultSet, columnIndex);
}

@Override
protected List<Integer> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getIntegerList(resultSet, columnIndex);
}
}

@CPSType(id = "NUMERIC", base = ResultType.class)
Expand All @@ -140,6 +148,11 @@ public String print(PrintSettings cfg, Object f) {
public Double getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getDouble(resultSet, columnIndex);
}

@Override
protected List<Double> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getDoubleList(resultSet, columnIndex);
}
}

@CPSType(id = "DATE", base = ResultType.class)
Expand All @@ -162,6 +175,11 @@ public Number getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetPr
return resultSetProcessor.getDate(resultSet, columnIndex);
}

@Override
protected List<Number> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getDateList(resultSet, columnIndex);
}

public static String print(Number num, DateTimeFormatter formatter) {
return formatter.format(LocalDate.ofEpochDay(num.intValue()));
}
Expand Down Expand Up @@ -276,6 +294,11 @@ public BigDecimal getFromResultSet(ResultSet resultSet, int columnIndex, ResultS
public BigDecimal readIntermediateValue(PrintSettings cfg, Number f) {
return new BigDecimal(f.longValue()).movePointLeft(cfg.getCurrency().getDefaultFractionDigits());
}

@Override
protected List<BigDecimal> getFromResultSetAsList(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
return resultSetProcessor.getMoneyList(resultSet, columnIndex);
}
}

@CPSType(id = "LIST", base = ResultType.class)
Expand Down Expand Up @@ -313,11 +336,12 @@ public String typeInfo() {

@Override
public List<T> getFromResultSet(ResultSet resultSet, int columnIndex, ResultSetProcessor resultSetProcessor) throws SQLException {
if (elementType.getClass() == DateRangeT.class || elementType.getClass() == StringT.class) {
return elementType.getFromResultSetAsList(resultSet, columnIndex, resultSetProcessor);
}
// TODO handle all other list types properly
throw new UnsupportedOperationException("Other result type lists not supported for now.");
return elementType.getFromResultSetAsList(resultSet, columnIndex, resultSetProcessor);
}

@Override
protected List<List<T>> getFromResultSetAsList(final ResultSet resultSet, final int columnIndex, final ResultSetProcessor resultSetProcessor) {
throw new UnsupportedOperationException("Nested lists not supported in SQL mode");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,6 @@ public QueryStep unnestValidityDate(QueryStep predecessor, String cteName) {
return predecessor;
}

@Override
public Field<String> stringAggregation(Field<String> stringField, Field<String> delimiter, List<Field<?>> orderByFields) {
return DSL.field(
"{0}({1}, {2} {3})",
String.class,
DSL.keyword("STRING_AGG"),
stringField,
delimiter,
DSL.orderBy(orderByFields)
);
}

@Override
public Field<String> daterangeStringAggregation(ColumnDateRange columnDateRange) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,6 @@ public QueryStep unnestValidityDate(QueryStep predecessor, String cteName) {
.build();
}

@Override
public Field<String> stringAggregation(Field<String> stringField, Field<String> delimiter, List<Field<?>> orderByFields) {
return DSL.field(
"{0}({1}, {2} {3})",
String.class,
DSL.keyword("string_agg"),
stringField,
delimiter,
DSL.orderBy(orderByFields)
);
}

@Override
public Field<String> daterangeStringAggregation(ColumnDateRange columnDateRange) {
Field<Object> asMultirange = rangeAgg(columnDateRange);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ public interface SqlFunctionProvider {
*/
QueryStep unnestValidityDate(QueryStep predecessor, String cteName);

Field<String> stringAggregation(Field<String> stringField, Field<String> delimiter, List<Field<?>> orderByFields);

/**
* Aggregates the start and end columns of the validity date of entries into one compound string expression.
* <p>
Expand Down Expand Up @@ -140,6 +138,17 @@ public interface SqlFunctionProvider {
*/
Field<String> yearQuarter(Field<Date> dateField);

default Field<String> stringAggregation(Field<String> stringField, Field<String> delimiter, List<Field<?>> orderByFields) {
return DSL.field(
"{0}({1}, {2} {3})",
String.class,
DSL.keyword("string_agg"),
stringField,
delimiter,
DSL.orderBy(orderByFields)
);
}

default Field<String> concat(List<Field<String>> fields) {
String concatenated = fields.stream()
// if a field is null, the whole concatenation would be null - but we just want to skip this field in this case,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ public class QueryStep {
*/
@Builder.Default
boolean unionAll = true;
/**
* Determines if the select should be distinct.
*/
boolean selectDistinct;
/**
* All {@link QueryStep}'s that shall be converted before this {@link QueryStep}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.jooq.Select;
import org.jooq.SelectConditionStep;
import org.jooq.SelectHavingStep;
import org.jooq.SelectSelectStep;
import org.jooq.impl.DSL;

/**
Expand Down Expand Up @@ -70,10 +71,15 @@ private CommonTableExpression<Record> toCte(QueryStep queryStep) {

private Select<Record> toSelectStep(QueryStep queryStep) {

Select<Record> selectStep = this.dslContext
.select(queryStep.getSelects().all())
.from(queryStep.getFromTables())
.where(queryStep.getConditions());
SelectSelectStep<Record> selectClause;
if (queryStep.isSelectDistinct()) {
selectClause = dslContext.selectDistinct(queryStep.getSelects().all());
}
else {
selectClause = dslContext.select(queryStep.getSelects().all());
}

Select<Record> selectStep = selectClause.from(queryStep.getFromTables()).where(queryStep.getConditions());

if (queryStep.isGroupBy()) {
selectStep = ((SelectConditionStep<Record>) selectStep).groupBy(queryStep.getGroupBy());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package com.bakdata.conquery.sql.conversion.model.select;

import static org.jooq.impl.DSL.field;
import static org.jooq.impl.DSL.name;

import java.util.List;
import java.util.Optional;

import com.bakdata.conquery.models.datasets.concepts.Connector;
import com.bakdata.conquery.models.datasets.concepts.select.connector.DistinctSelect;
import com.bakdata.conquery.sql.conversion.cqelement.concept.ConceptCteStep;
import com.bakdata.conquery.sql.conversion.cqelement.concept.ConnectorSqlTables;
import com.bakdata.conquery.sql.conversion.dialect.SqlFunctionProvider;
import com.bakdata.conquery.sql.conversion.model.CteStep;
import com.bakdata.conquery.sql.conversion.model.QueryStep;
import com.bakdata.conquery.sql.conversion.model.Selects;
import com.bakdata.conquery.sql.conversion.model.SqlIdColumns;
import com.bakdata.conquery.sql.execution.ResultSetProcessor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.jooq.Field;
import org.jooq.impl.DSL;
import org.jooq.impl.SQLDataType;

/**
* <pre>
* The two additional CTEs this aggregator creates:
* <ol>
* <li>
* Select distinct values of a column.
* {@code
* "distinct" as (
* select distinct "pid", "column"
* from "event_filter"
* )
* }
* </li>
* <li>
* String agg all distinct values of the column.
* {@code
* "aggregated" as (
* select
* "select-1-distinct"."pid",
* string_agg(cast("column" as varchar), cast(' ' as varchar) ) as "select-1"
* from "distinct"
* group by "pid"
* )
* }
* </li>
* </ol>
* </pre>
*/
public class DistinctSelectConverter implements SelectConverter<DistinctSelect> {

@Getter
@RequiredArgsConstructor
private enum DistinctSelectCteStep implements CteStep {

DISTINCT_SELECT("distinct", null),
STRING_AGG("aggregated", DISTINCT_SELECT);

private final String suffix;
private final DistinctSelectCteStep predecessor;
}

@Override
public ConnectorSqlSelects connectorSelect(DistinctSelect distinctSelect, SelectContext<Connector, ConnectorSqlTables> selectContext) {

String alias = selectContext.getNameGenerator().selectName(distinctSelect);

ConnectorSqlTables tables = selectContext.getTables();
FieldWrapper<Object> preprocessingSelect = new FieldWrapper<>(field(name(tables.getRootTable(), distinctSelect.getColumn().getName())).as(alias));

QueryStep distinctSelectCte = createDistinctSelectCte(preprocessingSelect, alias, selectContext);
QueryStep aggregatedCte = createAggregationCte(selectContext, preprocessingSelect, distinctSelectCte, alias);

ExtractingSqlSelect<Object> finalSelect = preprocessingSelect.qualify(tables.cteName(ConceptCteStep.AGGREGATION_FILTER));

return ConnectorSqlSelects.builder()
.preprocessingSelect(preprocessingSelect)
.additionalPredecessor(Optional.of(aggregatedCte))
.finalSelect(finalSelect)
.build();
}

private static QueryStep createAggregationCte(
SelectContext<Connector, ConnectorSqlTables> selectContext,
FieldWrapper<Object> preprocessingSelect,
QueryStep distinctSelectCte,
String alias
) {
SqlFunctionProvider functionProvider = selectContext.getFunctionProvider();
Field<String> castedColumn = functionProvider.cast(preprocessingSelect.qualify(distinctSelectCte.getCteName()).select(), SQLDataType.VARCHAR);
Field<String> aggregatedColumn = functionProvider.stringAggregation(castedColumn, DSL.toChar(ResultSetProcessor.UNIT_SEPARATOR), List.of(castedColumn))
.as(alias);

SqlIdColumns ids = distinctSelectCte.getQualifiedSelects().getIds();

Selects selects = Selects.builder()
.ids(ids)
.sqlSelect(new FieldWrapper<>(aggregatedColumn))
.build();

return QueryStep.builder()
.cteName(selectContext.getNameGenerator().cteStepName(DistinctSelectCteStep.STRING_AGG, alias))
.selects(selects)
.fromTable(QueryStep.toTableLike(distinctSelectCte.getCteName()))
.groupBy(ids.toFields())
.predecessor(distinctSelectCte)
.build();
}

private static QueryStep createDistinctSelectCte(
FieldWrapper<Object> preprocessingSelect,
String alias,
SelectContext<Connector, ConnectorSqlTables> selectContext
) {
// values to aggregate must be event-filtered first
String eventFilterTable = selectContext.getTables().cteName(ConceptCteStep.EVENT_FILTER);
ExtractingSqlSelect<Object> qualified = preprocessingSelect.qualify(eventFilterTable);
SqlIdColumns ids = selectContext.getIds().qualify(eventFilterTable);

Selects selects = Selects.builder()
.ids(ids)
.sqlSelect(qualified)
.build();

return QueryStep.builder()
.cteName(selectContext.getNameGenerator().cteStepName(DistinctSelectCteStep.DISTINCT_SELECT, alias))
.selectDistinct(true)
.selects(selects)
.fromTable(QueryStep.toTableLike(eventFilterTable))
.build();
}
}
Loading

0 comments on commit 86daceb

Please sign in to comment.