Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create new AST node Group in the expression language, that represents value in parenthesis #214

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class Ast$Test extends AnyFunSpec {
Expressions.parse("42 - 2").evaluateIntConst should be(Some(40))
}

it ("considers `(42 - 2)` constant") {
Expressions.parse("(42 - 2)").evaluateIntConst should be(Some(40))
}

it ("considers `(-3 + 7) * 8 / 2` constant") {
Expressions.parse("(-3 + 7) * 8 / 2").evaluateIntConst should be(Some(16))
}
Expand All @@ -33,6 +37,10 @@ class Ast$Test extends AnyFunSpec {
Expressions.parse("4 > 2 ? 1 : 5").evaluateIntConst should be(Some(1))
}

it ("considers `((4) > 2) ? ((1)) : 5` constant") {
Expressions.parse("((4) > 2) ? ((1)) : 5").evaluateIntConst should be(Some(1))
}

it ("considers `x` variable") {
Expressions.parse("x").evaluateIntConst should be(None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ class ExpressionsSpec extends AnyFunSpec {
)
}

it("parses group") {
Expressions.parse("(123)") should be (Group(IntNum(123)))
Expressions.parse("(foo)") should be (Group(Name(identifier("foo"))))
}

it("parses (1 + 2) / (7 * 8)") {
Expressions.parse("(1 + 2) / (7 * 8)") should be (
BinOp(
BinOp(IntNum(1), Add, IntNum(2)),
Group(BinOp(IntNum(1), Add, IntNum(2))),
Div,
BinOp(IntNum(7), Mult, IntNum(8))
Group(BinOp(IntNum(7), Mult, IntNum(8)))
)
)
}
Expand Down Expand Up @@ -128,7 +133,7 @@ class ExpressionsSpec extends AnyFunSpec {
}

it("parses ~(7+3)") {
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3))))
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, Group(BinOp(IntNum(7), Add, IntNum(3)))))
}

// Enums
Expand Down Expand Up @@ -285,7 +290,7 @@ class ExpressionsSpec extends AnyFunSpec {

it("parses (123).as<u4>") {
Expressions.parse("(123).as<u4>") should be (
CastToType(IntNum(123), typeId(false, Seq("u4")))
CastToType(Group(IntNum(123)), typeId(false, Seq("u4")))
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AttrSpec$Test extends AnyFunSpec {
spec.id should be(NamedIdentifier("foo"))
val dataType = spec.dataType.asInstanceOf[UserType]
dataType.name should be(List("bar"))
dataType.args should be(Seq(Ast.expr.IntNum(5)))
dataType.args should be(Seq(Ast.expr.Group(Ast.expr.IntNum(5))))
}
}
}
5 changes: 4 additions & 1 deletion shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object Ast {
}
}

object expr{
object expr {
case class BoolOp(op: boolop, values: Seq[expr]) extends expr
case class BinOp(left: expr, op: operator, right: expr) extends expr
case class UnaryOp(op: unaryop, operand: expr) extends expr
Expand Down Expand Up @@ -91,6 +91,9 @@ object Ast {
case class List(elts: Seq[expr]) extends expr
case class InterpolatedStr(elts: Seq[expr]) extends expr

/** Wraps `expr` into parentheses. Multiple nested groups merge into the one group */
case class Group(expr: expr) extends expr

/**
* Implicit declaration of ordering, so expressions can be used for ordering operations, e.g.
* for `SortedMap.from(...)`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ object ConstEvaluator {
case _ => value.NonConst
}

case expr.Group(expr) => evaluate(expr)
case _ => value.NonConst
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ object Expressions {
case (lhs, trailers) =>
trailers.foldLeft(lhs)((l, t) => t(l))
}
def group[$: P]: P[Ast.expr] = ("(" ~ test ~ ")").map(expr => Ast.expr.Group(expr))
def empty_list[$: P] = P("[" ~ "]").map(_ => Ast.expr.List(Nil))
// def empty_dict[_: P] = P("{" ~ "}").map(_ => Ast.expr.Dict(Nil, Nil))
def atom[$: P]: P[Ast.expr] = P(
empty_list |
// empty_dict |
"(" ~ test ~ ")" |
group |
"[" ~ list ~ "]" |
// "{" ~ dictorsetmaker ~ "}" |
enumByName |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ abstract class BaseTranslator(val provider: TypeProvider)
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kaitai.struct.exprlang.Ast

trait CommonOps extends AbstractTranslator {
def numericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) = {
s"(${translate(left)} ${binOp(op)} ${translate(right)})"
s"${translate(left)} ${binOp(op)} ${translate(right)}"
}

def binOp(op: Ast.operator): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class ExpressionValidator(val provider: TypeProvider)
CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName))
case Ast.expr.InterpolatedStr(elts: Seq[Ast.expr]) =>
elts.foreach(validate)
case Ast.expr.Group(nested) =>
validate(nested)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
ResultString(nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class JavaScriptTranslator(provider: TypeProvider) extends BaseTranslator(provid
case (_: IntType, _: IntType, Ast.operator.Mod) =>
s"${JavaScriptCompiler.kstreamName}.mod(${translate(left)}, ${translate(right)})"
case (_: IntType, _: IntType, Ast.operator.RShift) =>
s"(${translate(left)} >>> ${translate(right)})"
s"${translate(left)} >>> ${translate(right)}"
case _ =>
super.numericBinOp(left, op, right)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class TypeDetector(provider: TypeProvider) {
detectCastType(typeName)
case Ast.expr.ByteSizeOfType(_) | Ast.expr.BitSizeOfType(_) =>
CalcIntType
case Ast.expr.Group(expr) =>
detectType(expr)
}
}

Expand Down Expand Up @@ -244,6 +246,12 @@ class TypeDetector(provider: TypeProvider) {
*/
def detectCallType(call: Ast.expr.Call): DataType = {
call.func match {
case Ast.expr.Group(nested) => detectCallTypeImpl(nested)
case func => detectCallTypeImpl(func)
}
}
def detectCallTypeImpl(func: Ast.expr): DataType = {
func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
// TODO: check number and type of arguments in `call.args`
Expand Down
Loading