Skip to content

Commit

Permalink
fix min/max aggregation on reversed types
Browse files Browse the repository at this point in the history
  • Loading branch information
smiklosovic committed Feb 5, 2025
1 parent 69dc5d0 commit 5a6291b
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 58 deletions.
126 changes: 68 additions & 58 deletions src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,64 @@ public void addInput(Arguments arguments)
}
};

private static abstract class AbstractNativeAggregate implements AggregateFunction.Aggregate
{
private final AbstractType<?> returnType;

public AbstractNativeAggregate(AbstractType<?> returnType)
{
this.returnType = returnType;
}

private ByteBuffer result;

public void reset()
{
result = null;
}

public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return result;
}

@Override
public void addInput(Arguments arguments)
{
ByteBuffer value = arguments.get(0);

if (value == null)
return;

if (returnType.isReversed())
{
if (result == null || comparision(returnType, result, value))
result = value;
}
else
{
if (result == null || !comparision(returnType, result, value))
result = value;
}
}

public abstract boolean comparision(AbstractType<?> returnType, ByteBuffer result, ByteBuffer value);
}

private static abstract class AbstractNativeAggregateFunction extends NativeAggregateFunction
{
protected AbstractNativeAggregateFunction(String name, AbstractType<?> returnType, AbstractType<?>... argTypes)
{
super(name, returnType, argTypes);
}

@Override
public Arguments newArguments(ProtocolVersion version)
{
return FunctionArguments.newNoopInstance(version, 1);
}
}

/**
* Creates a MAX function for the specified type.
*
Expand All @@ -844,41 +902,17 @@ public void addInput(Arguments arguments)
*/
public static NativeAggregateFunction makeMaxFunction(final AbstractType<?> inputType)
{
return new NativeAggregateFunction("max", inputType, inputType)
return new AbstractNativeAggregateFunction("max", inputType, inputType)
{
@Override
public Arguments newArguments(ProtocolVersion version)
{
return FunctionArguments.newNoopInstance(version, 1);
}

@Override
public Aggregate newAggregate()
public Aggregate newAggregate() throws InvalidRequestException
{
return new Aggregate()
return new AbstractNativeAggregate(returnType())
{
private ByteBuffer max;

public void reset()
{
max = null;
}

public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return max;
}

@Override
public void addInput(Arguments arguments)
public boolean comparision(AbstractType<?> returnType, ByteBuffer result, ByteBuffer value)
{
ByteBuffer value = arguments.get(0);

if (value == null)
return;

if (max == null || returnType().compare(max, value) < 0)
max = value;
return returnType().compare(result, value) > 0;
}
};
}
Expand All @@ -893,41 +927,17 @@ public void addInput(Arguments arguments)
*/
public static NativeAggregateFunction makeMinFunction(final AbstractType<?> inputType)
{
return new NativeAggregateFunction("min", inputType, inputType)
return new AbstractNativeAggregateFunction("min", inputType, inputType)
{
@Override
public Arguments newArguments(ProtocolVersion version)
{
return FunctionArguments.newNoopInstance(version, 1);
}

@Override
public Aggregate newAggregate()
public Aggregate newAggregate() throws InvalidRequestException
{
return new Aggregate()
return new AbstractNativeAggregate(returnType())
{
private ByteBuffer min;

public void reset()
{
min = null;
}

public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return min;
}

@Override
public void addInput(Arguments arguments)
public boolean comparision(AbstractType<?> returnType, ByteBuffer result, ByteBuffer value)
{
ByteBuffer value = arguments.get(0);

if (value == null)
return;

if (min == null || returnType().compare(min, value) > 0)
min = value;
return returnType().compare(result, value) < 0;
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,58 @@ public void testCountStarFunction() throws Throwable
assertRows(execute("SELECT max(b), COUNT(1), b FROM %s WHERE a = 1 LIMIT 2"), row(5, 4L, 1));
}

@Test
public void testMaxAggregationDescending()
{
createTable("CREATE TABLE %s (a int, b int, primary key (a, b)) WITH CLUSTERING ORDER BY (b DESC)");

execute("INSERT INTO %s (a, b) VALUES (1, 1000)");
execute("INSERT INTO %s (a, b) VALUES (2, 100)");
execute("INSERT INTO %s (a, b) VALUES (4, 1)");

assertRows(execute("SELECT count(b), max(b) as max FROM %s"),
row(3L, 1000));
}

@Test
public void testMinAggregationDescending()
{
createTable("CREATE TABLE %s (a int, b int, primary key (a, b)) WITH CLUSTERING ORDER BY (b DESC)");

execute("INSERT INTO %s (a, b) VALUES (1, 1000)");
execute("INSERT INTO %s (a, b) VALUES (2, 100)");
execute("INSERT INTO %s (a, b) VALUES (4, 1)");

assertRows(execute("SELECT count(b), min(b) as max FROM %s"),
row(3L, 1));
}

@Test
public void testMaxAggregationAscending()
{
createTable("CREATE TABLE %s (a int, b int, primary key (a, b)) WITH CLUSTERING ORDER BY (b ASC)");

execute("INSERT INTO %s (a, b) VALUES (1, 1000)");
execute("INSERT INTO %s (a, b) VALUES (2, 100)");
execute("INSERT INTO %s (a, b) VALUES (4, 1)");

assertRows(execute("SELECT count(b), max(b) as max FROM %s"),
row(3L, 1000));
}

@Test
public void testMinAggregationAscending()
{
createTable("CREATE TABLE %s (a int, b int, primary key (a, b)) WITH CLUSTERING ORDER BY (b ASC)");

execute("INSERT INTO %s (a, b) VALUES (1, 1000)");
execute("INSERT INTO %s (a, b) VALUES (2, 100)");
execute("INSERT INTO %s (a, b) VALUES (4, 1)");

assertRows(execute("SELECT count(b), min(b) as max FROM %s"),
row(3L, 1));
}

@Test
public void testAggregateWithColumns() throws Throwable
{
Expand Down

0 comments on commit 5a6291b

Please sign in to comment.