From 948a031fe2a300537385dba38ac6b402d079b20d Mon Sep 17 00:00:00 2001 From: James Petty Date: Mon, 13 Nov 2023 17:30:07 -0500 Subject: [PATCH] Improve Variable#increment() bytecode generation Modifies Variable#increment() to generate IINC instructions where possible instead of multi-bytecode sequences of load, add, store that it previously used. Also fixes handling of incrementing primitive long values to avoid failing as a result of a type mismatch between the variable's long type and the type of constantInt(1). --- .../java/io/airlift/bytecode/Variable.java | 54 ++++++++++++++++++- .../TestSetVariableBytecodeExpression.java | 42 +++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/airlift/bytecode/Variable.java b/src/main/java/io/airlift/bytecode/Variable.java index 4fe63a0..504a599 100644 --- a/src/main/java/io/airlift/bytecode/Variable.java +++ b/src/main/java/io/airlift/bytecode/Variable.java @@ -21,7 +21,7 @@ import static io.airlift.bytecode.ParameterizedType.type; import static io.airlift.bytecode.expression.BytecodeExpressions.add; -import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; import static java.util.Objects.requireNonNull; public class Variable @@ -47,7 +47,15 @@ public BytecodeExpression set(BytecodeExpression value) public BytecodeExpression increment() { - return new SetVariableBytecodeExpression(this, add(this, constantInt(1))); + if (IntegerIncrementVariableBytecodeExpression.isSupportedType(getType())) { + return new IntegerIncrementVariableBytecodeExpression(this); + } + else if (getType().getPrimitiveType() == long.class) { + return new SetVariableBytecodeExpression(this, add(this, constantLong(1))); + } + else { + throw new UnsupportedOperationException("Variable %s of type %s does not support incrementing".formatted(getName(), getType())); + } } @Override @@ -68,6 +76,48 @@ public List getChildNodes() return ImmutableList.of(); } + private static final class IntegerIncrementVariableBytecodeExpression + extends BytecodeExpression + { + private final Variable variable; + + public IntegerIncrementVariableBytecodeExpression(Variable variable) + { + super(type(void.class)); + this.variable = requireNonNull(variable, "variable is null"); + if (!isSupportedType(variable.getType())) { + throw new IllegalArgumentException("Variable %s of type %s is not supported for integer increment".formatted(variable.getName(), variable.getType())); + } + } + + @Override + public BytecodeNode getBytecode(MethodGenerationContext generationContext) + { + return VariableInstruction.incrementVariable(variable, (byte) 1); + } + + @Override + public List getChildNodes() + { + return ImmutableList.of(); + } + + @Override + protected String formatOneLine() + { + return variable.getName() + "++"; + } + + public static boolean isSupportedType(ParameterizedType type) + { + if (!type.isPrimitive()) { + return false; + } + Class primitiveType = type.getPrimitiveType(); + return primitiveType == byte.class || primitiveType == short.class || primitiveType == int.class; + } + } + private static final class SetVariableBytecodeExpression extends BytecodeExpression { diff --git a/src/test/java/io/airlift/bytecode/expression/TestSetVariableBytecodeExpression.java b/src/test/java/io/airlift/bytecode/expression/TestSetVariableBytecodeExpression.java index 01639f5..44a6cc5 100644 --- a/src/test/java/io/airlift/bytecode/expression/TestSetVariableBytecodeExpression.java +++ b/src/test/java/io/airlift/bytecode/expression/TestSetVariableBytecodeExpression.java @@ -25,11 +25,53 @@ import static io.airlift.bytecode.ParameterizedType.type; import static io.airlift.bytecode.expression.BytecodeExpressionAssertions.assertBytecodeNode; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance; import static org.testng.Assert.assertEquals; public class TestSetVariableBytecodeExpression { + @Test + public void testIncrement() + throws Exception + { + assertBytecodeNode(scope -> { + Variable byteValue = scope.declareVariable(byte.class, "byte"); + assertEquals(byteValue.increment().toString(), "byte++;"); + return new BytecodeBlock() + .append(byteValue.set(constantInt(0))) + .append(byteValue.increment()) + .append(byteValue.ret()); + }, type(byte.class), (byte) 1); + + assertBytecodeNode(scope -> { + Variable shortValue = scope.declareVariable(short.class, "short"); + assertEquals(shortValue.increment().toString(), "short++;"); + return new BytecodeBlock() + .append(shortValue.set(constantInt(0))) + .append(shortValue.increment()) + .append(shortValue.ret()); + }, type(short.class), (short) 1); + + assertBytecodeNode(scope -> { + Variable intValue = scope.declareVariable(int.class, "int"); + assertEquals(intValue.increment().toString(), "int++;"); + return new BytecodeBlock() + .append(intValue.set(constantInt(0))) + .append(intValue.increment()) + .append(intValue.ret()); + }, type(int.class), 1); + + assertBytecodeNode(scope -> { + Variable longValue = scope.declareVariable(long.class, "long"); + assertEquals(longValue.increment().toString(), "long = (long + 1L);"); + return new BytecodeBlock() + .append(longValue.set(constantLong(0))) + .append(longValue.increment()) + .append(longValue.ret()); + }, type(long.class), 1L); + } + @Test public void testGetField() throws Exception