diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala index bebe464e7..8a68c1512 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala @@ -17,6 +17,10 @@ class Ast$Test extends AnyFunSpec { Expressions.parse("42 - 2").evaluateIntConst should be(Some(40)) } + it ("considers `(42 - 2)` constant") { + Expressions.parse("(42 - 2)").evaluateIntConst should be(Some(40)) + } + it ("considers `(-3 + 7) * 8 / 2` constant") { Expressions.parse("(-3 + 7) * 8 / 2").evaluateIntConst should be(Some(16)) } @@ -33,6 +37,10 @@ class Ast$Test extends AnyFunSpec { Expressions.parse("4 > 2 ? 1 : 5").evaluateIntConst should be(Some(1)) } + it ("considers `((4) > 2) ? ((1)) : 5` constant") { + Expressions.parse("((4) > 2) ? ((1)) : 5").evaluateIntConst should be(Some(1)) + } + it ("considers `x` variable") { Expressions.parse("x").evaluateIntConst should be(None) } diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala index 488c81861..39b3584d6 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala @@ -86,12 +86,17 @@ class ExpressionsSpec extends AnyFunSpec { ) } + it("parses group") { + Expressions.parse("(123)") should be (Group(IntNum(123))) + Expressions.parse("(foo)") should be (Group(Name(identifier("foo")))) + } + it("parses (1 + 2) / (7 * 8)") { Expressions.parse("(1 + 2) / (7 * 8)") should be ( BinOp( - BinOp(IntNum(1), Add, IntNum(2)), + Group(BinOp(IntNum(1), Add, IntNum(2))), Div, - BinOp(IntNum(7), Mult, IntNum(8)) + Group(BinOp(IntNum(7), Mult, IntNum(8))) ) ) } @@ -128,7 +133,7 @@ class ExpressionsSpec extends AnyFunSpec { } it("parses ~(7+3)") { - Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3)))) + Expressions.parse("~(7+3)") should be (UnaryOp(Invert, Group(BinOp(IntNum(7), Add, IntNum(3))))) } // Enums @@ -285,7 +290,7 @@ class ExpressionsSpec extends AnyFunSpec { it("parses (123).as") { Expressions.parse("(123).as") should be ( - CastToType(IntNum(123), typeId(false, Seq("u4"))) + CastToType(Group(IntNum(123)), typeId(false, Seq("u4"))) ) } diff --git a/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala b/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala index 392053e24..ed5f5734b 100644 --- a/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala +++ b/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala @@ -112,7 +112,7 @@ class AttrSpec$Test extends AnyFunSpec { spec.id should be(NamedIdentifier("foo")) val dataType = spec.dataType.asInstanceOf[UserType] dataType.name should be(List("bar")) - dataType.args should be(Seq(Ast.expr.IntNum(5))) + dataType.args should be(Seq(Ast.expr.Group(Ast.expr.IntNum(5)))) } } } diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala index 2acf1943b..9204a3e18 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala @@ -63,7 +63,7 @@ object Ast { } } - object expr{ + object expr { case class BoolOp(op: boolop, values: Seq[expr]) extends expr case class BinOp(left: expr, op: operator, right: expr) extends expr case class UnaryOp(op: unaryop, operand: expr) extends expr @@ -91,6 +91,9 @@ object Ast { case class List(elts: Seq[expr]) extends expr case class InterpolatedStr(elts: Seq[expr]) extends expr + /** Wraps `expr` into parentheses. Multiple nested groups merge into the one group */ + case class Group(expr: expr) extends expr + /** * Implicit declaration of ordering, so expressions can be used for ordering operations, e.g. * for `SortedMap.from(...)` diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala index d146b3b50..0ffda45b9 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala @@ -130,6 +130,7 @@ object ConstEvaluator { case _ => value.NonConst } + case expr.Group(expr) => evaluate(expr) case _ => value.NonConst } diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala index 0ef98ce71..d30806e90 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala @@ -117,12 +117,13 @@ object Expressions { case (lhs, trailers) => trailers.foldLeft(lhs)((l, t) => t(l)) } + def group[$: P]: P[Ast.expr] = ("(" ~ test ~ ")").map(expr => Ast.expr.Group(expr)) def empty_list[$: P] = P("[" ~ "]").map(_ => Ast.expr.List(Nil)) // def empty_dict[_: P] = P("{" ~ "}").map(_ => Ast.expr.Dict(Nil, Nil)) def atom[$: P]: P[Ast.expr] = P( empty_list | // empty_dict | - "(" ~ test ~ ")" | + group | "[" ~ list ~ "]" | // "{" ~ dictorsetmaker ~ "}" | enumByName | diff --git a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala index e59db2048..1745a6062 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala @@ -148,6 +148,12 @@ abstract class BaseTranslator(val provider: TypeProvider) doByteSizeOfType(typeName) case Ast.expr.BitSizeOfType(typeName) => doBitSizeOfType(typeName) + case Ast.expr.Group(nested) => + nested match { + // Unpack nested groups + case Ast.expr.Group(e) => translate(e) + case e => s"(${translate(e)})" + } } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala index 8d2741282..f35aa8574 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala @@ -87,6 +87,8 @@ class ExpressionValidator(val provider: TypeProvider) CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName)) case Ast.expr.InterpolatedStr(elts: Seq[Ast.expr]) => elts.foreach(validate) + case Ast.expr.Group(nested) => + validate(nested) } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala index 7e85ab670..a8d96a0d9 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala @@ -115,6 +115,12 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo doByteSizeOfType(typeName) case Ast.expr.BitSizeOfType(typeName) => doBitSizeOfType(typeName) + case Ast.expr.Group(nested) => + ResultString(nested match { + // Unpack nested groups + case Ast.expr.Group(e) => translate(e) + case e => s"(${translate(e)})" + }) } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala index 27a0c0243..6ceb18a1e 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala @@ -138,6 +138,8 @@ class TypeDetector(provider: TypeProvider) { detectCastType(typeName) case Ast.expr.ByteSizeOfType(_) | Ast.expr.BitSizeOfType(_) => CalcIntType + case Ast.expr.Group(expr) => + detectType(expr) } } @@ -244,6 +246,12 @@ class TypeDetector(provider: TypeProvider) { */ def detectCallType(call: Ast.expr.Call): DataType = { call.func match { + case Ast.expr.Group(nested) => detectCallTypeImpl(nested) + case func => detectCallTypeImpl(func) + } + } + def detectCallTypeImpl(func: Ast.expr): DataType = { + func match { case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => val objType = detectType(obj) // TODO: check number and type of arguments in `call.args`