Skip to content

Commit

Permalink
Create new AST node Group in the expression language, that represents…
Browse files Browse the repository at this point in the history
… value in parenthesis

This change fixes https://github.com/kaitai-io/kaitai_struct_tests/blob/master/formats/expr_ops_parens.ksy test
for Java and, I'm sure, for all other targets
  • Loading branch information
Mingun committed Nov 19, 2020
1 parent f19304d commit 5b557ac
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ class ExpressionsSpec extends FunSpec {
)
}

it("parses group") {
Expressions.parse("(123)") should be (Group(IntNum(123)))
Expressions.parse("(foo)") should be (Group(Name(identifier("foo"))))
}

it("parses (1 + 2) / (7 * 8)") {
Expressions.parse("(1 + 2) / (7 * 8)") should be (
BinOp(
BinOp(IntNum(1), Add, IntNum(2)),
BinOp(Group(IntNum(1), Add, IntNum(2))),
Div,
BinOp(IntNum(7), Mult, IntNum(8))
BinOp(Group(IntNum(7), Mult, IntNum(8)))
)
)
}
Expand Down Expand Up @@ -128,7 +133,7 @@ class ExpressionsSpec extends FunSpec {
}

it("parses ~(7+3)") {
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3))))
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(Group(IntNum(7), Add, IntNum(3)))))
}

// Enums
Expand Down Expand Up @@ -270,7 +275,7 @@ class ExpressionsSpec extends FunSpec {

it("parses (123).as<u4>") {
Expressions.parse("(123).as<u4>") should be (
CastToType(IntNum(123), typeId(false, Seq("u4")))
CastToType(Group(IntNum(123)), typeId(false, Seq("u4")))
)
}

Expand Down
10 changes: 7 additions & 3 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,18 @@ object Ast {
case operator.BitXor => Some(leftValue ^ rightValue)
case operator.BitAnd => Some(leftValue & rightValue)
}
case expr.Group(expr) => expr.evaluateIntConst
case _ => None
}
}
}

object expr{
object expr {
case class BoolOp(op: boolop, values: Seq[expr]) extends expr
case class BinOp(left: expr, op: operator, right: expr) extends expr
case class UnaryOp(op: unaryop, operand: expr) extends expr
case class IfExp(condition: expr, ifTrue: expr, ifFalse: expr) extends expr
// case class Dict(keys: Seq[expr], values: Seq[expr]) extends expr
// case class Dict(keys: Seq[expr], values: Seq[expr]) extends expr
case class Compare(left: expr, ops: cmpop, right: expr) extends expr
case class Call(func: expr, args: Seq[expr]) extends expr
case class IntNum(n: BigInt) extends expr
Expand All @@ -106,10 +107,13 @@ object Ast {
case class Subscript(value: expr, idx: expr) extends expr
case class Name(id: identifier) extends expr
case class List(elts: Seq[expr]) extends expr

/** Wraps `expr` into parentheses. Multiple nested groups merge into the one group */
case class Group(expr: expr) extends expr
}

sealed trait boolop
object boolop{
object boolop {
case object And extends boolop
case object Or extends boolop
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ object Expressions {
case (lhs, trailers) =>
trailers.foldLeft(lhs)((l, t) => t(l))
}
val group: P[Ast.expr] = ("(" ~ test ~ ")").map(expr => Ast.expr.Group(expr));
val atom: P[Ast.expr] = {
val empty_list = ("[" ~ "]").map(_ => Ast.expr.List(Nil))
// val empty_dict = ("{" ~ "}").map(_ => Ast.expr.Dict(Nil, Nil))
P(
empty_list |
// empty_dict |
"(" ~ test ~ ")" |
group |
"[" ~ list ~ "]" |
// "{" ~ dictorsetmaker ~ "}" |
enumByName |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ abstract class BaseTranslator(val provider: TypeProvider)
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class ExpressionValidator(val provider: TypeProvider)
CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName))
case Ast.expr.BitSizeOfType(typeName) =>
CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName))
case Ast.expr.Group(nested) =>
validate(nested)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
ResultString(nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class TypeDetector(provider: TypeProvider) {
detectCastType(typeName)
case Ast.expr.ByteSizeOfType(_) | Ast.expr.BitSizeOfType(_) =>
CalcIntType
case Ast.expr.Group(expr) =>
detectType(expr)
}
}

Expand Down Expand Up @@ -222,6 +224,12 @@ class TypeDetector(provider: TypeProvider) {
*/
def detectCallType(call: Ast.expr.Call): DataType = {
call.func match {
case Ast.expr.Group(nested) => detectCallTypeImpl(nested)
case func => detectCallTypeImpl(func)
}
}
def detectCallTypeImpl(func: Ast.expr): DataType = {
func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
// TODO: check number and type of arguments in `call.args`
Expand Down

0 comments on commit 5b557ac

Please sign in to comment.