From d6793db5f171658ffeeda8b08a1b6519e2566edf Mon Sep 17 00:00:00 2001 From: Petr Pucil Date: Sun, 1 Sep 2024 20:00:00 +0200 Subject: [PATCH] Rust: adapt to recent runtime library changes --- .../struct/languages/RustCompiler.scala | 56 ++++++++++++------- .../struct/translators/RustTranslator.scala | 2 +- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala index d8b2e7867..d3f5d5cf8 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala @@ -152,7 +152,7 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def runReadCalc(): Unit = { out.puts(s"if *${privateMemberName(EndianIdentifier)} == 0 {") out.inc - out.puts(s"""return Err(KError::UndecidedEndiannessError("${typeProvider.nowClass.path.mkString("/", "/", "")}".to_string()));""") + out.puts(s"""return Err(${ksErrorName(UndecidedEndiannessError)} { src_path: "${typeProvider.nowClass.path.mkString("/", "/", "")}".to_string() });""") out.dec out.puts("}") } @@ -371,19 +371,19 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) case ProcessXor(xorValue) => translator.detectType(xorValue) match { case _: IntType => - s"BytesReader::process_xor_one(&$srcExpr, ${expression(xorValue)})" + s"process_xor_one(&$srcExpr, ${expression(xorValue)})" case _: BytesType => - s"BytesReader::process_xor_many(&$srcExpr, &${translator.remove_deref(expression(xorValue))})" + s"process_xor_many(&$srcExpr, &${translator.remove_deref(expression(xorValue))})" } case ProcessZlib => - s"BytesReader::process_zlib(&$srcExpr)" + s"process_zlib(&$srcExpr).map_err(|msg| KError::BytesDecodingError { msg })?" case ProcessRotate(isLeft, rotValue) => val expr = if (isLeft) { expression(rotValue) } else { s"8 - (${expression(rotValue)})" } - s"BytesReader::process_rotate_left(&$srcExpr, $expr)" + s"process_rotate_left(&$srcExpr, $expr)" case ProcessCustom(name, args) => val procClass = name.map(x => type2class(x)).mkString("::") val procName = s"_process_${idToStr(varSrc)}" @@ -394,7 +394,7 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) val argList = translate_args(args, into = false) val argListInParens = s"($argList)" out.puts(s"let $procName = $procClass::new$argListInParens;") - s"$procName.decode(&$srcExpr)" + s"$procName.decode(&$srcExpr).map_err(|msg| KError::BytesDecodingError { msg })?" } handleAssignment(varDest, expr, rep, isRaw = false) } @@ -836,13 +836,13 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Seq[Byte]], include: Boolean): String = { val ioId = privateMemberName(IoIdentifier) val expr1 = padRight match { - case Some(padByte) => s"BytesReader::bytes_strip_right(&$expr0, $padByte).into()" + case Some(padByte) => s"bytes_strip_right(&$expr0, $padByte).into()" case None => expr0 } val expr2 = terminator match { case Some(term) => val t = term.head & 0xff - s"BytesReader::bytes_terminate(&$expr1, $t, $include).into()" + s"bytes_terminate(&$expr1, $t, $include).into()" case None => expr1 } expr2 @@ -1186,13 +1186,7 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) case t: ArrayType => s"Arr${switchVariantName(id, t.elType)}" } - override def ksErrorName(err: KSError): String = err match { - case EndOfStreamError => "KError::EncounteredEOF" - case UndecidedEndiannessError => "KError::UndecidedEndiannessError" - case ConversionError => "KError::CastError" - case _: ValidationError => s"KError::ValidationNotEqual" - } - + override def ksErrorName(err: KSError): String = RustCompiler.ksErrorName(err) override def attrValidateExpr( attr: AttrLikeSpec, @@ -1200,10 +1194,11 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) err: KSError, errArgs: List[Ast.expr] ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val srcPathStr = translator.translate(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + val validationKind = RustCompiler.validationErrorKind(err.asInstanceOf[ValidationError]) out.puts(s"if !(${expression(checkExpr)}) {") out.inc - out.puts(s"""return Err(KError::ValidationNotEqual(r#"$errArgsStr"#.to_string()));""") + out.puts(s"""return Err(${ksErrorName(err)}(ValidationFailedError { kind: $validationKind, src_path: $srcPathStr.to_string() }));""") out.dec out.puts("}") } @@ -1249,13 +1244,12 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) object RustCompiler extends LanguageCompilerStatic with StreamStructNames - with UpperCamelCaseClasses { + with UpperCamelCaseClasses + with ExceptionNames { override def getCompiler(tp: ClassTypeProvider, config: RuntimeConfig): LanguageCompiler = new RustCompiler(tp, config) - override def kstreamName = "KStream" - var in_reader = false def self_name(): String = { @@ -1293,10 +1287,30 @@ object RustCompiler rootClassTypeName(c.upClass.get, isRecurse = true) } - override def kstructName = s"KStruct" + override def kstreamName = "KStream" + override def kstructName = "KStruct" def kstructUnitName = "KStructUnit" + override def ksErrorName(err: KSError): String = err match { + case EndOfStreamError => "KError::Eof" + case UndecidedEndiannessError => "KError::UndecidedEndianness" + case ConversionError => "KError::CastError" + case _: ValidationError => "KError::ValidationFailed" + } + + def validationErrorKind(err: ValidationError): String = { + val kind = err match { + case _: ValidationNotEqualError => "NotEqual" + case _: ValidationLessThanError => "LessThan" + case _: ValidationGreaterThanError => "GreaterThan" + case _: ValidationNotAnyOfError => "NotAnyOf" + case _: ValidationNotInEnumError => "NotInEnum" + case _: ValidationExprError => "Expr" + } + s"ValidationKind::$kind" + } + def classTypeName(c: ClassSpec): String = s"${types2class(c.name)}" diff --git a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala index f30fe27af..788d49182 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala @@ -452,7 +452,7 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) s"${remove_deref(translate(i))}.to_string()" override def bytesToStr(bytesExpr: String, encoding: String): String = - s"""decode_string(&$bytesExpr, &"$encoding")?""" + s"""bytes_to_str(&$bytesExpr, "$encoding")?""" override def bytesLength(b: Ast.expr): String = s"${remove_deref(translate(b))}.len()"