From 420d877dc1a8bc9b100946338f695f9002b2c270 Mon Sep 17 00:00:00 2001 From: Mingun Date: Fri, 4 Oct 2024 22:22:22 +0500 Subject: [PATCH] Use Ast.EnumRef everywhere instead of Ast.typeId + Ast.identifier Because of new types type system enforses fixing https://github.com/kaitai-io/kaitai_struct/issues/857 This test need to be updated: [info] - expr_compare_enum2 *** FAILED *** [info] [expr_compare_enum2.ksy: /seq/1/if: [info] error: can't compare EnumType(EnumRef(false,List(),animal),Int1Type(false)) and Int1Type(true) [info] ] [info] did not equal [info] [expr_compare_enum2.ksy: /seq/1/if: [info] error: can't compare EnumType(List(animal),Int1Type(false)) and Int1Type(true) [info] ] (SimpleMatchers.scala:34) --- .../struct/ClassTypeProvider$Test.scala | 99 ++++++++++--------- .../struct/exprlang/ExpressionsSpec.scala | 13 +-- .../translators/TestTypeProviders.scala | 2 +- .../io/kaitai/struct/ClassTypeProvider.scala | 14 +-- .../kaitai/struct/GraphvizClassCompiler.scala | 4 +- .../scala/io/kaitai/struct/exprlang/Ast.scala | 6 +- .../kaitai/struct/exprlang/Expressions.scala | 11 +-- .../kaitai/struct/format/InstanceSpec.scala | 2 +- .../struct/precompile/ResolveTypes.scala | 2 +- .../struct/translators/BaseTranslator.scala | 10 +- .../translators/ExpressionValidator.scala | 12 +-- .../struct/translators/GoTranslator.scala | 10 +- .../struct/translators/RustTranslator.scala | 6 +- .../struct/translators/TypeDetector.scala | 13 ++- .../struct/translators/TypeProvider.scala | 2 +- 15 files changed, 101 insertions(+), 105 deletions(-) diff --git a/jvm/src/test/scala/io/kaitai/struct/ClassTypeProvider$Test.scala b/jvm/src/test/scala/io/kaitai/struct/ClassTypeProvider$Test.scala index 2a97388e8..2e3ff390c 100644 --- a/jvm/src/test/scala/io/kaitai/struct/ClassTypeProvider$Test.scala +++ b/jvm/src/test/scala/io/kaitai/struct/ClassTypeProvider$Test.scala @@ -624,35 +624,36 @@ class ClassTypeProvider$Test extends AnyFunSpec { val e_12 = child_12.enums.get("e").getOrElse(throw new NoSuchElementException("'e_12' not found")) val e_2 = child_2.enums.get("e").getOrElse(throw new NoSuchElementException("'e_2' not found")) - val none = Ast.typeId(false, Seq()) - val one = Ast.typeId(false, Seq("one")) - val one_two = Ast.typeId(false, Seq("one", "two")) - val unknown = Ast.typeId(false, Seq("unknown")) + val none = Ast.EnumRef(false, Seq(), "e") + val one_e = Ast.EnumRef(false, Seq("one"), "e") + val one_unk = Ast.EnumRef(false, Seq("one"), "unknown") + val one_two = Ast.EnumRef(false, Seq("one", "two"), "e") + val unknown = Ast.EnumRef(false, Seq("unknown"), "e") describe("in 'root' context") { val resolver = new ClassTypeProvider(specs, root) it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_root) + resolver.resolveEnum(none) should be(e_root) } it("doesn't resolve 'one::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_e) thrown.getMessage should be("unable to find type 'one', searching from 'root'") } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'one', searching from 'root'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find type 'one', searching from 'root'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root'") } } @@ -662,25 +663,25 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_1 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_root) + resolver.resolveEnum(none) should be(e_root) } it("resolves 'one::e'") { - resolver.resolveEnum(one, "e") should be(e_11) + resolver.resolveEnum(one_e) should be(e_11) } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_1::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_1::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_1'") } } @@ -690,26 +691,26 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_2 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_2) + resolver.resolveEnum(none) should be(e_2) } it("doesn't resolve 'one::e'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "e") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_e) thrown.getMessage should be("unable to find enum 'e' in 'root::child_2::one'") } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_2::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_2::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_2'") } } @@ -719,25 +720,25 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_11 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_11) + resolver.resolveEnum(none) should be(e_11) } it("resolves 'one::e'") { - resolver.resolveEnum(one, "e") should be(e_11) + resolver.resolveEnum(one_e) should be(e_11) } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_1::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_1::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_1::one'") } } @@ -747,26 +748,26 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_12 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_12) + resolver.resolveEnum(none) should be(e_12) } it("doesn't resolve 'one::e'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "e") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_e) thrown.getMessage should be("unable to find enum 'e' in 'root::child_1::two::one'") } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_1::two::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_1::two::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_1::two'") } } @@ -776,26 +777,26 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_21 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_2) + resolver.resolveEnum(none) should be(e_2) } it("doesn't resolve 'one::e'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "e") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_e) thrown.getMessage should be("unable to find enum 'e' in 'root::child_2::one'") } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_2::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_2::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_2::one'") } } @@ -805,26 +806,26 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_22 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_2) + resolver.resolveEnum(none) should be(e_2) } it("doesn't resolve 'one::e'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "e") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_e) thrown.getMessage should be("unable to find enum 'e' in 'root::child_2::one'") } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_2::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_2::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_2::two'") } } @@ -834,26 +835,26 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_121 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_12) + resolver.resolveEnum(none) should be(e_12) } it("doesn't resolve 'one::e'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "e") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_e) thrown.getMessage should be("unable to find enum 'e' in 'root::child_1::two::one'") } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_1::two::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_1::two::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_1::two::one'") } } @@ -863,26 +864,26 @@ class ClassTypeProvider$Test extends AnyFunSpec { resolver.nowClass = child_122 it("resolves 'e'") { - resolver.resolveEnum(none, "e") should be(e_12) + resolver.resolveEnum(none) should be(e_12) } it("doesn't resolve 'one::e'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "e") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_e) thrown.getMessage should be("unable to find enum 'e' in 'root::child_1::two::one'") } it("doesn't resolve 'one::two::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(one_two) thrown.getMessage should be("unable to find type 'two' in 'root::child_1::two::one'") } it("doesn't resolve 'one::unknown'") { - val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one, "unknown") + val thrown = the[EnumNotFoundError] thrownBy resolver.resolveEnum(one_unk) thrown.getMessage should be("unable to find enum 'unknown' in 'root::child_1::two::one'") } it("doesn't resolve 'unknown::e'") { - val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown, "e") + val thrown = the[TypeNotFoundError] thrownBy resolver.resolveEnum(unknown) thrown.getMessage should be("unable to find type 'unknown', searching from 'root::child_1::two::two'") } } diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala index 309630fac..16bc7bee0 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala @@ -133,15 +133,14 @@ class ExpressionsSpec extends AnyFunSpec { // Enums it("parses port::http") { - Expressions.parse("port::http") should be (EnumByLabel(identifier("port"), identifier("http"))) + Expressions.parse("port::http") should be (EnumByLabel(EnumRef(false, Seq(), "port"), identifier("http"))) } it("parses some_type::port::http") { Expressions.parse("some_type::port::http") should be ( EnumByLabel( - identifier("port"), + EnumRef(false, Seq("some_type"), "port"), identifier("http"), - typeId(absolute = false, Seq("some_type")) ) ) } @@ -149,9 +148,8 @@ class ExpressionsSpec extends AnyFunSpec { it("parses parent_type::child_type::port::http") { Expressions.parse("parent_type::child_type::port::http") should be ( EnumByLabel( - identifier("port"), + EnumRef(false, Seq("parent_type", "child_type"), "port"), identifier("http"), - typeId(absolute = false, Seq("parent_type", "child_type")) ) ) } @@ -159,9 +157,8 @@ class ExpressionsSpec extends AnyFunSpec { it("parses ::parent_type::child_type::port::http") { Expressions.parse("::parent_type::child_type::port::http") should be ( EnumByLabel( - identifier("port"), + EnumRef(true, Seq("parent_type", "child_type"), "port"), identifier("http"), - typeId(absolute = true, Seq("parent_type", "child_type")) ) ) } @@ -171,7 +168,7 @@ class ExpressionsSpec extends AnyFunSpec { Compare( BinOp( Attribute( - EnumByLabel(identifier("port"),identifier("http")), + EnumByLabel(EnumRef(false, Seq(), "port"), identifier("http")), identifier("to_i") ), Add, diff --git a/jvm/src/test/scala/io/kaitai/struct/translators/TestTypeProviders.scala b/jvm/src/test/scala/io/kaitai/struct/translators/TestTypeProviders.scala index 051ee216d..981ed227d 100644 --- a/jvm/src/test/scala/io/kaitai/struct/translators/TestTypeProviders.scala +++ b/jvm/src/test/scala/io/kaitai/struct/translators/TestTypeProviders.scala @@ -15,7 +15,7 @@ object TestTypeProviders { abstract class FakeTypeProvider extends TypeProvider { val nowClass = ClassSpec.opaquePlaceholder(List("top_class")) - override def resolveEnum(inType: Ast.typeId, enumName: String) = + override def resolveEnum(ref: Ast.EnumRef) = throw new NotImplementedError override def resolveType(typeName: Ast.typeId): DataType = { diff --git a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala index fd74cf236..9875efaa8 100644 --- a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala +++ b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala @@ -90,18 +90,18 @@ class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends throw new FieldNotFoundError(attrName, inClass) } - override def resolveEnum(inType: Ast.typeId, enumName: String): EnumSpec = { - val inClass = if (inType.absolute) topClass else nowClass + override def resolveEnum(ref: Ast.EnumRef): EnumSpec = { + val inClass = if (ref.absolute) topClass else nowClass // When concrete type is not defined, search enum definition in all enclosing types - if (inType.names.isEmpty) { - resolveEnumName(inClass, enumName) + if (ref.typePath.isEmpty) { + resolveEnumName(inClass, ref.name) } else { - val ty = resolveTypePath(inClass, inType.names) - ty.enums.get(enumName) match { + val ty = resolveTypePath(inClass, ref.typePath) + ty.enums.get(ref.name) match { case Some(spec) => spec case None => - throw new EnumNotFoundInTypeError(enumName, ty) + throw new EnumNotFoundInTypeError(ref.name, ty) } } } diff --git a/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala index bbf1944c9..caf54dd53 100644 --- a/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala @@ -347,8 +347,8 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends List() case _: Ast.expr.EnumByLabel => List() - case Ast.expr.EnumById(_, id, _) => - affectedVars(id) + case Ast.expr.EnumById(_, expr) => + affectedVars(expr) case Ast.expr.Attribute(value, attr) => if (attr.name == Identifier.SIZEOF) { val vars = value match { diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala index d2a7a5c2d..9523ea5de 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala @@ -76,8 +76,10 @@ object Ast { case class FloatNum(n: BigDecimal) extends expr case class Str(s: String) extends expr case class Bool(n: Boolean) extends expr - case class EnumByLabel(enumName: identifier, label: identifier, inType: typeId = EmptyTypeId) extends expr - case class EnumById(enumName: identifier, id: expr, inType: typeId = EmptyTypeId) extends expr + /** Take named enumeration constant from the specified enumeration. */ + case class EnumByLabel(ref: EnumRef, label: identifier) extends expr + /** Cast specified expression to the enumerated type. Used only by value instances with `enum` key. */ + case class EnumById(ref: EnumRef, expr: expr) extends expr case class Attribute(value: expr, attr: identifier) extends expr case class CastToType(value: expr, typeName: typeId) extends expr diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala index 3f0e41a7d..02c5994be 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala @@ -171,14 +171,11 @@ object Expressions { case (first, names: Seq[Ast.identifier]) => val isAbsolute = first.nonEmpty val (enumName, enumLabel) = names.takeRight(2) match { - case Seq(a, b) => (a, b) - } - val typePath = names.dropRight(2) - if (typePath.isEmpty) { - Ast.expr.EnumByLabel(enumName, enumLabel, Ast.EmptyTypeId) - } else { - Ast.expr.EnumByLabel(enumName, enumLabel, Ast.typeId(isAbsolute, typePath.map(_.name))) + case Seq(a, b) => (a.name, b) } + val typePath = names.dropRight(2).map(n => n.name) + val ref = Ast.EnumRef(isAbsolute, typePath, enumName) + Ast.expr.EnumByLabel(ref, enumLabel) } def byteSizeOfType[$: P]: P[Ast.expr.ByteSizeOfType] = diff --git a/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala b/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala index 4a66307ed..8011bdb84 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala @@ -61,7 +61,7 @@ object InstanceSpec { case None => value case Some(enumName) => - Ast.expr.EnumById(Ast.identifier(enumName), value) + Ast.expr.EnumById(Expressions.parseEnumRef(enumName), value) } val ifExpr = ParseUtils.getOptValueExpression(srcMap, "if", path) diff --git a/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala b/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala index 0e50dda83..fe003195f 100644 --- a/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala +++ b/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala @@ -70,7 +70,7 @@ class ResolveTypes(specs: ClassSpecs, topClass: ClassSpec, opaqueTypes: Boolean) case et: EnumType => try { val resolver = new ClassTypeProvider(specs, curClass) - val ty = resolver.resolveEnum(Ast.typeId(et.name.absolute, et.name.typePath), et.name.name) + val ty = resolver.resolveEnum(et.name) Log.enumResolve.info(() => s" => ${ty.nameAsStr}") et.enumSpec = Some(ty) None diff --git a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala index 1dd550666..a36452bd9 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala @@ -59,11 +59,11 @@ abstract class BaseTranslator(val provider: TypeProvider) doInterpolatedStringLiteral(s) case Ast.expr.Bool(n) => doBoolLiteral(n) - case Ast.expr.EnumById(enumType, id, inType) => - val enumSpec = provider.resolveEnum(inType, enumType.name) - doEnumById(enumSpec, translate(id)) - case Ast.expr.EnumByLabel(enumType, label, inType) => - val enumSpec = provider.resolveEnum(inType, enumType.name) + case Ast.expr.EnumById(ref, expr) => + val enumSpec = provider.resolveEnum(ref) + doEnumById(enumSpec, translate(expr)) + case Ast.expr.EnumByLabel(ref, label) => + val enumSpec = provider.resolveEnum(ref) doEnumByLabel(enumSpec, label.name) case Ast.expr.Name(name: Ast.identifier) => if (name.name == Identifier.SIZEOF) { diff --git a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala index 8d2741282..5b214d467 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/ExpressionValidator.scala @@ -31,13 +31,13 @@ class ExpressionValidator(val provider: TypeProvider) _: Ast.expr.FloatNum | _: Ast.expr.Str | _: Ast.expr.Bool => // all simple literals are good and valid - case Ast.expr.EnumById(enumType, id, inType) => - provider.resolveEnum(inType, enumType.name) - validate(id) - case Ast.expr.EnumByLabel(enumType, label, inType) => - val enumSpec = provider.resolveEnum(inType, enumType.name) + case Ast.expr.EnumById(ref, expr) => + provider.resolveEnum(ref) + validate(expr) + case Ast.expr.EnumByLabel(ref, label) => + val enumSpec = provider.resolveEnum(ref) if (!enumSpec.map.values.exists(_.name == label.name)) { - throw new EnumMemberNotFoundError(label.name, enumType.name, enumSpec.path.mkString("/")) + throw new EnumMemberNotFoundError(label.name, ref.name, enumSpec.path.mkString("/")) } case Ast.expr.Name(name: Ast.identifier) => if (name.name == Identifier.SIZEOF) { diff --git a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala index 28647aec8..f9d600b5b 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala @@ -47,11 +47,11 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo trInterpolatedStringLiteral(s) case Ast.expr.Bool(n) => trBoolLiteral(n) - case Ast.expr.EnumById(enumType, id, inType) => - val enumSpec = provider.resolveEnum(inType, enumType.name) - trEnumById(enumSpec.name, translate(id)) - case Ast.expr.EnumByLabel(enumType, label, inType) => - val enumSpec = provider.resolveEnum(inType, enumType.name) + case Ast.expr.EnumById(ref, expr) => + val enumSpec = provider.resolveEnum(ref) + trEnumById(enumSpec.name, translate(expr)) + case Ast.expr.EnumByLabel(ref, label) => + val enumSpec = provider.resolveEnum(ref) trEnumByLabel(enumSpec.name, label.name) case Ast.expr.Name(name: Ast.identifier) => if (name.name == Identifier.SIZEOF) { diff --git a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala index 788d49182..c56d73b2d 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala @@ -396,10 +396,10 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) override def translate(v: Ast.expr): String = { v match { - case Ast.expr.EnumById(enumType, id, inType) => - id match { + case Ast.expr.EnumById(ref, expr) => + expr match { case ifExp: Ast.expr.IfExp => - val enumSpec = provider.resolveEnum(inType, enumType.name) + val enumSpec = provider.resolveEnum(ref) val enumName = RustCompiler.types2class(enumSpec.name) def toStr(ex: Ast.expr) = ex match { case Ast.expr.IntNum(n) => s"$enumName::try_from($n)?" diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala index 0f572f875..7866d764f 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala @@ -50,14 +50,13 @@ class TypeDetector(provider: TypeProvider) { case Ast.expr.Str(_) => CalcStrType case Ast.expr.InterpolatedStr(_) => CalcStrType case Ast.expr.Bool(_) => CalcBooleanType - case Ast.expr.EnumByLabel(enumType, _, inType) => - val t = EnumType(Ast.EnumRef(false, inType.names.toList, enumType.name), CalcIntType) - t.enumSpec = Some(provider.resolveEnum(inType, enumType.name)) + case Ast.expr.EnumByLabel(ref, _) => + val t = EnumType(ref, CalcIntType) + t.enumSpec = Some(provider.resolveEnum(ref)) t - case Ast.expr.EnumById(enumType, _, inType) => - // TODO: May be create a type with a name that includes surrounding type? - val t = EnumType(Ast.EnumRef(false, List(), enumType.name), CalcIntType) - t.enumSpec = Some(provider.resolveEnum(inType, enumType.name)) + case Ast.expr.EnumById(ref, _) => + val t = EnumType(ref, CalcIntType) + t.enumSpec = Some(provider.resolveEnum(ref)) t case Ast.expr.Name(name: Ast.identifier) => provider.determineType(name.name).asNonOwning() case Ast.expr.InternalName(id) => provider.determineType(id) diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala index 9cf5548ad..e2d4eb117 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala @@ -16,7 +16,7 @@ trait TypeProvider { def determineType(attrId: Identifier): DataType def determineType(inClass: ClassSpec, attrName: String): DataType def determineType(inClass: ClassSpec, attrId: Identifier): DataType - def resolveEnum(typeName: Ast.typeId, enumName: String): EnumSpec + def resolveEnum(ref: Ast.EnumRef): EnumSpec def resolveType(typeName: Ast.typeId): DataType def isLazy(attrName: String): Boolean def isLazy(inClass: ClassSpec, attrName: String): Boolean