From e5ad786fdf66a94a3f53a92dc29be106a1791e2c Mon Sep 17 00:00:00 2001 From: aditya0811 Date: Mon, 29 Apr 2024 09:09:16 +0530 Subject: [PATCH] Issue #12367 --- .../function/CaseTransformFunctionTest.java | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java index 79a415e5a92f..315b53e9e08d 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java @@ -36,6 +36,7 @@ import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; public class CaseTransformFunctionTest extends BaseTransformFunctionTest { @@ -106,7 +107,8 @@ public void testCaseTransformFunctionWithIntResults() { testCaseQueries(String.format("%s(%s, %s)", functionType.getName(), LONG_SV_COLUMN, String.format("%d", _longSVValues[INDEX_TO_COMPARE])), getPredicateResults(LONG_SV_COLUMN, functionType)); testCaseQueries(String.format("%s(%s, %s)", functionType.getName(), FLOAT_SV_COLUMN, - String.format("%f", _floatSVValues[INDEX_TO_COMPARE])), getPredicateResults(FLOAT_SV_COLUMN, functionType)); + "CAST(" + String.format("%f", _floatSVValues[INDEX_TO_COMPARE]) + " AS FLOAT)"), + getPredicateResults(FLOAT_SV_COLUMN, functionType)); testCaseQueries(String.format("%s(%s, %s)", functionType.getName(), DOUBLE_SV_COLUMN, String.format("%.20f", _doubleSVValues[INDEX_TO_COMPARE])), getPredicateResults(DOUBLE_SV_COLUMN, functionType)); @@ -116,6 +118,33 @@ public void testCaseTransformFunctionWithIntResults() { } } + @Test + public void testCaseTransformFunctionWithoutCastForFloatValues() { + boolean[] predicateResults = new boolean[1]; + Arrays.fill(predicateResults, true); + int[] expectedValues = new int[1]; + int index = -1; + for (int i = 0; i < NUM_ROWS; i++) { + if (Double.compare(_floatSVValues[i], Double.parseDouble(String.format("%f", _floatSVValues[i]))) != 0) { + index = i; + expectedValues[0] = predicateResults[0] ? _intSVValues[i] : 10; + break; + } + } + + if (index != -1) { + String predicate = String.format("%s(%s, %s)", TransformFunctionType.EQUALS, FLOAT_SV_COLUMN, + String.format("%f", _floatSVValues[index])); + String expression = String.format("CASE WHEN %s THEN %s ELSE 10 END", predicate, INT_SV_COLUMN); + ExpressionContext expressionContext = RequestContextUtils.getExpression(expression); + TransformFunction transformFunction = TransformFunctionFactory.get(expressionContext, _dataSourceMap); + Assert.assertTrue(transformFunction instanceof CaseTransformFunction); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + int[] intValues = transformFunction.transformToIntValuesSV(_projectionBlock); + assertNotEquals(intValues[index], expectedValues[0]); + } + } + @DataProvider public static String[] illegalExpressions() { //@formatter:off