diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala index 28a69b522..e945090a7 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala @@ -17,6 +17,10 @@ class Ast$Test extends FunSpec { Expressions.parse("42 - 2").evaluateIntConst should be(Some(40)) } + it ("considers `(42 - 2)` constant") { + Expressions.parse("(42 - 2)").evaluateIntConst should be(Some(40)) + } + it ("considers `(-3 + 7) * 8 / 2` constant") { Expressions.parse("(-3 + 7) * 8 / 2").evaluateIntConst should be(Some(16)) } @@ -29,6 +33,10 @@ class Ast$Test extends FunSpec { Expressions.parse("4 > 2 ? 1 : 5").evaluateIntConst should be(Some(1)) } + it ("considers `((4) > 2) ? ((1)) : 5` constant") { + Expressions.parse("((4) > 2) ? ((1)) : 5").evaluateIntConst should be(Some(1)) + } + it ("considers `x` variable") { Expressions.parse("x").evaluateIntConst should be(None) } diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala index b8382f47d..d160e6628 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala @@ -86,12 +86,17 @@ class ExpressionsSpec extends FunSpec { ) } + it("parses group") { + Expressions.parse("(123)") should be (Group(IntNum(123))) + Expressions.parse("(foo)") should be (Group(Name(identifier("foo")))) + } + it("parses (1 + 2) / (7 * 8)") { Expressions.parse("(1 + 2) / (7 * 8)") should be ( BinOp( - BinOp(IntNum(1), Add, IntNum(2)), + Group(BinOp(IntNum(1), Add, IntNum(2))), Div, - BinOp(IntNum(7), Mult, IntNum(8)) + Group(BinOp(IntNum(7), Mult, IntNum(8))) ) ) } @@ -128,7 +133,7 @@ class ExpressionsSpec extends FunSpec { } it("parses ~(7+3)") { - Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3)))) + Expressions.parse("~(7+3)") should be (UnaryOp(Invert, Group(BinOp(IntNum(7), Add, IntNum(3))))) } // Enums @@ -270,7 +275,7 @@ class ExpressionsSpec extends FunSpec { it("parses (123).as") { Expressions.parse("(123).as") should be ( - CastToType(IntNum(123), typeId(false, Seq("u4"))) + CastToType(Group(IntNum(123)), typeId(false, Seq("u4"))) ) } diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala index 3f13b3105..f3eaf1d61 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala @@ -79,17 +79,18 @@ object Ast { case operator.BitXor => Some(leftValue ^ rightValue) case operator.BitAnd => Some(leftValue & rightValue) } + case expr.Group(expr) => expr.evaluateIntConst case _ => None } } } - object expr{ + object expr { case class BoolOp(op: boolop, values: Seq[expr]) extends expr case class BinOp(left: expr, op: operator, right: expr) extends expr case class UnaryOp(op: unaryop, operand: expr) extends expr case class IfExp(condition: expr, ifTrue: expr, ifFalse: expr) extends expr -// case class Dict(keys: Seq[expr], values: Seq[expr]) extends expr + // case class Dict(keys: Seq[expr], values: Seq[expr]) extends expr case class Compare(left: expr, ops: cmpop, right: expr) extends expr case class Call(func: expr, args: Seq[expr]) extends expr case class IntNum(n: BigInt) extends expr @@ -106,10 +107,13 @@ object Ast { case class Subscript(value: expr, idx: expr) extends expr case class Name(id: identifier) extends expr case class List(elts: Seq[expr]) extends expr + + /** Wraps `expr` into parentheses. Multiple nested groups merge into the one group */ + case class Group(expr: expr) extends expr } sealed trait boolop - object boolop{ + object boolop { case object And extends boolop case object Or extends boolop } diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala index 0298731b2..4de18d44d 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala @@ -108,13 +108,14 @@ object Expressions { case (lhs, trailers) => trailers.foldLeft(lhs)((l, t) => t(l)) } + val group: P[Ast.expr] = ("(" ~ test ~ ")").map(expr => Ast.expr.Group(expr)); val atom: P[Ast.expr] = { val empty_list = ("[" ~ "]").map(_ => Ast.expr.List(Nil)) // val empty_dict = ("{" ~ "}").map(_ => Ast.expr.Dict(Nil, Nil)) P( empty_list | // empty_dict | - "(" ~ test ~ ")" | + group | "[" ~ list ~ "]" | // "{" ~ dictorsetmaker ~ "}" | enumByName | diff --git a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala index d7f9ba26b..38b9e47ee 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala @@ -141,6 +141,12 @@ abstract class BaseTranslator(val provider: TypeProvider) doByteSizeOfType(typeName) case Ast.expr.BitSizeOfType(typeName) => doBitSizeOfType(typeName) + case Ast.expr.Group(nested) => + nested match { + // Unpack nested groups + case Ast.expr.Group(e) => translate(e) + case e => s"(${translate(e)})" + } } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala index e767bf361..e8d2e733a 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala @@ -71,6 +71,8 @@ class ExpressionValidator(val provider: TypeProvider) CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName)) case Ast.expr.BitSizeOfType(typeName) => CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName)) + case Ast.expr.Group(nested) => + validate(nested) } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala index 3a24b2a53..87ae32f13 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala @@ -103,6 +103,12 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo doByteSizeOfType(typeName) case Ast.expr.BitSizeOfType(typeName) => doBitSizeOfType(typeName) + case Ast.expr.Group(nested) => + ResultString(nested match { + // Unpack nested groups + case Ast.expr.Group(e) => translate(e) + case e => s"(${translate(e)})" + }) } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala index 4baf0591a..06ad2c314 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala @@ -126,6 +126,8 @@ class TypeDetector(provider: TypeProvider) { detectCastType(typeName) case Ast.expr.ByteSizeOfType(_) | Ast.expr.BitSizeOfType(_) => CalcIntType + case Ast.expr.Group(expr) => + detectType(expr) } } @@ -222,6 +224,12 @@ class TypeDetector(provider: TypeProvider) { */ def detectCallType(call: Ast.expr.Call): DataType = { call.func match { + case Ast.expr.Group(nested) => detectCallTypeImpl(nested) + case func => detectCallTypeImpl(func) + } + } + def detectCallTypeImpl(func: Ast.expr): DataType = { + func match { case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => val objType = detectType(obj) // TODO: check number and type of arguments in `call.args`