Skip to content

Commit

Permalink
Rust: adapt to recent runtime library changes
Browse files Browse the repository at this point in the history
  • Loading branch information
generalmimon committed Sep 1, 2024
1 parent f360132 commit d6793db
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
56 changes: 35 additions & 21 deletions shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def runReadCalc(): Unit = {
out.puts(s"if *${privateMemberName(EndianIdentifier)} == 0 {")
out.inc
out.puts(s"""return Err(KError::UndecidedEndiannessError("${typeProvider.nowClass.path.mkString("/", "/", "")}".to_string()));""")
out.puts(s"""return Err(${ksErrorName(UndecidedEndiannessError)} { src_path: "${typeProvider.nowClass.path.mkString("/", "/", "")}".to_string() });""")
out.dec
out.puts("}")
}
Expand Down Expand Up @@ -371,19 +371,19 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
case ProcessXor(xorValue) =>
translator.detectType(xorValue) match {
case _: IntType =>
s"BytesReader::process_xor_one(&$srcExpr, ${expression(xorValue)})"
s"process_xor_one(&$srcExpr, ${expression(xorValue)})"
case _: BytesType =>
s"BytesReader::process_xor_many(&$srcExpr, &${translator.remove_deref(expression(xorValue))})"
s"process_xor_many(&$srcExpr, &${translator.remove_deref(expression(xorValue))})"
}
case ProcessZlib =>
s"BytesReader::process_zlib(&$srcExpr)"
s"process_zlib(&$srcExpr).map_err(|msg| KError::BytesDecodingError { msg })?"
case ProcessRotate(isLeft, rotValue) =>
val expr = if (isLeft) {
expression(rotValue)
} else {
s"8 - (${expression(rotValue)})"
}
s"BytesReader::process_rotate_left(&$srcExpr, $expr)"
s"process_rotate_left(&$srcExpr, $expr)"
case ProcessCustom(name, args) =>
val procClass = name.map(x => type2class(x)).mkString("::")
val procName = s"_process_${idToStr(varSrc)}"
Expand All @@ -394,7 +394,7 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
val argList = translate_args(args, into = false)
val argListInParens = s"($argList)"
out.puts(s"let $procName = $procClass::new$argListInParens;")
s"$procName.decode(&$srcExpr)"
s"$procName.decode(&$srcExpr).map_err(|msg| KError::BytesDecodingError { msg })?"
}
handleAssignment(varDest, expr, rep, isRaw = false)
}
Expand Down Expand Up @@ -836,13 +836,13 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Seq[Byte]], include: Boolean): String = {
val ioId = privateMemberName(IoIdentifier)
val expr1 = padRight match {
case Some(padByte) => s"BytesReader::bytes_strip_right(&$expr0, $padByte).into()"
case Some(padByte) => s"bytes_strip_right(&$expr0, $padByte).into()"
case None => expr0
}
val expr2 = terminator match {
case Some(term) =>
val t = term.head & 0xff
s"BytesReader::bytes_terminate(&$expr1, $t, $include).into()"
s"bytes_terminate(&$expr1, $t, $include).into()"
case None => expr1
}
expr2
Expand Down Expand Up @@ -1186,24 +1186,19 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
case t: ArrayType => s"Arr${switchVariantName(id, t.elType)}"
}

override def ksErrorName(err: KSError): String = err match {
case EndOfStreamError => "KError::EncounteredEOF"
case UndecidedEndiannessError => "KError::UndecidedEndiannessError"
case ConversionError => "KError::CastError"
case _: ValidationError => s"KError::ValidationNotEqual"
}

override def ksErrorName(err: KSError): String = RustCompiler.ksErrorName(err)

override def attrValidateExpr(
attr: AttrLikeSpec,
checkExpr: Ast.expr,
err: KSError,
errArgs: List[Ast.expr]
): Unit = {
val errArgsStr = errArgs.map(translator.translate).mkString(", ")
val srcPathStr = translator.translate(Ast.expr.Str(attr.path.mkString("/", "/", "")))
val validationKind = RustCompiler.validationErrorKind(err.asInstanceOf[ValidationError])
out.puts(s"if !(${expression(checkExpr)}) {")
out.inc
out.puts(s"""return Err(KError::ValidationNotEqual(r#"$errArgsStr"#.to_string()));""")
out.puts(s"""return Err(${ksErrorName(err)}(ValidationFailedError { kind: $validationKind, src_path: $srcPathStr.to_string() }));""")
out.dec
out.puts("}")
}
Expand Down Expand Up @@ -1249,13 +1244,12 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
object RustCompiler
extends LanguageCompilerStatic
with StreamStructNames
with UpperCamelCaseClasses {
with UpperCamelCaseClasses
with ExceptionNames {
override def getCompiler(tp: ClassTypeProvider,
config: RuntimeConfig): LanguageCompiler =
new RustCompiler(tp, config)

override def kstreamName = "KStream"

var in_reader = false

def self_name(): String = {
Expand Down Expand Up @@ -1293,10 +1287,30 @@ object RustCompiler
rootClassTypeName(c.upClass.get, isRecurse = true)
}

override def kstructName = s"KStruct"
override def kstreamName = "KStream"
override def kstructName = "KStruct"

def kstructUnitName = "KStructUnit"

override def ksErrorName(err: KSError): String = err match {
case EndOfStreamError => "KError::Eof"
case UndecidedEndiannessError => "KError::UndecidedEndianness"
case ConversionError => "KError::CastError"
case _: ValidationError => "KError::ValidationFailed"
}

def validationErrorKind(err: ValidationError): String = {
val kind = err match {
case _: ValidationNotEqualError => "NotEqual"
case _: ValidationLessThanError => "LessThan"
case _: ValidationGreaterThanError => "GreaterThan"
case _: ValidationNotAnyOfError => "NotAnyOf"
case _: ValidationNotInEnumError => "NotInEnum"
case _: ValidationExprError => "Expr"
}
s"ValidationKind::$kind"
}

def classTypeName(c: ClassSpec): String =
s"${types2class(c.name)}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig)
s"${remove_deref(translate(i))}.to_string()"

override def bytesToStr(bytesExpr: String, encoding: String): String =
s"""decode_string(&$bytesExpr, &"$encoding")?"""
s"""bytes_to_str(&$bytesExpr, "$encoding")?"""

override def bytesLength(b: Ast.expr): String =
s"${remove_deref(translate(b))}.len()"
Expand Down

0 comments on commit d6793db

Please sign in to comment.