diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala index bebe464e7..8a68c1512 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala @@ -17,6 +17,10 @@ class Ast$Test extends AnyFunSpec { Expressions.parse("42 - 2").evaluateIntConst should be(Some(40)) } + it ("considers `(42 - 2)` constant") { + Expressions.parse("(42 - 2)").evaluateIntConst should be(Some(40)) + } + it ("considers `(-3 + 7) * 8 / 2` constant") { Expressions.parse("(-3 + 7) * 8 / 2").evaluateIntConst should be(Some(16)) } @@ -33,6 +37,10 @@ class Ast$Test extends AnyFunSpec { Expressions.parse("4 > 2 ? 1 : 5").evaluateIntConst should be(Some(1)) } + it ("considers `((4) > 2) ? ((1)) : 5` constant") { + Expressions.parse("((4) > 2) ? ((1)) : 5").evaluateIntConst should be(Some(1)) + } + it ("considers `x` variable") { Expressions.parse("x").evaluateIntConst should be(None) } diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala index 488c81861..39b3584d6 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala @@ -86,12 +86,17 @@ class ExpressionsSpec extends AnyFunSpec { ) } + it("parses group") { + Expressions.parse("(123)") should be (Group(IntNum(123))) + Expressions.parse("(foo)") should be (Group(Name(identifier("foo")))) + } + it("parses (1 + 2) / (7 * 8)") { Expressions.parse("(1 + 2) / (7 * 8)") should be ( BinOp( - BinOp(IntNum(1), Add, IntNum(2)), + Group(BinOp(IntNum(1), Add, IntNum(2))), Div, - BinOp(IntNum(7), Mult, IntNum(8)) + Group(BinOp(IntNum(7), Mult, IntNum(8))) ) ) } @@ -128,7 +133,7 @@ class ExpressionsSpec extends AnyFunSpec { } it("parses ~(7+3)") { - Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3)))) + Expressions.parse("~(7+3)") should be (UnaryOp(Invert, Group(BinOp(IntNum(7), Add, IntNum(3))))) } // Enums @@ -285,7 +290,7 @@ class ExpressionsSpec extends AnyFunSpec { it("parses (123).as") { Expressions.parse("(123).as") should be ( - CastToType(IntNum(123), typeId(false, Seq("u4"))) + CastToType(Group(IntNum(123)), typeId(false, Seq("u4"))) ) } diff --git a/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala b/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala index 392053e24..ed5f5734b 100644 --- a/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala +++ b/jvm/src/test/scala/io/kaitai/struct/format/AttrSpec$Test.scala @@ -112,7 +112,7 @@ class AttrSpec$Test extends AnyFunSpec { spec.id should be(NamedIdentifier("foo")) val dataType = spec.dataType.asInstanceOf[UserType] dataType.name should be(List("bar")) - dataType.args should be(Seq(Ast.expr.IntNum(5))) + dataType.args should be(Seq(Ast.expr.Group(Ast.expr.IntNum(5)))) } } } diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala index 2acf1943b..9204a3e18 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala @@ -63,7 +63,7 @@ object Ast { } } - object expr{ + object expr { case class BoolOp(op: boolop, values: Seq[expr]) extends expr case class BinOp(left: expr, op: operator, right: expr) extends expr case class UnaryOp(op: unaryop, operand: expr) extends expr @@ -91,6 +91,9 @@ object Ast { case class List(elts: Seq[expr]) extends expr case class InterpolatedStr(elts: Seq[expr]) extends expr + /** Wraps `expr` into parentheses. Multiple nested groups merge into the one group */ + case class Group(expr: expr) extends expr + /** * Implicit declaration of ordering, so expressions can be used for ordering operations, e.g. * for `SortedMap.from(...)` diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala index d146b3b50..0ffda45b9 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/ConstEvaluator.scala @@ -130,6 +130,7 @@ object ConstEvaluator { case _ => value.NonConst } + case expr.Group(expr) => evaluate(expr) case _ => value.NonConst } diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala index 0ef98ce71..dc81bc692 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala @@ -45,9 +45,9 @@ object Expressions { def formatExpr[$: P]: P[Ast.expr] = P("{" ~/ test ~ "}") def test[$: P]: P[Ast.expr] = P( or_test ~ ("?" ~ test ~ ":" ~ test).? ).map { - case (x, None) => x - case (condition, Some((ifTrue, ifFalse))) => Ast.expr.IfExp(condition, ifTrue, ifFalse) - } + case (x, None) => x + case (condition, Some((ifTrue, ifFalse))) => Ast.expr.IfExp(condition, ifTrue, ifFalse) + } def or_test[$: P] = P( and_test.rep(1, kw("or")) ).map { case Seq(x) => x case xs => Ast.expr.BoolOp(Ast.boolop.Or, xs) @@ -106,46 +106,47 @@ object Expressions { power ) // def power[_: P]: P[Ast.expr] = P( atom ~ trailer.rep ~ (Pow ~ factor).? ).map { - // case (lhs, trailers, rhs) => - // val left = trailers.foldLeft(lhs)((l, t) => t(l)) - // rhs match{ - // case None => left - // case Some((op, right)) => Ast.expr.BinOp(left, op, right) - // } - // } +// case (lhs, trailers, rhs) => +// val left = trailers.foldLeft(lhs)((l, t) => t(l)) +// rhs match{ +// case None => left +// case Some((op, right)) => Ast.expr.BinOp(left, op, right) +// } +// } def power[$: P]: P[Ast.expr] = P( atom ~ trailer.rep ).map { case (lhs, trailers) => trailers.foldLeft(lhs)((l, t) => t(l)) } + def group[$: P]: P[Ast.expr] = ("(" ~ test ~ ")").map(expr => Ast.expr.Group(expr)) def empty_list[$: P] = P("[" ~ "]").map(_ => Ast.expr.List(Nil)) // def empty_dict[_: P] = P("{" ~ "}").map(_ => Ast.expr.Dict(Nil, Nil)) def atom[$: P]: P[Ast.expr] = P( - empty_list | - // empty_dict | - "(" ~ test ~ ")" | - "[" ~ list ~ "]" | - // "{" ~ dictorsetmaker ~ "}" | - enumByName | - byteSizeOfType | - bitSizeOfType | + empty_list | +// empty_dict | + group | + "[" ~ list ~ "]" | +// "{" ~ dictorsetmaker ~ "}" | + enumByName | + byteSizeOfType | + bitSizeOfType | fstring | - STRING.rep(1).map(_.mkString).map(Ast.expr.Str) | - NAME.map((x) => x.name match { - case "true" => Ast.expr.Bool(true) - case "false" => Ast.expr.Bool(false) - case _ => Ast.expr.Name(x) - }) | - FLOAT_NUMBER.map(Ast.expr.FloatNum) | - INT_NUMBER.map(Ast.expr.IntNum) - ) + STRING.rep(1).map(_.mkString).map(Ast.expr.Str) | + NAME.map((x) => x.name match { + case "true" => Ast.expr.Bool(true) + case "false" => Ast.expr.Bool(false) + case _ => Ast.expr.Name(x) + }) | + FLOAT_NUMBER.map(Ast.expr.FloatNum) | + INT_NUMBER.map(Ast.expr.IntNum) + ) def list_contents[$: P] = P( test.rep(1, ",") ~ ",".? ) def list[$: P] = P( list_contents ).map(Ast.expr.List(_)) def call[$: P] = P("(" ~ arglist ~ ")").map { case (args) => (lhs: Ast.expr) => Ast.expr.Call(lhs, args)} def slice[$: P] = P("[" ~ test ~ "]").map { case (args) => (lhs: Ast.expr) => Ast.expr.Subscript(lhs, args)} def cast[$: P] = P( "." ~ "as" ~ "<" ~ TYPE_NAME ~ ">" ).map( - typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName) - ) + typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName) + ) def attr[$: P] = P("." ~ NAME).map(id => (lhs: Ast.expr) => Ast.expr.Attribute(lhs, id)) def trailer[$: P]: P[Ast.expr => Ast.expr] = P( call | slice | cast | attr ) @@ -154,11 +155,11 @@ object Expressions { // def dict_item[_: P] = P( test ~ ":" ~ test ) // def dict[_: P]: P[Ast.expr.Dict] = P( - // (dict_item.rep(1, ",") ~ ",".?).map { x => - // val (keys, values) = x.unzip - // Ast.expr.Dict(keys, values) - // } - // ) +// (dict_item.rep(1, ",") ~ ",".?).map{x => +// val (keys, values) = x.unzip +// Ast.expr.Dict(keys, values) +// } +// ) // def dictorsetmaker[_: P]: P[Ast.expr] = P( /*dict_comp |*/ dict /*| set_comp | set*/) def arglist[$: P]: P[Seq[Ast.expr]] = P( (test).rep(0, ",") ) diff --git a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala index e59db2048..1745a6062 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala @@ -148,6 +148,12 @@ abstract class BaseTranslator(val provider: TypeProvider) doByteSizeOfType(typeName) case Ast.expr.BitSizeOfType(typeName) => doBitSizeOfType(typeName) + case Ast.expr.Group(nested) => + nested match { + // Unpack nested groups + case Ast.expr.Group(e) => translate(e) + case e => s"(${translate(e)})" + } } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala index 8d2741282..bf0b1d0da 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala @@ -62,10 +62,10 @@ class ExpressionValidator(val provider: TypeProvider) case Ast.expr.Subscript(container: Ast.expr, idx: Ast.expr) => detectType(container) match { case _: ArrayType | _: BytesType => - validate(container) + validate(container) detectType(idx) match { case _: IntType => - validate(idx) + validate(idx) case indexType => throw new TypeMismatchError(s"subscript operation on arrays require index to be integer, but found $indexType") } @@ -87,6 +87,8 @@ class ExpressionValidator(val provider: TypeProvider) CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName)) case Ast.expr.InterpolatedStr(elts: Seq[Ast.expr]) => elts.foreach(validate) + case Ast.expr.Group(nested) => + validate(nested) } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala index 7e85ab670..a8d96a0d9 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala @@ -115,6 +115,12 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo doByteSizeOfType(typeName) case Ast.expr.BitSizeOfType(typeName) => doBitSizeOfType(typeName) + case Ast.expr.Group(nested) => + ResultString(nested match { + // Unpack nested groups + case Ast.expr.Group(e) => translate(e) + case e => s"(${translate(e)})" + }) } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala index 27a0c0243..6ceb18a1e 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala @@ -138,6 +138,8 @@ class TypeDetector(provider: TypeProvider) { detectCastType(typeName) case Ast.expr.ByteSizeOfType(_) | Ast.expr.BitSizeOfType(_) => CalcIntType + case Ast.expr.Group(expr) => + detectType(expr) } } @@ -244,6 +246,12 @@ class TypeDetector(provider: TypeProvider) { */ def detectCallType(call: Ast.expr.Call): DataType = { call.func match { + case Ast.expr.Group(nested) => detectCallTypeImpl(nested) + case func => detectCallTypeImpl(func) + } + } + def detectCallTypeImpl(func: Ast.expr): DataType = { + func match { case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => val objType = detectType(obj) // TODO: check number and type of arguments in `call.args`