Skip to content

Commit

Permalink
update org.junit.Assert to org.junit.jupiter.api.Assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
llama90 committed Jun 16, 2024
1 parent 801de2f commit 8377b66
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/
package org.apache.arrow.gandiva.evaluator;

import static org.junit.jupiter.api.Assertions.assertTrue;

import org.apache.arrow.vector.types.pojo.ArrowType;
import org.junit.Assert;
import org.junit.Test;

public class DecimalTypeUtilTest {
Expand All @@ -29,35 +30,35 @@ public void testOutputTypesForAdd() {
ArrowType.Decimal resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, operand1, operand2);
Assert.assertTrue(getDecimal(31, 10).equals(resultType));
assertTrue(getDecimal(31, 10).equals(resultType));

operand1 = getDecimal(30, 6);
operand2 = getDecimal(30, 5);
resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, operand1, operand2);
Assert.assertTrue(getDecimal(32, 6).equals(resultType));
assertTrue(getDecimal(32, 6).equals(resultType));

operand1 = getDecimal(30, 10);
operand2 = getDecimal(38, 10);
resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, operand1, operand2);
Assert.assertTrue(getDecimal(38, 9).equals(resultType));
assertTrue(getDecimal(38, 9).equals(resultType));

operand1 = getDecimal(38, 10);
operand2 = getDecimal(38, 38);
resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, operand1, operand2);
Assert.assertTrue(getDecimal(38, 9).equals(resultType));
assertTrue(getDecimal(38, 9).equals(resultType));

operand1 = getDecimal(38, 10);
operand2 = getDecimal(38, 2);
resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, operand1, operand2);
Assert.assertTrue(getDecimal(38, 6).equals(resultType));
assertTrue(getDecimal(38, 6).equals(resultType));
}

@Test
Expand All @@ -67,14 +68,14 @@ public void testOutputTypesForMultiply() {
ArrowType.Decimal resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, operand1, operand2);
Assert.assertTrue(getDecimal(38, 6).equals(resultType));
assertTrue(getDecimal(38, 6).equals(resultType));

operand1 = getDecimal(38, 10);
operand2 = getDecimal(9, 2);
resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, operand1, operand2);
Assert.assertTrue(getDecimal(38, 6).equals(resultType));
assertTrue(getDecimal(38, 6).equals(resultType));
}

@Test
Expand All @@ -84,7 +85,7 @@ public void testOutputTypesForMod() {
ArrowType.Decimal resultType =
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MOD, operand1, operand2);
Assert.assertTrue(getDecimal(30, 10).equals(resultType));
assertTrue(getDecimal(30, 10).equals(resultType));
}

private ArrowType.Decimal getDecimal(int precision, int scale) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
*/
package org.apache.arrow.gandiva.evaluator;

import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.common.collect.Lists;
import java.util.Set;
import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.junit.Assert;
import org.junit.Test;

public class ExpressionRegistryTest {
Expand All @@ -29,7 +30,7 @@ public class ExpressionRegistryTest {
public void testTypes() throws GandivaException {
Set<ArrowType> types = ExpressionRegistry.getInstance().getSupportedTypes();
ArrowType.Int uint8 = new ArrowType.Int(8, false);
Assert.assertTrue(types.contains(uint8));
assertTrue(types.contains(uint8));
}

@Test
Expand All @@ -38,7 +39,7 @@ public void testFunctions() throws GandivaException {
FunctionSignature signature =
new FunctionSignature("add", uint8, Lists.newArrayList(uint8, uint8));
Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions();
Assert.assertTrue(functions.contains(signature));
assertTrue(functions.contains(signature));
}

@Test
Expand All @@ -47,7 +48,7 @@ public void testFunctionAliases() throws GandivaException {
FunctionSignature signature =
new FunctionSignature("modulo", int64, Lists.newArrayList(int64, int64));
Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions();
Assert.assertTrue(functions.contains(signature));
assertTrue(functions.contains(signature));
}

@Test
Expand All @@ -57,6 +58,6 @@ public void testCaseInsensitiveFunctionName() throws GandivaException {
FunctionSignature signature =
new FunctionSignature("castvarchar", utf8, Lists.newArrayList(utf8, int64));
Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions();
Assert.assertTrue(functions.contains(signature));
assertTrue(functions.contains(signature));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
package org.apache.arrow.gandiva.evaluator;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

import com.google.common.collect.Lists;
import java.util.ArrayList;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.arrow.gandiva.evaluator;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.nio.charset.Charset;
Expand All @@ -32,7 +34,6 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.Assert;
import org.junit.Test;

public class FilterTest extends BaseEvaluatorTest {
Expand Down Expand Up @@ -130,7 +131,7 @@ public void testSimpleInString() throws GandivaException, Exception {
releaseRecordBatch(batch);
selectionBuffer.close();
filter.close();
Assert.assertArrayEquals(expected, actual);
assertArrayEquals(expected, actual);
}

@Test
Expand Down Expand Up @@ -173,7 +174,7 @@ public void testSimpleInInt() throws GandivaException, Exception {
releaseRecordBatch(batch);
selectionBuffer.close();
filter.close();
Assert.assertArrayEquals(expected, actual);
assertArrayEquals(expected, actual);
}

@Test
Expand Down Expand Up @@ -323,6 +324,6 @@ private void verifyTestCase(
selectionBuffer.close();
filter.close();

Assert.assertArrayEquals(expected, actual);
assertArrayEquals(expected, actual);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.arrow.gandiva.evaluator;

import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.common.collect.Lists;
import java.util.List;
import org.apache.arrow.gandiva.expression.Condition;
Expand All @@ -24,7 +26,6 @@
import org.apache.arrow.gandiva.expression.TreeNode;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;

Expand Down Expand Up @@ -60,7 +61,7 @@ public void testAdd3() throws Exception {
16 * THOUSAND,
4);
System.out.println("Time taken for projecting 1m records of add3 is " + timeTaken + "ms");
Assert.assertTrue(timeTaken <= 13 * toleranceRatio);
assertTrue(timeTaken <= 13 * toleranceRatio);
}

@Test
Expand Down Expand Up @@ -124,7 +125,7 @@ public void testIf() throws Exception {
16 * THOUSAND,
4);
System.out.println("Time taken for projecting 10m records of nestedIf is " + timeTaken + "ms");
Assert.assertTrue(timeTaken <= 15 * toleranceRatio);
assertTrue(timeTaken <= 15 * toleranceRatio);
}

@Test
Expand Down Expand Up @@ -154,6 +155,6 @@ public void testFilterAdd2() throws Exception {
16 * THOUSAND,
4);
System.out.println("Time taken for filtering 10m records of a+b<c is " + timeTaken + "ms");
Assert.assertTrue(timeTaken <= 12 * toleranceRatio);
assertTrue(timeTaken <= 12 * toleranceRatio);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/
package org.apache.arrow.gandiva.evaluator;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.common.collect.Lists;
import java.math.BigDecimal;
Expand Down Expand Up @@ -106,7 +106,7 @@ public void test_add() throws GandivaException {
for (int i = 0; i < 4; i++) {
assertFalse(outVector.isNull(i));
assertTrue(
"index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i)) == 0);
expOutput[i].compareTo(outVector.getObject(i)) == 0, "index : " + i + " failed compare");
}

// free buffers
Expand Down Expand Up @@ -236,7 +236,7 @@ public void test_multiply() throws GandivaException {
for (int i = 0; i < 4; i++) {
assertFalse(outVector.isNull(i));
assertTrue(
"index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i)) == 0);
expOutput[i].compareTo(outVector.getObject(i)) == 0, "index : " + i + " failed compare");
}

// free buffers
Expand Down Expand Up @@ -322,9 +322,9 @@ public void testCompare() throws GandivaException {
for (int i = 0; i < numRows; i++) {
assertFalse(resultVector.isNull(i));
assertEquals(
"mismatch in result for expr at idx " + idx + " for row " + i,
expectedArray[i],
resultVector.getObject(i).booleanValue());
resultVector.getObject(i).booleanValue(),
"mismatch in result for expr at idx " + idx + " for row " + i);
}
}
} finally {
Expand Down Expand Up @@ -457,6 +457,7 @@ public void testRound() throws GandivaException {
for (int i = 0; i < numRows; i++) {
assertFalse(resultVector.isNull(i));
assertTrue(
expectedArray[i].compareTo(resultVector.getObject(i)) == 0,
"mismatch in result for "
+ "field "
+ resultVector.getField().getName()
Expand All @@ -465,8 +466,7 @@ public void testRound() throws GandivaException {
+ " expected "
+ expectedArray[i]
+ ", got "
+ resultVector.getObject(i),
expectedArray[i].compareTo(resultVector.getObject(i)) == 0);
+ resultVector.getObject(i));
}
}
} finally {
Expand Down Expand Up @@ -583,6 +583,7 @@ public void testCastToDecimal() throws GandivaException {
for (int i = 0; i < numRows; i++) {
assertFalse(resultVector.isNull(i));
assertTrue(
expectedArray[i].compareTo(resultVector.getObject(i)) == 0,
"mismatch in result for "
+ "field "
+ resultVector.getField().getName()
Expand All @@ -591,8 +592,7 @@ public void testCastToDecimal() throws GandivaException {
+ " expected "
+ expectedArray[i]
+ ", got "
+ resultVector.getObject(i),
expectedArray[i].compareTo(resultVector.getObject(i)) == 0);
+ resultVector.getObject(i));
}
}
} finally {
Expand Down Expand Up @@ -818,6 +818,7 @@ public void testCastStringToDecimal() throws GandivaException {
// compare the outputs.
for (int i = 0; i < numRows; i++) {
assertTrue(
expected[i].compareTo(resultVector.getObject(i)) == 0,
"mismatch in result for "
+ "field "
+ resultVector.getField().getName()
Expand All @@ -826,8 +827,7 @@ public void testCastStringToDecimal() throws GandivaException {
+ " expected "
+ expected[i]
+ ", got "
+ resultVector.getObject(i),
expected[i].compareTo(resultVector.getObject(i)) == 0);
+ resultVector.getObject(i));
}
} finally {
// free buffers
Expand Down
Loading

0 comments on commit 8377b66

Please sign in to comment.