Skip to content

Commit

Permalink
Add support of absolute paths for enums on parse level
Browse files Browse the repository at this point in the history
  • Loading branch information
Mingun committed Oct 4, 2024
1 parent 3e68dcc commit 6ac1562
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 38 deletions.
56 changes: 56 additions & 0 deletions jvm/src/test/scala/io/kaitai/struct/exprlang/EnumRefSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.kaitai.struct.exprlang

import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers._

class EnumRefSpec extends AnyFunSpec {
describe("Expressions.parseEnumRef") {
describe("parses local enum refs") {
it("some_enum") {
Expressions.parseEnumRef("some_enum") should be(Ast.EnumRef(
false, Seq(), "some_enum"
))
}
it("with spaces: ' some_enum '") {
Expressions.parseEnumRef(" some_enum ") should be(Ast.EnumRef(
false, Seq(), "some_enum"
))
}

it("::some_enum") {
Expressions.parseEnumRef("::some_enum") should be(Ast.EnumRef(
true, Seq(), "some_enum"
))
}
it("with spaces: ' :: some_enum '") {
Expressions.parseEnumRef(" :: some_enum ") should be(Ast.EnumRef(
true, Seq(), "some_enum"
))
}
}

describe("parses path enum refs") {
it("some::enum") {
Expressions.parseEnumRef("some::enum") should be(Ast.EnumRef(
false, Seq("some"), "enum"
))
}
it("with spaces: ' some :: enum '") {
Expressions.parseEnumRef(" some :: enum ") should be(Ast.EnumRef(
false, Seq("some"), "enum"
))
}

it("::some::enum") {
Expressions.parseEnumRef("::some::enum") should be(Ast.EnumRef(
true, Seq("some"), "enum"
))
}
it("with spaces: ' :: some :: enum '") {
Expressions.parseEnumRef(" :: some :: enum ") should be(Ast.EnumRef(
true, Seq("some"), "enum"
))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ object GraphvizClassCompiler extends LanguageCompilerStatic {
): LanguageCompiler = ???

def type2class(name: List[String]) = name.last
def type2display(name: List[String]) = name.map(Utils.upperCamelCase).mkString("::")
def type2display(name: Seq[String]) = name.map(Utils.upperCamelCase).mkString("::")

def dataTypeName(dataType: DataType, valid: Option[ValidationSpec]): String = {
dataType match {
Expand All @@ -508,7 +508,7 @@ object GraphvizClassCompiler extends LanguageCompilerStatic {
val comma = if (bytesStr.isEmpty) "" else ", "
s"str($bytesStr$comma$encoding)"
case EnumType(name, basedOn) =>
s"${dataTypeName(basedOn, valid)}${type2display(name)}"
s"${dataTypeName(basedOn, valid)}${type2display(name.fullName)}"
case BitsType(width, bitEndian) => s"b$width${bitEndian.toSuffix}"
case BitsType1(bitEndian) => s"b1${bitEndian.toSuffix}→bool"
case _ => dataType.toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ object DataType {
def isOwning = false
}

case class EnumType(name: List[String], basedOn: IntType) extends DataType {
case class EnumType(name: Ast.EnumRef, basedOn: IntType) extends DataType {
var enumSpec: Option[EnumSpec] = None

/**
Expand Down Expand Up @@ -487,7 +487,7 @@ object DataType {
enumRef match {
case Some(enumName) =>
r match {
case numType: IntType => EnumType(classNameToList(enumName), numType)
case numType: IntType => EnumType(Expressions.parseEnumRef(enumName), numType)
case _ =>
throw KSYParseError(s"tried to resolve non-integer $r to enum", path).toException
}
Expand Down
13 changes: 13 additions & 0 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,18 @@ object Ast {
case object GtE extends cmpop
}

/**
* Reference to an enum in scope. Scope is defined by the `absolute` flag and
* a path to a type (which can be empty) in which enum is defined.
*/
case class EnumRef(absolute: Boolean, typePath: Seq[String], name: String) {
/** @return Type path and name of enum in one list. */
def fullName: Seq[String] = typePath :+ name
/**
* @return Enum designation name as human-readable string, to be used in compiler
* error messages.
*/
def asStr: String = fullName.mkString(if (absolute) "::" else "", "::", "")
}
case class TypeWithArguments(typeName: typeId, arguments: expr.List)
}
15 changes: 15 additions & 0 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ object Expressions {
case (path, Some(args)) => Ast.TypeWithArguments(path, args)
}

def enumRef[$: P]: P[Ast.EnumRef] = P(Start ~ "::".!.? ~ NAME.rep(1, "::") ~ End).map {
case (absolute, names) =>
// List have at least one element, so we always can split it into head and the last element
val typePath :+ enumName = names
Ast.EnumRef(absolute.nonEmpty, typePath.map(i => i.name), enumName.name)
}

class ParseException(val src: String, val failure: Parsed.Failure)
extends RuntimeException(failure.msg)

Expand All @@ -211,6 +218,14 @@ object Expressions {
*/
def parseTypeRef(src: String): Ast.TypeWithArguments = realParse(src, typeRef(_))

/**
* Parse string with reference to enumeration definition, optionally in full path format.
*
* @param src Enum reference as string, like `::path::to::enum`
* @return Object that represents path to enum
*/
def parseEnumRef(src: String): Ast.EnumRef = realParse(src, enumRef(_))

private def realParse[T](src: String, parser: P[_] => P[T]): T = {
val r = fastparse.parse(src.trim, parser)
r match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ object CSharpCompiler extends LanguageCompilerStatic
case KaitaiStreamType | OwnedKaitaiStreamType => kstreamName

case t: UserType => types2class(t.name)
case EnumType(name, _) => types2class(name)
case EnumType(ref, _) => types2class(ref.fullName)

case at: ArrayType => {
importList.add("System.Collections.Generic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ object CppCompiler extends LanguageCompilerStatic
types2class(if (absolute) {
t.enumSpec.get.name
} else {
t.name
t.name.fullName
})

case at: ArrayType => {
Expand Down Expand Up @@ -1210,7 +1210,7 @@ object CppCompiler extends LanguageCompilerStatic
)
}

def types2class(components: List[String]) =
def types2class(components: Seq[String]) =
components.map(type2class).mkString("::")

def type2class(name: String) = Utils.lowerUnderscoreCase(name) + "_t"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ object JavaCompiler extends LanguageCompilerStatic
case KaitaiStructType | CalcKaitaiStructType(_) => kstructName

case t: UserType => types2class(t.name)
case EnumType(name, _) => types2class(name)
case EnumType(ref, _) => types2class(ref.fullName)

case _: ArrayType => kaitaiType2JavaTypeBoxed(attrType, importList)

Expand Down Expand Up @@ -918,7 +918,7 @@ object JavaCompiler extends LanguageCompilerStatic
case KaitaiStructType | CalcKaitaiStructType(_) => kstructName

case t: UserType => types2class(t.name)
case EnumType(name, _) => types2class(name)
case EnumType(ref, _) => types2class(ref.fullName)

case at: ArrayType => {
importList.add("java.util.ArrayList")
Expand All @@ -929,7 +929,7 @@ object JavaCompiler extends LanguageCompilerStatic
}
}

def types2class(names: List[String]) = names.map(x => type2class(x)).mkString(".")
def types2class(names: Seq[String]) = names.map(x => type2class(x)).mkString(".")

override def kstreamName: String = "KaitaiStream"
override def kstructName: String = "KaitaiStruct"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ object RustCompiler
def classTypeName(c: ClassSpec): String =
s"${types2class(c.name)}"

def types2class(names: List[String]): String =
def types2class(names: Seq[String]): String =
// TODO: Use `mod` to scope types instead of weird names
names.map(x => type2class(x)).mkString("_")

Expand All @@ -1339,7 +1339,7 @@ object RustCompiler
case t: EnumType =>
val baseName = t.enumSpec match {
case Some(spec) => s"${types2class(spec.name)}"
case None => s"${types2class(t.name)}"
case None => s"${types2class(t.name.fullName)}"
}
baseName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,19 @@ class ResolveTypes(specs: ClassSpecs, topClass: ClassSpec, opaqueTypes: Boolean)
}
}
case et: EnumType =>
et.name match {
case typePath :+ name =>
try {
val resolver = new ClassTypeProvider(specs, curClass)
val ty = resolver.resolveEnum(Ast.typeId(false, typePath), name)
Log.enumResolve.info(() => s" => ${ty.nameAsStr}")
et.enumSpec = Some(ty)
None
} catch {
case ex: TypeNotFoundError =>
Log.typeResolve.info(() => s" => ??? (while resolving enum '${et.name}'): $ex")
Log.enumResolve.info(() => s" => ??? (enclosing type not found, enum '${et.name}'): $ex")
Some(TypeNotFoundErr(typePath, curClass, path :+ "enum"))
case ex: EnumNotFoundError =>
Log.enumResolve.info(() => s" => ??? (enum '${et.name}'): $ex")
Some(EnumNotFoundErr(et.name, curClass, path :+ "enum"))
}
case _ =>
Log.enumResolve.info(() => s" => ??? (enum '${et.name}' without name)")
// TODO: Maybe more specific error about empty name?
try {
val resolver = new ClassTypeProvider(specs, curClass)
val ty = resolver.resolveEnum(Ast.typeId(et.name.absolute, et.name.typePath), et.name.name)
Log.enumResolve.info(() => s" => ${ty.nameAsStr}")
et.enumSpec = Some(ty)
None
} catch {
case ex: TypeNotFoundError =>
Log.typeResolve.info(() => s" => ??? (while resolving enum '${et.name}'): $ex")
Log.enumResolve.info(() => s" => ??? (enclosing type not found, enum '${et.name}'): $ex")
Some(TypeNotFoundErr(et.name.typePath, curClass, path :+ "enum"))
case ex: EnumNotFoundError =>
Log.enumResolve.info(() => s" => ??? (enum '${et.name}'): $ex")
Some(EnumNotFoundErr(et.name, curClass, path :+ "enum"))
}
case st: SwitchType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.kaitai.struct.problems

import io.kaitai.struct.{JSON, Jsonable, Utils, problems}
import io.kaitai.struct.datatype.DataType
import io.kaitai.struct.exprlang.Expressions
import io.kaitai.struct.exprlang.{Ast, Expressions}
import io.kaitai.struct.format.{ClassSpec, Identifier, KSVersion}
import fastparse.Parsed.Failure

Expand Down Expand Up @@ -170,7 +170,7 @@ case class ParamMismatchError(idx: Int, argType: DataType, paramName: String, pa
override def severity: ProblemSeverity = ProblemSeverity.Error
}

case class TypeNotFoundErr(name: List[String], curClass: ClassSpec, path: List[String], fileName: Option[String] = None)
case class TypeNotFoundErr(name: Seq[String], curClass: ClassSpec, path: List[String], fileName: Option[String] = None)
extends CompilationProblem {

override def text = s"unable to find type '${name.mkString("::")}', searching from '${curClass.nameAsStr}'"
Expand All @@ -180,10 +180,10 @@ case class TypeNotFoundErr(name: List[String], curClass: ClassSpec, path: List[S
override def severity: ProblemSeverity = ProblemSeverity.Error
}

case class EnumNotFoundErr(name: List[String], curClass: ClassSpec, path: List[String], fileName: Option[String] = None)
case class EnumNotFoundErr(ref: Ast.EnumRef, curClass: ClassSpec, path: List[String], fileName: Option[String] = None)
extends CompilationProblem {

override def text = s"unable to find enum '${name.mkString("::")}', searching from '${curClass.nameAsStr}'"
override def text = s"unable to find enum '${ref.asStr}', searching from '${curClass.nameAsStr}'"
override val coords: ProblemCoords = ProblemCoords(fileName, Some(path))
override def localizedInFile(fileName: String): CompilationProblem =
copy(fileName = Some(fileName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ class TypeDetector(provider: TypeProvider) {
case Ast.expr.InterpolatedStr(_) => CalcStrType
case Ast.expr.Bool(_) => CalcBooleanType
case Ast.expr.EnumByLabel(enumType, _, inType) =>
val t = EnumType(inType.names.toList :+ enumType.name, CalcIntType)
val t = EnumType(Ast.EnumRef(false, inType.names.toList, enumType.name), CalcIntType)
t.enumSpec = Some(provider.resolveEnum(inType, enumType.name))
t
case Ast.expr.EnumById(enumType, _, inType) =>
val t = EnumType(List(enumType.name), CalcIntType)
// TODO: May be create a type with a name that includes surrounding type?
val t = EnumType(Ast.EnumRef(false, List(), enumType.name), CalcIntType)
t.enumSpec = Some(provider.resolveEnum(inType, enumType.name))
t
case Ast.expr.Name(name: Ast.identifier) => provider.determineType(name.name).asNonOwning()
Expand Down

0 comments on commit 6ac1562

Please sign in to comment.