diff --git a/build.sbt b/build.sbt index 71b14cfc3..c9d700e5f 100644 --- a/build.sbt +++ b/build.sbt @@ -8,7 +8,7 @@ resolvers ++= Resolver.sonatypeOssRepos("public") val NAME = "kaitai-struct-compiler" val VERSION = "0.11-SNAPSHOT" -val TARGET_LANGS = "C++/STL, C#, Go, Java, JavaScript, Lua, Nim, Perl, PHP, Python, Ruby" +val TARGET_LANGS = "C++/STL, C#, Go, Java, JavaScript, Lua, Nim, Perl, PHP, Python, Ruby, Rust" val UTF8 = Charset.forName("UTF-8") lazy val root = project.in(file(".")). diff --git a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala index abed472ff..0dee55827 100644 --- a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala +++ b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala @@ -9,6 +9,7 @@ import io.kaitai.struct.translators.TypeProvider class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends TypeProvider { var nowClass = topClass + val allClasses: ClassSpecs = classSpecs var _currentIteratorType: Option[DataType] = None var _currentSwitchType: Option[DataType] = None diff --git a/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala index a148abc43..c0f0817dd 100644 --- a/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala @@ -1,7 +1,8 @@ package io.kaitai.struct -import io.kaitai.struct.datatype.DataType.{KaitaiStreamType, UserTypeInstream} -import io.kaitai.struct.datatype.{Endianness, FixedEndian, InheritedEndian} +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.datatype._ +import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format._ import io.kaitai.struct.languages.RustCompiler import io.kaitai.struct.languages.components.ExtraAttrs @@ -29,18 +30,19 @@ class RustClassCompiler( // Basic struct declaration lang.classHeader(curClass.name) - + compileAttrDeclarations(curClass.seq ++ extraAttrs) curClass.instances.foreach { case (instName, instSpec) => compileInstanceDeclaration(instName, instSpec) } - + // Constructor = Read() function compileReadFunction(curClass) - + compileInstances(curClass) compileAttrReaders(curClass.seq ++ extraAttrs) + curClass.toStringExpr.foreach(expr => lang.classToString(expr)) lang.classFooter(curClass.name) compileEnums(curClass) @@ -49,7 +51,7 @@ class RustClassCompiler( compileSubclasses(curClass) } - def compileReadFunction(curClass: ClassSpec) = { + def compileReadFunction(curClass: ClassSpec): Unit = { lang.classConstructorHeader( curClass.name, curClass.parentType, @@ -58,23 +60,41 @@ class RustClassCompiler( curClass.params ) - // FIXME val defEndian = curClass.meta.endian match { case Some(fe: FixedEndian) => Some(fe) case _ => None } - - lang.readHeader(defEndian, false) - + + lang.readHeader(defEndian, isEmpty = false) + + curClass.meta.endian match { + case Some(ce: CalcEndian) => compileCalcEndian(ce) + case Some(_) => // Nothing to generate + case None => // Same here + } + compileSeq(curClass.seq, defEndian) lang.classConstructorFooter } - override def compileInstances(curClass: ClassSpec) = { + override def compileCalcEndian(ce: CalcEndian): Unit = { + def renderProc(result: FixedEndian): Unit = { + val v = result match { + case LittleEndian => Ast.expr.IntNum(1) + case BigEndian => Ast.expr.IntNum(2) + } + lang.instanceCalculate(IS_LE_ID, CalcIntType, v) + } + lang.switchCases[FixedEndian](IS_LE_ID, ce.on, ce.cases, renderProc, renderProc) + lang.runReadCalc() + } + + override def compileInstances(curClass: ClassSpec): Unit = { lang.instanceDeclHeader(curClass.name) curClass.instances.foreach { case (instName, instSpec) => compileInstance(curClass.name, instName, instSpec, curClass.meta.endian) } + lang.instanceFooter } override def compileInstance(className: List[String], instName: InstanceIdentifier, instSpec: InstanceSpec, endian: Option[Endianness]): Unit = { @@ -88,16 +108,16 @@ class RustClassCompiler( lang.instanceHeader(className, instName, dataType, instSpec.isNullable) lang.instanceCheckCacheAndReturn(instName, dataType) + lang.instanceSetCalculated(instName) instSpec match { case vi: ValueInstanceSpec => lang.attrParseIfHeader(instName, vi.ifExpr) lang.instanceCalculate(instName, dataType, vi.value) lang.attrParseIfFooter(vi.ifExpr) - case i: ParseInstanceSpec => - lang.attrParse(i, instName, None) // FIXME + case pi: ParseInstanceSpec => + lang.attrParse(pi, instName, None) // FIXME } - lang.instanceSetCalculated(instName) lang.instanceReturn(instName, dataType) lang.instanceFooter } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala index 9fdb64a2c..b0e885e7b 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala @@ -1,238 +1,404 @@ package io.kaitai.struct.languages +import io.kaitai.struct._ +//import io.kaitai.struct.datatype.DataType.{ReadableType, _} +import io.kaitai.struct.datatype._ import io.kaitai.struct.datatype.DataType._ -import io.kaitai.struct.datatype.{DataType, FixedEndian, InheritedEndian, KSError} +//import io.kaitai.struct.datatype.{DataType, FixedEndian, InheritedEndian, KSError} import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.format.{NoRepeat, RepeatEos, RepeatExpr, RepeatSpec, _} +import io.kaitai.struct.format._ import io.kaitai.struct.languages.components._ import io.kaitai.struct.translators.RustTranslator -import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils, ExternalType} + +import scala.annotation.tailrec class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) extends LanguageCompiler(typeProvider, config) + with AllocateIOLocalVar + with EveryReadIsExpression + with FixedContentsUsingArrayByteLiteral with ObjectOrientedLanguage - with UpperCamelCaseClasses with SingleOutputFile - with AllocateIOLocalVar + with UpperCamelCaseClasses with UniversalFooter - with UniversalDoc - with FixedContentsUsingArrayByteLiteral - with EveryReadIsExpression { + with SwitchIfOps + with UniversalDoc { import RustCompiler._ + override val translator: RustTranslator = + new RustTranslator(typeProvider, config) + override def innerClasses = false override def innerEnums = false - override val translator: RustTranslator = new RustTranslator(typeProvider, config) - - override def universalFooter: Unit = { - out.dec - out.puts("}") - } - - override def outImports(topClass: ClassSpec) = - importList.toList.map((x) => s"use $x;").mkString("", "\n", "\n") - override def indent: String = " " + override def outFileName(topClassName: String): String = s"$topClassName.rs" + override def outImports(topClass: ClassSpec): String = + importList.toList + .mkString("", "\n", "\n") + override def fileHeader(topClassName: String): Unit = { outHeader.puts(s"// $headerComment") outHeader.puts - importList.add("std::option::Option") - importList.add("std::boxed::Box") - importList.add("std::io::Result") - importList.add("std::io::Cursor") - importList.add("std::vec::Vec") - importList.add("std::default::Default") - importList.add("kaitai_struct::KaitaiStream") - importList.add("kaitai_struct::KaitaiStruct") + outHeader.puts("#![allow(unused_imports)]") + outHeader.puts("#![allow(non_snake_case)]") + outHeader.puts("#![allow(non_camel_case_types)]") + outHeader.puts("#![allow(irrefutable_let_patterns)]") + outHeader.puts("#![allow(unused_comparisons)]") + outHeader.puts + outHeader.puts("extern crate kaitai;") - out.puts + importList.add("use kaitai::*;") + importList.add("use std::convert::{TryFrom, TryInto};") + importList.add("use std::cell::{Ref, Cell, RefCell};") + importList.add("use std::rc::{Rc, Weak};") } - override def externalTypeDeclaration(extType: ExternalType): Unit = { - val className = type2class(extType.name.last) - val pkg = type2classAbs(extType.name) + override def externalTypeDeclaration(extType: ExternalType): Unit = + importList.add( + s"use super::${extType.name.head}::${types2class(extType.name)};" + ) + + override def classHeader(name: List[String]): Unit = { + out.puts + out.puts("#[derive(Default, Debug, Clone)]") + out.puts(s"pub struct ${classTypeName(typeProvider.nowClass)} {") + out.inc + + val root = types2class(name.slice(0, 1)) + out.puts(s"pub ${privateMemberName(RootIdentifier)}: SharedType<$root>,") + + val parent = if (typeProvider.nowClass.isTopLevel) + root + else { + kaitaiTypeToNativeType(None, typeProvider.nowClass, typeProvider.nowClass.parentType, cleanTypename = true) + } + out.puts(s"pub ${privateMemberName(ParentIdentifier)}: SharedType<$parent>,") + out.puts(s"pub _self: SharedType,") + + typeProvider.nowClass.params.foreach { p => + // Make sure the parameter is imported if necessary + p.dataType match { + case u: UserType => if (u.isExternal(typeProvider.nowClass)) externalTypeDeclaration(ExternalUserType(u.classSpec.get)) + case _ => () + } - importList.add(s"$pkg::$className") + // Declare parameters as if they were attributes + attributeDeclaration(p.id, p.dataType, isNullable = false) + } } - override def classHeader(name: List[String]): Unit = - classHeader(name, Some(kstructName)) + // Intentional no-op; Rust has already ended the struct definition by the time we reach this + override def classFooter(name: List[String]): Unit = {} + + override def classConstructorHeader(name: List[String], + parentType: DataType, + rootClassName: List[String], + isHybrid: Boolean, + params: List[ParamDefSpec]): Unit = { + typeProvider.nowClass.meta.endian match { + case Some(_: CalcEndian) | Some(InheritedEndian) => + attributeDeclaration(EndianIdentifier, IntMultiType(signed = true, Width4, None), isNullable = false) + case _ => + } + + // Unlike other OOP languages, implementing an interface happens outside the struct declaration. + universalFooter - def classHeader(name: List[String], parentClass: Option[String]): Unit = { - out.puts("#[derive(Default)]") - out.puts(s"pub struct ${type2class(name)} {") + // If there are any switch types in the struct definition, create the enums for them + typeProvider.nowClass.seq.foreach( + a => + a.dataType match { + case st: SwitchType => switchTypeEnum(a.id, st) + case _ => () + } + ) + typeProvider.nowClass.instances.foreach( + i => + i._2.dataTypeComposite match { + case st: SwitchType => switchTypeEnum(i._1, st) + case _ => () + } + ) + + out.puts( + s"impl $kstructName for ${classTypeName(typeProvider.nowClass)} {" + ) + out.inc + val root = classTypeName(typeProvider.topClass) + out.puts(s"type Root = $root;") + + val parent = if (typeProvider.nowClass.isTopLevel) + root + else { + kaitaiTypeToNativeType(None, typeProvider.nowClass, typeProvider.nowClass.parentType, cleanTypename = true) + } + + out.puts( + s"type Parent = $parent;" + ) + out.puts } - override def classFooter(name: List[String]): Unit = universalFooter + override def runRead(name: List[String]): Unit = {} - override def classConstructorHeader(name: List[String], parentType: DataType, rootClassName: List[String], isHybrid: Boolean, params: List[ParamDefSpec]): Unit = { + override def runReadCalc(): Unit = { + out.puts(s"if *${privateMemberName(EndianIdentifier)} == 0 {") + out.inc + out.puts(s"""return Err(${ksErrorName(UndecidedEndiannessError)} { src_path: "${typeProvider.nowClass.path.mkString("/", "/", "")}".to_string() });""") + out.dec out.puts("}") - out.puts + } - out.puts(s"impl KaitaiStruct for ${type2class(name)} {") + override def readHeader(endian: Option[FixedEndian], + isEmpty: Boolean): Unit = { + RustCompiler.in_reader = true + val root = privateMemberName(RootIdentifier) + out.puts(s"fn read(") + out.inc + out.puts(s"self_rc: &OptRc,") + out.puts(s"${privateMemberName(IoIdentifier)}: &S,") + out.puts( + s"$root: SharedType," + ) + out.puts( + s"${privateMemberName(ParentIdentifier)}: SharedType," + ) + out.dec + out.puts(s") -> KResult<()> {") out.inc - // Parameter names - val pIo = paramName(IoIdentifier) - val pParent = paramName(ParentIdentifier) - val pRoot = paramName(RootIdentifier) + out.puts(s"*self_rc._io.borrow_mut() = _io.clone();") + out.puts(s"self_rc._root.set(_root.get());") + out.puts(s"self_rc._parent.set(_parent.get());") + out.puts(s"self_rc._self.set(Ok(self_rc.clone()));") - // Types - val tIo = kstreamName - val tParent = kaitaiType2NativeType(parentType) + out.puts(s"let _rrc = self_rc._root.get_value().borrow().upgrade();") + out.puts(s"let _prc = self_rc._parent.get_value().borrow().upgrade();") + out.puts(s"let _r = _rrc.as_ref().unwrap();") - out.puts(s"fn new(stream: &mut S,") - out.puts(s" _parent: &Option>,") - out.puts(s" _root: &Option>)") - out.puts(s" -> Result") - out.inc - out.puts(s"where Self: Sized {") + // If there aren't any attributes to parse, we need to end the read implementation here + if (typeProvider.nowClass.seq.isEmpty) + endRead() + } - out.puts(s"let mut s: Self = Default::default();") - out.puts + override def readFooter(): Unit = out.puts(s"// readFooter()") - out.puts(s"s.stream = stream;") + override def attributeDeclaration(attrName: Identifier, + attrType: DataType, + isNullable: Boolean): Unit = { + val typeName = attrName match { + case RootIdentifier | ParentIdentifier => return + case _ => + kaitaiTypeToNativeType(Some(attrName), typeProvider.nowClass, attrType) + } - out.puts(s"s.read(stream, _parent, _root)?;") - out.puts + out.puts(s"${idToStr(attrName)}: RefCell<$typeName>,") + } + + override def attributeReader(attrName: Identifier, + attrType: DataType, + isNullable: Boolean): Unit = { + var typeName = attrName match { + case RootIdentifier | ParentIdentifier => return + case _ => + kaitaiTypeToNativeType(Some(attrName), typeProvider.nowClass, attrType) + } - out.puts("Ok(s)") + out.puts( + s"impl ${classTypeName(typeProvider.nowClass)} {") + out.inc + + var types : Set[DataType] = Set() + var enum_typename = false + var switch_typename = false + attrType match { + case st: SwitchType => + types = st.cases.values.toSet + switch_typename = true + case _: EnumType => enum_typename = true + case _ => + } + var enum_only_numeric = true + types.foreach { + case _: NumericType => // leave unchanged + case _ => enum_only_numeric = false + } + var fn = idToStr(attrName) + if (switch_typename && enum_only_numeric) { + out.puts(s"pub fn $fn(&self) -> usize {") + out.inc + out.puts(s"self.${idToStr(attrName)}.borrow().as_ref().unwrap().into()") + out.dec + out.puts("}") + fn = s"${fn}_enum" + } + { + out.puts(s"pub fn $fn(&self) -> Ref<$typeName> {") + out.inc + out.puts(s"self.${idToStr(attrName)}.borrow()") + } + out.dec + out.puts("}") out.dec out.puts("}") - out.puts } - override def runRead(name: List[String]): Unit = { + override def attrParse(attr: AttrLikeSpec, + id: Identifier, + defEndian: Option[Endianness]): Unit = { + super.attrParse(attr, id, defEndian) + // Detect if this is the last attribute parse and finish the read method + if (typeProvider.nowClass.seq.nonEmpty && typeProvider.nowClass.seq.last.id == id) + endRead() } - override def runReadCalc(): Unit = { + def endRead(): Unit = { + out.puts("Ok(())") + out.dec + out.puts("}") + RustCompiler.in_reader = false + } + override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = {} + + override def condIfHeader(expr: Ast.expr): Unit = { + out.puts(s"if ${expression(expr)} {") + out.inc } - override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = { - out.puts - out.puts(s"fn read(&mut self,") - out.puts(s" stream: &mut S,") - out.puts(s" _parent: &Option>,") - out.puts(s" _root: &Option>)") - out.puts(s" -> Result<()>") + override def condRepeatInitAttr(id: Identifier, dataType: DataType): Unit = { + // this line required for handleAssignmentRepeatUntil + typeProvider._currentIteratorType = Some(dataType) + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = Vec::new();") + } + + override def condRepeatEosHeader(id: Identifier, + io: String, + dataType: DataType): Unit = { + out.puts("{") + out.inc + out.puts(s"let mut _i = 0;") + out.puts(s"while !_io.is_eof() {") out.inc - out.puts(s"where Self: Sized {") } - override def readFooter(): Unit = { - out.puts - out.puts("Ok(())") + override def handleAssignmentRepeatEos(id: Identifier, expr: String): Unit = { + out.puts(s"${RustCompiler.privateMemberName(id, writeAccess = true)}.push($expr);") + } + + override def condRepeatEosFooter: Unit = { + out.puts("_i += 1;") + out.dec + out.puts("}") out.dec out.puts("}") } - override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { - attrName match { - case ParentIdentifier | RootIdentifier | IoIdentifier => - // just ignore it for now - case IoIdentifier => - out.puts(s" stream: ${kaitaiType2NativeType(attrType)},") - case _ => - out.puts(s" pub ${idToStr(attrName)}: ${kaitaiType2NativeType(attrType)},") - } + override def condRepeatExprHeader(id: Identifier, + io: String, + dataType: DataType, + repeatExpr: Ast.expr): Unit = { + val lenVar = s"l_${idToStr(id)}" + out.puts(s"let $lenVar = ${expression(repeatExpr)};") + out.puts(s"for _i in 0..$lenVar {") + out.inc } - override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { + override def condRepeatUntilHeader(id: Identifier, + io: String, + dataType: DataType, + repeatExpr: Ast.expr): Unit = { + out.puts("{") + out.inc + out.puts("let mut _i = 0;") + out.puts("while {") + out.inc + } + override def createSubstream(id: Identifier, byteType: BytesType, io: String, rep: RepeatSpec, defEndian: Option[FixedEndian]): String = { + createSubstreamBuffered(id, byteType, io, rep, defEndian) } - override def universalDoc(doc: DocSpec): Unit = { - if (doc.summary.isDefined) { - out.puts - out.puts("/*") - doc.summary.foreach((summary) => out.putsLines(" * ", summary)) - out.puts(" */") + override def handleAssignmentRepeatUntil(id: Identifier, + expr: String, + isRaw: Boolean): Unit = { + out.puts(s"${RustCompiler.privateMemberName(id, writeAccess = true)}.push($expr);") + var copy_type = "" + if (typeProvider._currentIteratorType.isDefined && translator.is_copy_type(typeProvider._currentIteratorType.get)) { + copy_type = "*" } + val t = localTemporaryName(id) + out.puts(s"let $t = ${privateMemberName(id)};") + out.puts(s"let ${translator.doLocalName(Identifier.ITERATOR)} = $copy_type$t.last().unwrap();") } - override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = { - out.puts("if ($this->_m__is_le) {") - out.inc - leProc() + override def condRepeatUntilFooter(id: Identifier, + io: String, + dataType: DataType, + repeatExpr: Ast.expr): Unit = { + // this line required by kaitai code + typeProvider._currentIteratorType = Some(dataType) + out.puts("_i += 1;") + out.puts(s"let x = !(${expression(repeatExpr)});") + out.puts("x") out.dec - out.puts("} else {") - out.inc - beProc() + out.puts("} {}") out.dec out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);") + def getRawIdExpr(varName: Identifier, rep: RepeatSpec): String = { + val memberName = privateMemberName(varName) + rep match { + case NoRepeat => memberName + case _ => s"$memberName[$memberName.len() - 1]" + } + } override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) val expr = proc match { case ProcessXor(xorValue) => - val procName = translator.detectType(xorValue) match { - case _: IntType => "processXorOne" - case _: BytesType => "processXorMany" + translator.detectType(xorValue) match { + case _: IntType => + s"process_xor_one(&$srcExpr, ${expression(xorValue)})" + case _: BytesType => + s"process_xor_many(&$srcExpr, &${translator.remove_deref(expression(xorValue))})" } - s"$kstreamName::$procName($srcExpr, ${expression(xorValue)})" case ProcessZlib => - s"$kstreamName::processZlib($srcExpr);" + s"process_zlib(&$srcExpr).map_err(|msg| KError::BytesDecodingError { msg })?" case ProcessRotate(isLeft, rotValue) => val expr = if (isLeft) { expression(rotValue) } else { s"8 - (${expression(rotValue)})" } - s"$kstreamName::processRotateLeft($srcExpr, $expr, 1)" + s"process_rotate_left(&$srcExpr, $expr)" case ProcessCustom(name, args) => - val procClass = if (name.length == 1) { - val onlyName = name.head - val className = type2class(onlyName) - importList.add(s"$onlyName::$className") - className - } else { - val pkgName = type2classAbs(name.init) - val className = type2class(name.last) - importList.add(s"$pkgName::$className") - s"$pkgName::$className" - } - - out.puts(s"let _process = $procClass::new(${args.map(expression).mkString(", ")});") - s"_process.decode($srcExpr)" - } - handleAssignment(varDest, expr, rep, false) - } + val procClass = name.map(x => type2class(x)).mkString("::") + val procName = s"_process_${idToStr(varSrc)}" - override def allocateIO(id: Identifier, rep: RepeatSpec): String = { - val memberName = privateMemberName(id) + val mod_name = name.last + importList.add(s"use crate::$mod_name::*;") - val args = rep match { - case RepeatUntil(_) => translator.doLocalName(Identifier.ITERATOR2) - case _ => getRawIdExpr(id, rep) - } - - out.puts(s"let mut io = Cursor::new($args);") - "io" - } - - def getRawIdExpr(varName: Identifier, rep: RepeatSpec): String = { - val memberName = privateMemberName(varName) - rep match { - case NoRepeat => memberName - case _ => s"$memberName.last()" + val argList = translate_args(args, into = false) + val argListInParens = s"($argList)" + out.puts(s"let $procName = $procClass::new$argListInParens;") + s"$procName.decode(&$srcExpr).map_err(|msg| KError::BytesDecodingError { msg })?" } + handleAssignment(varDest, expr, rep, isRaw = false) } override def useIO(ioEx: Ast.expr): String = { - out.puts(s"let mut io = ${expression(ioEx)};") + out.puts(s"let io = Clone::clone(&*${expression(ioEx)});") "io" } @@ -240,373 +406,984 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"let _pos = $io.pos();") override def seek(io: String, pos: Ast.expr): Unit = - out.puts(s"$io.seek(${expression(pos)});") + out.puts(s"$io.seek(${expression(pos)} as usize)?;") override def popPos(io: String): Unit = - out.puts(s"$io.seek(_pos);") + out.puts(s"$io.seek(_pos)?;") override def alignToByte(io: String): Unit = - out.puts(s"$io.alignToByte();") + out.puts(s"${privateMemberName(IoIdentifier)}.align_to_byte()?;") - override def condIfHeader(expr: Ast.expr): Unit = { - out.puts(s"if ${expression(expr)} {") + override def privateMemberName(id: Identifier): String = + RustCompiler.privateMemberName(id) + + override def instanceDeclHeader(className: List[String]): Unit = { + if (typeProvider.nowClass.params.nonEmpty) { + val paramsArg = Utils.join(typeProvider.nowClass.params.map { p => + val n = paramName(p.id) + val t = kaitaiTypeToNativeType(Some(p.id), typeProvider.nowClass, p.dataType, excludeOptionWrapper = true) + var byref = "" + if (!translator.is_copy_type(p.dataType)) + byref = "&" + // generate param access helper + attributeReader(p.id, p.dataType, isNullable = false) + s"$n: $byref$t" + }, "", ", ", "") + + out.puts(s"impl ${classTypeName(typeProvider.nowClass)} {") + out.inc + out.puts(s"pub fn set_params(&mut self, $paramsArg) {") + out.inc + typeProvider.nowClass.params.foreach(p => handleAssignmentParams(p.id, paramName(p.id))) + out.dec + out.puts("}") + out.dec + out.puts("}") + } + typeProvider.nowClass.meta.endian match { + case Some(_: CalcEndian) | Some(InheritedEndian) => + out.puts(s"impl ${classTypeName(typeProvider.nowClass)} {") + out.inc + val t = kaitaiTypeToNativeType(Some(EndianIdentifier), typeProvider.nowClass, IntMultiType(signed = true, Width4, None), excludeOptionWrapper = true) + out.puts(s"pub fn set_endian(&mut self, ${idToStr(EndianIdentifier)}: $t) {") + out.inc + handleAssignmentSimple(EndianIdentifier, s"${idToStr(EndianIdentifier)}") + out.dec + out.puts("}") + out.dec + out.puts("}") + case _ => + } + out.puts(s"impl ${classTypeName(typeProvider.nowClass)} {") out.inc } - override def condRepeatInitAttr(id: Identifier, dataType: DataType): Unit = - out.puts(s"${privateMemberName(id)} = vec!();") + override def universalFooter: Unit = { + out.dec + out.puts("}") + } - override def condRepeatEosHeader(id: Identifier, io: String, dataType: DataType): Unit = { - out.puts(s"while !$io.isEof() {") - out.inc + override def instanceDeclaration(attrName: InstanceIdentifier, + attrType: DataType, + isNullable: Boolean): Unit = { + val typeName = kaitaiTypeToNativeType( + Some(attrName), + typeProvider.nowClass, + attrType + ) + out.puts(s"${calculatedFlagForName(attrName)}: Cell,") + out.puts(s"${idToStr(attrName)}: RefCell<$typeName>,") } - override def handleAssignmentRepeatEos(id: Identifier, expr: String): Unit = { - out.puts(s"${privateMemberName(id)}.append($expr);") + def calculatedFlagForName(ksName: Identifier) = + s"f_${idToStr(ksName)}" + + override def instanceClear(instName: InstanceIdentifier): Unit = { + var set = false + val ins = translator.get_instance(typeProvider.nowClass, idToStr(instName)) + if (ins.isDefined) { + set = ins.get.dataTypeComposite match { + case _: UserType => true + case _ => false + } + } + if (!set) { + out.puts(s"self.${calculatedFlagForName(instName)}.set(false);") + } } - override def condRepeatEosFooter: Unit = { - super.condRepeatEosFooter + override def instanceSetCalculated(instName: InstanceIdentifier): Unit = { + var set = false + val ins = translator.get_instance(typeProvider.nowClass, idToStr(instName)) + if (ins.isDefined) { + set = ins.get.dataTypeComposite match { + case _: UserType => true + case _ => false + } + } + if (!set) { + out.puts(s"self.${calculatedFlagForName(instName)}.set(true);") + } } - override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, repeatExpr: Ast.expr): Unit = { - out.puts(s"for i in 0..${expression(repeatExpr)} {") + override def idToStr(id: Identifier): String = RustCompiler.idToStr(id) + + override def instanceHeader(className: List[String], + instName: InstanceIdentifier, + dataType: DataType, + isNullable: Boolean): Unit = { + out.puts(s"pub fn ${idToStr(instName)}(") + out.inc + out.puts("&self") + out.dec + val typeName = kaitaiTypeToNativeType( + Some(instName), + typeProvider.nowClass, + dataType + ) + out.puts(s") -> KResult> {") out.inc + out.puts(s"let _io = self._io.borrow();") + out.puts(s"let _rrc = self._root.get_value().borrow().upgrade();") + out.puts(s"let _prc = self._parent.get_value().borrow().upgrade();") + out.puts(s"let _r = _rrc.as_ref().unwrap();") } - override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = - handleAssignmentRepeatEos(id, expr) - - override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit = { - out.puts("while {") + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, + dataType: DataType): Unit = { + out.puts(s"if self.${calculatedFlagForName(instName)}.get() {") out.inc + out.puts(s"return Ok(${privateMemberName(instName)});") + out.dec + out.puts("}") } - override def handleAssignmentRepeatUntil(id: Identifier, expr: String, isRaw: Boolean): Unit = { - val tempVar = if (isRaw) { - translator.doLocalName(Identifier.ITERATOR2) - } else { - translator.doLocalName(Identifier.ITERATOR) + override def instanceCalculate(instName: Identifier, dataType: DataType, value: Ast.expr): Unit = { + dataType match { + case _: UserType => + handleAssignmentSimple(instName, s"${translator.remove_deref(expression(value))}.clone()") + case _: StrType => + handleAssignmentSimple(instName, s"${translator.remove_deref(expression(value))}.to_string()") + case _: BytesType => + handleAssignmentSimple(instName, s"${translator.rem_vec_amp(translator.remove_deref(expression(value)))}.to_vec()") + case _: ArrayType => + handleAssignmentSimple(instName, s"${translator.rem_vec_amp(translator.remove_deref(expression(value)))}.to_vec()") + case _: EnumType => + handleAssignmentSimple(instName, s"${translator.remove_deref(expression(value))}") + case _ => + handleAssignmentSimple(instName, s"(${expression(value)}) as ${kaitaiPrimitiveToNativeType(dataType)}") } - out.puts(s"let $tempVar = $expr;") - out.puts(s"${privateMemberName(id)}.append($tempVar);") } - override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit = { - typeProvider._currentIteratorType = Some(dataType) - out.puts(s"!(${expression(untilExpr)})") + override def instanceReturn(instName: InstanceIdentifier, + attrType: DataType): Unit = { + out.puts(s"Ok(${privateMemberName(instName)})") + } + + override def enumDeclaration(curClass: List[String], + enumName: String, + enumColl: Seq[(Long, EnumValueSpec)]): Unit = { + + val enumClass = types2class(curClass ::: List(enumName)) + + // Set up the actual enum definition + out.puts(s"#[derive(Debug, PartialEq, Clone)]") + out.puts(s"pub enum $enumClass {") + out.inc + + enumColl.foreach { + case (_, label) => + if (label.doc.summary.isDefined) + universalDoc(label.doc) + + out.puts(s"${type2class(label.name)},") + } + out.puts("Unknown(i64),") + out.dec - out.puts("} { }") + out.puts("}") + out.puts + + // Set up parsing enums from the underlying value + out.puts(s"impl TryFrom for $enumClass {") + + out.inc + // We typically need the lifetime in KError for returning byte slices from stream; + // because we can only return `UnknownVariant` which contains a Copy type, it's safe + // to declare that the error type is `'static` + out.puts(s"type Error = KError;") + out.puts(s"fn try_from(flag: i64) -> KResult<$enumClass> {") + + out.inc + out.puts(s"match flag {") + + out.inc + enumColl.foreach { + case (value, label) => + out.puts(s"$value => Ok($enumClass::${type2class(label.name)}),") + } + out.puts(s"_ => Ok($enumClass::Unknown(flag)),") + out.dec + + out.puts("}") + out.dec + out.puts("}") + out.dec + out.puts("}") + out.puts + + out.puts(s"impl From<&$enumClass> for i64 {") + out.inc + out.puts(s"fn from(v: &$enumClass) -> Self {") + out.inc + out.puts(s"match *v {") + out.inc + enumColl.foreach { + case (value, label) => + out.puts(s"$enumClass::${type2class(label.name)} => $value,") + } + out.puts(s"$enumClass::Unknown(v) => v") + out.dec + out.puts("}") + out.dec + out.puts("}") + out.dec + out.puts("}") + out.puts + + out.puts(s"impl Default for $enumClass {") + out.inc + out.puts(s"fn default() -> Self { $enumClass::Unknown(0) }") + out.dec + out.puts("}") + out.puts + } + + override def universalDoc(doc: DocSpec): Unit = { + out.puts + out.puts( "/**") + + doc.summary.foreach(docStr => out.putsLines(" * ", docStr)) + + doc.ref.foreach { + case TextRef(text) => + out.putsLines(" * ", s"\\sa $text") + case UrlRef(url, text) => + out.putsLines(" * ", s"\\sa $url $text") + } + + out.puts( " */") + } + + override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = + handleAssignmentRepeatEos(id, expr) + + def handleAssignmentParams(id: Identifier, expr: String): Unit = { + val paramId = typeProvider.nowClass.params.find(s => s.id == id) + var need_clone = false + if (paramId.isDefined) { + need_clone = !translator.is_copy_type(paramId.get.dataType) + } + paramId.get.dataType match { + case _: EnumType => + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = $expr.clone();") + case _ => + if (need_clone) + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = $expr.clone();") + else + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = $expr;") + } } override def handleAssignmentSimple(id: Identifier, expr: String): Unit = { - out.puts(s"${privateMemberName(id)} = $expr;") + val seqId = translator.findMember(idToStr(id)) + var done = false + var refcell = false + if (seqId.isDefined) { + val idType = seqId.get.dataType + idType match { + case t: UserType => + refcell = true + case _: BytesType => refcell = true + case _: ArrayType => refcell = true + case _: StrType => refcell = true + case _: EnumType => + done = true + out.puts( + s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = $expr;" + ) + case _: SwitchType => + done = true + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = Some($expr);") + case _ => + } + if (refcell) { + val typeName = kaitaiTypeToNativeType(Some(id), typeProvider.nowClass, idType) + if (typeName.startsWith("Option<")) { + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = Some($expr);") + } else { + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = $expr;") + } + done = true + } + } + if (!done) { + var inst = false + id match { + case _: InstanceIdentifier => + inst = true + case RawIdentifier(inner) => inner match { + case _: InstanceIdentifier => + inst = true + case _ => + } + case EndianIdentifier => + inst = true + case _ => + } + if (inst) { + done = true + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = $expr;") + } + } + if (!done) + out.puts(s"*${RustCompiler.privateMemberName(id, writeAccess = true)} = $expr;") } - override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = { + override def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit = + out.puts(s"let $id = $expr;") + + def translate_args(args: Seq[Ast.expr], into: Boolean): String = { + Utils.join(args.map { a => + val typ = translator.detectType(a) + var byref = "" + val t = kaitaiTypeToNativeType(None, typeProvider.nowClass, typ) + var try_into = "" + typ match { + case _: NumericType => + if (into) { + try_into = s".try_into().map_err(|_| KError::CastError)?" + } + case _ => + if (!translator.is_copy_type(typ)) + byref = "&" + } + var translated = translator.translate(a) + if (translated == "_r") // _root + translated = "OptRc::new(&_rrc)" + if (try_into.nonEmpty) + s"$byref($translated)$try_into" + else + s"$byref$translated" + }, "", ", ", "") + } + + override def parseExpr(dataType: DataType, + assignType: DataType, + io: String, + defEndian: Option[FixedEndian]): String = { + var addParams = "" dataType match { case t: ReadableType => - s"$io.read_${t.apiCall(defEndian)}()?" - case blt: BytesLimitType => - s"$io.read_bytes(${expression(blt.size)})?" - case _: BytesEosType => - s"$io.read_bytes_full()?" + t match { + case IntMultiType(_, _, None) => + s"if *${privateMemberName(EndianIdentifier)} == 1 { $io.read_${t.apiCall(Some(LittleEndian))}()?.into() } else { $io.read_${t.apiCall(Some(BigEndian))}()?.into() }" + case IntMultiType(_, _, Some(e)) => + s"$io.read_${t.apiCall(Some(e))}()?.into()" + case _ => + s"$io.read_${t.apiCall(defEndian)}()?.into()" + } + case _: BytesEosType => s"$io.read_bytes_full()?.into()" case BytesTerminatedType(terminator, include, consume, eosError, _) => val term = terminator.head & 0xff - s"$io.read_bytes_term($term, $include, $consume, $eosError)?" - case BitsType1(bitEndian) => - s"$io.read_bits_int(1)? != 0" - case BitsType(width: Int, bitEndian) => - s"$io.read_bits_int($width)?" + s"$io.read_bytes_term($term, $include, $consume, $eosError)?.into()" + case b: BytesLimitType => s"$io.read_bytes(${expression(b.size)} as usize)?.into()" + case BitsType1(bitEndian) => s"$io.read_bits_int_${bitEndian.toSuffix}(1)? != 0" + case BitsType(width: Int, bitEndian) => s"$io.read_bits_int_${bitEndian.toSuffix}($width)?" case t: UserType => - val addParams = Utils.join(t.args.map((a) => translator.translate(a)), "", ", ", ", ") + addParams = translate_args(t.args, into = true) + val userType = t match { + case t: UserType => + val baseName = t.classSpec match { + case Some(spec) => types2class(spec.name) + case None => types2class(t.name) + } + s"$baseName" + } + val root = s"Some(${self_name()}.${privateMemberName(RootIdentifier)}.clone())" val addArgs = if (t.isExternal(typeProvider.nowClass)) { - "" + ", None, None" } else { - val parent = t.forcedParent match { - case Some(USER_TYPE_NO_PARENT) => "null" - case Some(fp) => translator.translate(fp) - case None => "self" + var parent = t.forcedParent match { + case Some(USER_TYPE_NO_PARENT) => "None" + case Some(fp) => s"Some(SharedType::new(${translator.translate(fp)}.clone()))" + case None => s"Some(${self_name()}._self.clone())" } - val addEndian = t.classSpec.get.meta.endian match { - case Some(InheritedEndian) => s", ${privateMemberName(EndianIdentifier)}" - case _ => "" + t.classSpec.get.parentType match { + case CalcKaitaiStructType(_) => parent = "None" + case _ => } - s", $parent, ${privateMemberName(RootIdentifier)}$addEndian" + s", $root, $parent" } - - s"Box::new(${translator.types2classAbs(t.classSpec.get.name)}::new(self.stream, self, _root)?)" + var io2 = "" + var streamType = "" + if (io == privateMemberName(IoIdentifier)) { + io2 = s"&*$io" + streamType = "_" + } else { + io2 = translator.ensure_amp(io) + streamType = "BytesReader" + } + if (addParams.isEmpty) { + if (t.classSpec.isDefined) t.classSpec.get.meta.endian match { + case Some(InheritedEndian) => + out.puts(s"let f = |t : &mut $userType| Ok(t.set_endian(*${privateMemberName(EndianIdentifier)}));") + out.puts(s"let t = Self::read_into_with_init::<$streamType, $userType>($io2$addArgs, &f)?.into();") + case _ => + out.puts(s"let t = Self::read_into::<$streamType, $userType>($io2$addArgs)?.into();") + } + } else { + out.puts(s"let f = |t : &mut $userType| Ok(t.set_params($addParams));") + out.puts(s"let t = Self::read_into_with_init::<$streamType, $userType>($io2$addArgs, &f)?.into();") + } + return s"t" + case _ => s"// parseExpr($dataType, $assignType, $io, $defEndian)" } } override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Seq[Byte]], include: Boolean): String = { + val ioId = privateMemberName(IoIdentifier) val expr1 = padRight match { - case Some(padByte) => s"$kstreamName::bytesStripRight($expr0, $padByte)" + case Some(padByte) => s"bytes_strip_right(&$expr0, $padByte).into()" case None => expr0 } val expr2 = terminator match { case Some(term) => val t = term.head & 0xff - s"$kstreamName::bytesTerminate($expr1, $t, $include)" + s"bytes_terminate(&$expr1, $t, $include).into()" case None => expr1 } expr2 } - var switchIfs = false - val NAME_SWITCH_ON = Ast.expr.Name(Ast.identifier(Identifier.SWITCH_ON)) + override def attrFixedContentsParse(attrName: Identifier, + contents: String): Unit = + out.puts(s"// attrFixedContentsParse($attrName, $contents)") - override def switchStart(id: Identifier, on: Ast.expr): Unit = { - val onType = translator.detectType(on) + override def publicMemberName(id: Identifier): String = + s"// publicMemberName($id)" - switchIfs = onType match { - case _: ArrayTypeInStream | _: BytesType => true - case _ => false - } + override def localTemporaryName(id: Identifier): String = + s"_t_${idToStr(id)}" - if (!switchIfs) { - out.puts(s"match ${expression(on)} {") - out.inc - } + override def userTypeDebugRead(id: String, dataType: DataType, assignType: DataType): Unit = { + // we already have splitted construction of object and read method } - def switchCmpExpr(condition: Ast.expr): String = - expression( - Ast.expr.Compare( - NAME_SWITCH_ON, - Ast.cmpop.Eq, - condition - ) - ) + override def switchRequiresIfs(onType: DataType): Boolean = onType match { + case _: IntType | _: EnumType => false + case _ => true + } - override def switchCaseFirstStart(condition: Ast.expr): Unit = { - if (switchIfs) { - out.puts(s"if ${switchCmpExpr(condition)} {") - out.inc - } else { - switchCaseStart(condition) - } + override def switchStart(id: Identifier, on: Ast.expr): Unit = { + switch_else_exist = false + out.puts(s"match ${expression(on)} {") + out.inc } override def switchCaseStart(condition: Ast.expr): Unit = { - if (switchIfs) { - out.puts(s"else if ${switchCmpExpr(condition)} {") - out.inc - } else { - out.puts(s"${expression(condition)} => {") - out.inc - } + out.puts(s"${expression(condition)} => {") + out.inc } override def switchCaseEnd(): Unit = { - if (switchIfs) { - out.dec - out.puts("}") - } else { - out.dec - out.puts("},") - } + out.dec + out.puts("}") } + var switch_else_exist = false + override def switchElseStart(): Unit = { - if (switchIfs) { - out.puts("else {") - out.inc - } else { - out.puts("_ => {") - out.inc - } + switch_else_exist = true + out.puts("_ => {") + out.inc } - override def switchElseEnd(): Unit = { + override def switchEnd(): Unit = { + if (!switch_else_exist) { + out.puts("_ => {}") + } out.dec out.puts("}") } - override def switchEnd(): Unit = universalFooter + override def switchIfStart(id: Identifier, on: Ast.expr, onType: DataType): Unit = { + out.puts("{") + out.inc + out.puts(s"let on = ${translator.remove_deref(expression(on))};") + } - override def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = { - out.puts(s" pub ${idToStr(attrName)}: Option<${kaitaiType2NativeType(attrType)}>,") + override def switchIfCaseFirstStart(condition: Ast.expr): Unit = { + out.puts(s"if *on == ${expression(condition)} {") + out.inc } - override def instanceDeclHeader(className: List[String]): Unit = { + override def switchIfCaseStart(condition: Ast.expr): Unit = { + out.puts(s"else if *on == ${expression(condition)} {") + out.inc + } + + override def switchIfCaseEnd(): Unit = { out.dec out.puts("}") - out.puts - - out.puts(s"impl ${type2class(className)} {") - out.inc } - override def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = { - out.puts(s"fn ${idToStr(instName)}(&mut self) -> ${kaitaiType2NativeType(dataType)} {") + override def switchIfElseStart(): Unit = { + out.puts("else {") out.inc } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { - out.puts(s"if let Some(x) = ${privateMemberName(instName)} {") - out.inc - out.puts("return x;") + override def switchIfEnd(): Unit = { out.dec out.puts("}") - out.puts } - override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { - out.puts(s"return ${privateMemberName(instName)};") - } + override def allocateIO(id: Identifier, rep: RepeatSpec): String = {//= privateMemberName(IoIdentifier) + val memberName = privateMemberName(id) + val ioId = IoStorageIdentifier(id) + + var newStreamRaw = s"$memberName" + val ioName = rep match { + case NoRepeat => + var newStream = newStreamRaw + val localIO = localTemporaryName(ioId) + val ids = idToStr(id) + out.puts(s"let $ids = $newStream;") + newStream = ids + out.puts(s"let $localIO = BytesReader::from($newStream.clone());") + s"&$localIO" + case _ => + val ids = idToStr(id) + val localIO = s"io_$ids" + out.puts(s"let $ids = $newStreamRaw;") + newStreamRaw = ids + out.puts(s"let $localIO = BytesReader::from($newStreamRaw.last().unwrap().clone());") + s"&$localIO" + } - override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = { - val enumClass = type2class(curClass ::: List(enumName)) + ioName + } - out.puts(s"enum $enumClass {") + def switchTypeEnum(id: Identifier, st: SwitchType): Unit = { + // Because Rust can't handle `AnyType` in the type hierarchy, + // we generate an enum with all possible variations + val enum_typeName = kaitaiTypeToNativeType( + Some(id), + typeProvider.nowClass, + st, + excludeOptionWrapper = true + ) + out.puts("#[derive(Debug, Clone)]") + out.puts(s"pub enum $enum_typeName {") out.inc - enumColl.foreach { case (id, label) => - universalDoc(label.doc) - out.puts(s"${value2Const(label.name)},") + val types = st.cases.values.toSet + + { + val types_set = scala.collection.mutable.Set[String]() + types.foreach(t => { + // Because this switch type will itself be in an option, we can exclude it from user types + val variantName = switchVariantName(id, t) + val typeName = kaitaiTypeToNativeType( + Some(id), + typeProvider.nowClass, + t, + excludeOptionWrapper = true + ) + val new_typename = types_set.add(typeName) + // same typename could be in case of different endianness + if (new_typename) { + out.puts(s"$variantName($typeName),") + } + }) } out.dec out.puts("}") - } - def value2Const(label: String) = Utils.upperUnderscoreCase(label) + var enum_only_numeric = true + types.foreach { + case _: NumericType => // leave true + case _ => enum_only_numeric = false + } - def idToStr(id: Identifier): String = { - id match { - case SpecialIdentifier(name) => name - case NamedIdentifier(name) => Utils.lowerCamelCase(name) - case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" - case InstanceIdentifier(name) => Utils.lowerCamelCase(name) - case RawIdentifier(innerId) => "_raw_" + idToStr(innerId) + // generate only if switch types are different + { + val types_set = scala.collection.mutable.Set[String]() + // add helper methods From + types.foreach(t => { + // Because this switch type will itself be in an option, we can exclude it from user types + val variantName = switchVariantName(id, t) + var typeName = kaitaiTypeToNativeType( + Some(id), + typeProvider.nowClass, + t, + excludeOptionWrapper = true) + + val new_typename = types_set.add(typeName) + if (new_typename) { + // generate helpers to convert enum into variant (let x : Rc = enum1.into()) + if (!enum_only_numeric) { + val asOption = "^Option<.*".r + val suffix = kaitaiTypeToNativeType(Some(id), typeProvider.nowClass, t) match { + case asOption() => s".as_ref().unwrap()" + case _ => "" + } + out.puts(s"impl From<&$enum_typeName> for $typeName {") + out.inc + out.puts(s"fn from(v: &$enum_typeName) -> Self {") + out.inc + out.puts(s"if let $enum_typeName::$variantName(x) = v {") + out.inc + out.puts(s"return x$suffix.clone();") + out.dec + out.puts("}") + out.puts(s"""panic!("expected $enum_typeName::$variantName, got {:?}", v)""") + out.dec + out.puts("}") + out.dec + out.puts("}") + } + // special case for Bytes(Vec[u8]) (else switch) + t match { + case _ : BytesType => + typeName = s"Vec" + case _ => + } + // generate helpers to create enum from variant (let enum1 = Var1.into()) + out.puts(s"impl From<$typeName> for $enum_typeName {") + out.inc + out.puts(s"fn from(v: $typeName) -> Self {") + out.inc + out.puts(s"Self::$variantName(v)") + out.dec + out.puts("}") + out.dec + out.puts("}") + if (enum_only_numeric) { + out.puts(s"impl From<&$enum_typeName> for $typeName {") + out.inc + out.puts(s"fn from(e: &$enum_typeName) -> Self {") + out.inc + out.puts(s"if let $enum_typeName::$variantName(v) = e {") + out.inc + out.puts(s"return *v") + out.dec + out.puts("}") + out.puts(s"""panic!(\"trying to convert from enum $enum_typeName::$variantName to $typeName, enum value {:?}\", e)""") + out.dec + out.puts("}") + out.dec + out.puts("}") + } + } + }) + } + if (enum_only_numeric) { + out.puts(s"impl From<&$enum_typeName> for usize {") + out.inc + out.puts(s"fn from(e: &$enum_typeName) -> Self {") + out.inc + out.puts(s"match e {") + out.inc + val variants_set = scala.collection.mutable.Set[String]() + types.foreach(t => { + val variantName = switchVariantName(id, t) + val new_typename = variants_set.add(variantName) + if (new_typename) { + out.puts(s"$enum_typeName::$variantName(v) => *v as usize,") + } + }) + out.dec + out.puts("}") + out.dec + out.puts("}") + out.dec + out.puts("}") + out.puts } - } - override def privateMemberName(id: Identifier): String = { - id match { - case IoIdentifier => s"self.stream" - case RootIdentifier => s"_root" - case ParentIdentifier => s"_parent" - case _ => s"self.${idToStr(id)}" + // generate helper method with name from variant type (to convert enum into variant and call variant method inside) + // only if there is only single variant + // if more than 1 - Kaitai will do casting + if (types.size == 1) { + val types_set = scala.collection.mutable.Set[String]() + val attrs_set = scala.collection.mutable.Set[String]() + types.foreach(t => { + val typeName = kaitaiTypeToNativeType(Some(id), typeProvider.nowClass, t, cleanTypename = true) + if (types_set.add(typeName)) { + t match { + case ut: UserType => + ut.classSpec.get.seq.foreach( + attr => { + val attrName = attr.id + if (attrs_set.add(idToStr(attrName))) { + out.puts(s"impl $enum_typeName {") + out.inc + val fn = idToStr(attrName) + var nativeType = kaitaiTypeToNativeType(Some(attrName), typeProvider.nowClass, attr.dataTypeComposite, cleanTypename = true) + var nativeTypeEx = kaitaiTypeToNativeType(Some(attrName), typeProvider.nowClass, attr.dataTypeComposite) + val typeNameEx = kaitaiTypeToNativeType(Some(id), typeProvider.nowClass, t) + val x = if (typeNameEx.startsWith("Option<")) "x.as_ref().unwrap()" else "x" + var clone = "" + if (nativeTypeEx.startsWith("OptRc<")) { + nativeType = s"$nativeTypeEx" + clone = ".clone()" + } else + nativeType = s"Ref<$nativeType>" + out.puts(s"pub fn $fn(&self) -> $nativeType {") + out.inc + out.puts("match self {") + out.inc + out.puts(s"$enum_typeName::$typeName(x) => $x.$fn.borrow()$clone,") + //out.puts("_ => panic!(\"wrong variant: {:?}\", self),") + out.dec + out.puts("}") + out.dec + out.puts("}") + out.dec + out.puts("}") + } + } + ) + case _ => + } + } + }) } + } - override def publicMemberName(id: Identifier) = idToStr(id) + def switchVariantName(id: Identifier, attrType: DataType): String = + attrType match { + // TODO: Not exhaustive + case Int1Type(false) => "U1" + case IntMultiType(false, Width2, _) => "U2" + case IntMultiType(false, Width4, _) => "U4" + case IntMultiType(false, Width8, _) => "U8" + + case Int1Type(true) => "S1" + case IntMultiType(true, Width2, _) => "S2" + case IntMultiType(true, Width4, _) => "S4" + case IntMultiType(true, Width8, _) => "S8" + + case FloatMultiType(Width4, _) => "F4" + case FloatMultiType(Width8, _) => "F8" + + case BitsType(_,_) => "Bits" + case _: BooleanType => "Boolean" + case CalcIntType => "Int" + case CalcFloatType => "Float" + case _: StrType => "String" + case _: BytesType => "Bytes" - override def localTemporaryName(id: Identifier): String = s"$$_t_${idToStr(id)}" + case t: UserType => + kaitaiTypeToNativeType( + Some(id), + typeProvider.nowClass, + t, + cleanTypename = true + ) + case t: EnumType => + kaitaiTypeToNativeType( + Some(id), + typeProvider.nowClass, + t, + excludeOptionWrapper = true + ) + case t: ArrayType => s"Arr${switchVariantName(id, t.elType)}" + } - override def paramName(id: Identifier): String = s"${idToStr(id)}" + override def ksErrorName(err: KSError): String = RustCompiler.ksErrorName(err) - def kaitaiType2NativeType(attrType: DataType): String = { - attrType match { - case Int1Type(false) => "u8" - case IntMultiType(false, Width2, _) => "u16" - case IntMultiType(false, Width4, _) => "u32" - case IntMultiType(false, Width8, _) => "u64" + override def attrValidateExpr( + attr: AttrLikeSpec, + checkExpr: Ast.expr, + err: KSError, + errArgs: List[Ast.expr] + ): Unit = { + val srcPathStr = translator.translate(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + val validationKind = RustCompiler.validationErrorKind(err.asInstanceOf[ValidationError]) + out.puts(s"if !(${expression(checkExpr)}) {") + out.inc + out.puts(s"""return Err(${ksErrorName(err)}(ValidationFailedError { kind: $validationKind, src_path: $srcPathStr.to_string() }));""") + out.dec + out.puts("}") + } - case Int1Type(true) => "i8" - case IntMultiType(true, Width2, _) => "i16" - case IntMultiType(true, Width4, _) => "i32" - case IntMultiType(true, Width8, _) => "i64" + override def attrParse2( + id: Identifier, + dataType: DataType, + io: String, + rep: RepeatSpec, + isRaw: Boolean, + defEndian: Option[FixedEndian], + assignTypeOpt: Option[DataType] = None + ): Unit = { + dataType match { + case t: EnumType => + val expr = + t.basedOn match { + case inst: ReadableType => + s"($io.read_${inst.apiCall(defEndian)}()? as i64).try_into()?" + case BitsType(width: Int, bitEndian) => + s"($io.read_bits_int_${bitEndian.toSuffix}($width)? as i64).try_into()?" + } + handleAssignment(id, expr, rep, isRaw) + case _ => + super.attrParse2(id, dataType, io, rep, isRaw, defEndian, assignTypeOpt) + } + } - case FloatMultiType(Width4, _) => "f32" - case FloatMultiType(Width8, _) => "f64" + override def classToString(toStringExpr: Ast.expr): Unit = { + importList.add("use std::fmt;") + out.puts(s"impl fmt::Display for ${classTypeName(typeProvider.nowClass)} {") + out.inc + out.puts(s"fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {") + out.inc + out.puts(s"""write!(f, "{}", ${translator.translate(toStringExpr)})""") + out.dec + out.puts("}") + out.dec + out.puts("}") + } +} - case BitsType(_, _) => "u64" +object RustCompiler + extends LanguageCompilerStatic + with StreamStructNames + with UpperCamelCaseClasses + with ExceptionNames { + override def getCompiler(tp: ClassTypeProvider, + config: RuntimeConfig): LanguageCompiler = + new RustCompiler(tp, config) - case _: BooleanType => "bool" - case CalcIntType => "i32" - case CalcFloatType => "f64" + var in_reader = false - case _: StrType => "String" - case _: BytesType => "Vec" + def self_name(): String = { + if (in_reader) "self_rc" else "self" + } - case t: UserType => t.classSpec match { - case Some(cs) => s"Box<${type2class(cs.name)}>" - case None => s"Box<${type2class(t.name)}>" - } + def privateMemberName(id: Identifier, writeAccess: Boolean = false): String = id match { + case IoIdentifier => "_io" + case RootIdentifier => "_root" + case ParentIdentifier => "_parent" + case _ => + val n = s"${self_name()}.${idToStr(id)}" + if (writeAccess) + s"$n.borrow_mut()" + else + s"$n.borrow()" + } - case t: EnumType => t.enumSpec match { - case Some(cs) => s"Box<${type2class(cs.name)}>" - case None => s"Box<${type2class(t.name)}>" - } + def idToStr(id: Identifier): String = id match { + case SpecialIdentifier(n) => n + case NamedIdentifier(n) => n + case InstanceIdentifier(n) => n + case NumberedIdentifier(idx) => s"${NumberedIdentifier.TEMPLATE}$idx" + case RawIdentifier(inner) => s"${idToStr(inner)}_raw" // use suffix naming, easy to replace, like in anyField() + case IoStorageIdentifier(inner) => s"${idToStr(inner)}_io" // same here + } - case at: ArrayType => s"Vec<${kaitaiType2NativeType(at.elType)}>" + @tailrec + def rootClassTypeName(c: ClassSpec, isRecurse: Boolean = false): String = { + if (!isRecurse && c.isTopLevel) + "Self" + else if (c.isTopLevel) + classTypeName(c) + else + rootClassTypeName(c.upClass.get, isRecurse = true) + } - case KaitaiStreamType | OwnedKaitaiStreamType => s"Option>" - case KaitaiStructType | CalcKaitaiStructType(_) => s"Option>" + override def kstreamName = "KStream" + override def kstructName = "KStruct" - case st: SwitchType => kaitaiType2NativeType(st.combinedType) - } + def kstructUnitName = "KStructUnit" + + override def ksErrorName(err: KSError): String = err match { + case EndOfStreamError => "KError::Eof" + case UndecidedEndiannessError => "KError::UndecidedEndianness" + case ConversionError => "KError::CastError" + case _: ValidationError => "KError::ValidationFailed" } - def kaitaiType2Default(attrType: DataType): String = { - attrType match { - case Int1Type(false) => "0" - case IntMultiType(false, Width2, _) => "0" - case IntMultiType(false, Width4, _) => "0" - case IntMultiType(false, Width8, _) => "0" + def validationErrorKind(err: ValidationError): String = { + val kind = err match { + case _: ValidationNotEqualError => "NotEqual" + case _: ValidationLessThanError => "LessThan" + case _: ValidationGreaterThanError => "GreaterThan" + case _: ValidationNotAnyOfError => "NotAnyOf" + case _: ValidationNotInEnumError => "NotInEnum" + case _: ValidationExprError => "Expr" + } + s"ValidationKind::$kind" + } - case Int1Type(true) => "0" - case IntMultiType(true, Width2, _) => "0" - case IntMultiType(true, Width4, _) => "0" - case IntMultiType(true, Width8, _) => "0" + def classTypeName(c: ClassSpec): String = + s"${types2class(c.name)}" - case FloatMultiType(Width4, _) => "0" - case FloatMultiType(Width8, _) => "0" + def types2class(names: List[String]): String = + // TODO: Use `mod` to scope types instead of weird names + names.map(x => type2class(x)).mkString("_") - case BitsType(_, _) => "0" + def kaitaiTypeToNativeType(id: Option[Identifier], + cs: ClassSpec, + attrType: DataType, + excludeOptionWrapper: Boolean = false, + cleanTypename: Boolean = false): String = + attrType match { + // TODO: Not exhaustive + case _: NumericType | _: BooleanType | _: StrType | _: BytesType => + kaitaiPrimitiveToNativeType(attrType) - case _: BooleanType => "false" - case CalcIntType => "0" - case CalcFloatType => "0" + case t: UserType => + val baseName = t.classSpec match { + case Some(spec) => types2class(spec.name) + case None => types2class(t.name) + } + if (cleanTypename) + baseName + else + s"OptRc<$baseName>" + + case t: EnumType => + val baseName = t.enumSpec match { + case Some(spec) => s"${types2class(spec.name)}" + case None => s"${types2class(t.name)}" + } + baseName + + case t: ArrayType => + s"Vec<${kaitaiTypeToNativeType(id, cs, t.elType, excludeOptionWrapper = true)}>" + + case _: SwitchType => + val typeName = id.get match { + case name: NamedIdentifier => + s"${types2class(cs.name ::: List(name.name))}" + case name: InstanceIdentifier => + s"${types2class(cs.name ::: List(name.name))}" + case _ => kstructUnitName + } - case _: StrType => "\"\"" - case _: BytesType => "vec!()" + if (excludeOptionWrapper) typeName else s"Option<$typeName>" - case t: UserType => "Default::default()" - case t: EnumType => "Default::default()" + case KaitaiStreamType => "BytesReader" + case CalcKaitaiStructType(_) => kstructUnitName + } - case ArrayTypeInStream(inType) => "vec!()" + def kaitaiPrimitiveToNativeType(attrType: DataType): String = attrType match { + case Int1Type(false) => "u8" + case IntMultiType(false, Width2, _) => "u16" + case IntMultiType(false, Width4, _) => "u32" + case IntMultiType(false, Width8, _) => "u64" - case KaitaiStreamType | OwnedKaitaiStreamType => "None" - case KaitaiStructType => "None" + case Int1Type(true) => "i8" + case IntMultiType(true, Width2, _) => "i16" + case IntMultiType(true, Width4, _) => "i32" + case IntMultiType(true, Width8, _) => "i64" - case _: SwitchType => "" - // TODO - } - } + case FloatMultiType(Width4, _) => "f32" + case FloatMultiType(Width8, _) => "f64" - def type2class(names: List[String]) = types2classRel(names) + case BitsType(_,_) => "u64" - def type2classAbs(names: List[String]) = - names.mkString("::") + case _: BooleanType => "bool" + case CalcIntType => "i32" + case CalcFloatType => "f64" - override def ksErrorName(err: KSError): String = RustCompiler.ksErrorName(err) -} + case _: StrType => "String" + case _: BytesType => "Vec" -object RustCompiler extends LanguageCompilerStatic - with StreamStructNames - with UpperCamelCaseClasses - with ExceptionNames { - override def getCompiler( - tp: ClassTypeProvider, - config: RuntimeConfig - ): LanguageCompiler = new RustCompiler(tp, config) - - override def kstructName = "&Option>" - override def kstreamName = "&mut S" - override def ksErrorName(err: KSError): String = ??? - - def types2class(typeName: Ast.typeId) = { - typeName.names.map(type2class).mkString( - if (typeName.absolute) "__" else "", - "__", - "" - ) + case ArrayTypeInStream(inType) => s"Vec<${kaitaiPrimitiveToNativeType(inType)}>" } - - def types2classRel(names: List[String]) = - names.map(type2class).mkString("__") } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala index af1c5efd6..788d49182 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala @@ -1,19 +1,31 @@ package io.kaitai.struct.translators +import io.kaitai.struct.format._ +import io.kaitai.struct.datatype._ import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.exprlang.Ast.expr -import io.kaitai.struct.format.{EnumSpec, Identifier} +import io.kaitai.struct.format.{EnumSpec, Identifier, IoIdentifier, ParentIdentifier, RootIdentifier} import io.kaitai.struct.languages.RustCompiler -import io.kaitai.struct.{RuntimeConfig, Utils} +import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils} + +class RustTranslator(provider: TypeProvider, config: RuntimeConfig) + extends BaseTranslator(provider) { + + import RustCompiler._ + + var lastFoundMemberClass: ClassSpec = provider.nowClass -class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends BaseTranslator(provider) { override def doByteArrayLiteral(arr: Seq[Byte]): String = - "vec!([" + arr.map((x) => - "%0#2x".format(x & 0xff) - ).mkString(", ") + "])" + "vec![" + arr.map(x => "%0#2xu8".format(x & 0xff)).mkString(", ") + "]" override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = - s"pack('C*', ${elts.map(translate).mkString(", ")})" + "vec![" + elts.map(translate).mkString(", ") + "]" + override def doArrayLiteral(t: DataType, value: Seq[Ast.expr]): String = { + t match { + case CalcStrType => "vec![" + value.map(v => translate(v)).mkString(".to_string(), ") + ".to_string()]" + case _ => "vec![" + value.map(v => translate(v)).mkString(", ") + "]" + } + } override val asciiCharQuoteMap: Map[Char, String] = Map( '\t' -> "\\t", @@ -23,89 +35,478 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base '\\' -> "\\\\" ) + override def strLiteralGenericCC(code: Char): String = + strLiteralUnicode(code) + override def strLiteralUnicode(code: Char): String = "\\u{%x}".format(code.toInt) - override def doLocalName(s: String) = { - s match { - case Identifier.ITERATOR => "tmpa" - case Identifier.ITERATOR2 => "tmpb" - case Identifier.INDEX => "i" - case _ => s"self.${doName(s)}" + def isSignedIntType(dt: DataType): Boolean = dt match { + case Int1Type(true) => true + case IntMultiType(true, _, _) => true + case CalcIntType => true + case _ => false + } + + def isAllDigits(x: String) = x forall Character.isDigit + + override def genericBinOp(left: Ast.expr, + op: Ast.operator, + right: Ast.expr, + extPrec: Int): String = { + val lt = detectType(left) + val rt = detectType(right) + val tl = translate(left) + val tr = translate(right) + + if (isSignedIntType(lt) && isSignedIntType(rt) && op == Ast.operator.Mod) { + s"modulo($tl as i64, $tr as i64)" + } else if (isSignedIntType(lt) && isSignedIntType(rt) && op == Ast.operator.RShift) { + // Arithmetic right shift on signed integer types, logical right shift on unsigned integer types + val ct = RustCompiler.kaitaiPrimitiveToNativeType(TypeDetector.combineTypes(lt, rt)) + s"((($tl as u64) >> $tr) as $ct)" + } else { + if (lt == rt && isAllDigits(tl) && isAllDigits(tr)) { + // let rust decide final type + s"($tl ${binOp(op)} $tr)" + } else { + val ct = RustCompiler.kaitaiPrimitiveToNativeType(TypeDetector.combineTypes(lt, rt)) + s"(($tl as $ct) ${binOp(op)} ($tr as $ct))" + } + } + } + + def unwrap(s: String): String = s + ".as_ref().unwrap()" + + override def doName(s: String): String = s match { + case Identifier.PARENT => s + case _ => + val refOpt = "^Option<.*".r + val memberFound = findMember(s) + val f = s"$s()" + if (memberFound.isDefined) { + memberFound.get match { + case vis: ValueInstanceSpec => + val aType = RustCompiler.kaitaiTypeToNativeType(Some(vis.id), provider.nowClass, vis.dataTypeComposite) + aType match { + case refOpt() => unwrap(s"$f?") + case _ => s"$f?" + } + case as: AttrSpec => + val aType = RustCompiler.kaitaiTypeToNativeType(Some(as.id), provider.nowClass, as.dataTypeComposite) + aType match { + case refOpt() => + if (!enum_numeric_only(as.dataTypeComposite)) { + unwrap(f) + } else f + case _ => f + } + case pd: ParamDefSpec => + val aType = RustCompiler.kaitaiTypeToNativeType(Some(pd.id), provider.nowClass, pd.dataTypeComposite) + aType match { + case refOpt() => unwrap(f) + case _ => f + } + case pis: ParseInstanceSpec => + val aType = RustCompiler.kaitaiTypeToNativeType(Some(pis.id), provider.nowClass, pis.dataTypeComposite) + aType match { + case refOpt() => unwrap(s"$f?") + case _ => s"$f?" + } + case _ => + f + } + } + else { + f + } + } + + def updateLastFoundMemberClass(dt: DataType) { + if (dt.isInstanceOf[UserType]) { + val s = dt.asInstanceOf[UserType] + if (s.classSpec.isDefined) { + lastFoundMemberClass = s.classSpec.get + } + } + } + + def resetLastFoundMemberClass() { + lastFoundMemberClass = provider.nowClass + } + + def findMember(attrName: String, c: ClassSpec = lastFoundMemberClass): Option[MemberSpec] = { + def findInClass(inClass: ClassSpec): Option[MemberSpec] = { + + inClass.seq.foreach { el => + if (idToStr(el.id) == attrName) { + updateLastFoundMemberClass(el.dataType) + return Some(el) + } + } + + inClass.params.foreach { el => + if (idToStr(el.id) == attrName) { + updateLastFoundMemberClass(el.dataType) + return Some(el) + } + } + + inClass.instances.foreach { case (instName, instSpec) => + if (idToStr(instName) == attrName) { + updateLastFoundMemberClass(instSpec.dataType) + return Some(instSpec) + } + } + + inClass.types.foreach{ t => + for { found <- findInClass(t._2) } + return Some(found) + } + None + } + + attrName match { + case Identifier.PARENT | Identifier.IO => + return None + case _ => + for { ms <- findInClass(c) } + return Some(ms) + + provider.asInstanceOf[ClassTypeProvider].allClasses.foreach { cls => + for { ms <- findInClass(cls._2) } + return Some(ms) + } + } + None + } + + def get_instance(cs: ClassSpec, s: String): Option[InstanceSpec] = { + var found : Option[InstanceSpec] = None + // look for instance + cs.instances.foreach { case (instName, instSpec) => + if (idToStr(instName) == s) { + found = Some(instSpec) + } + } + // look deeper + if (found.isEmpty) { + cs.types.foreach { + case (_, typeSpec) => + found = get_instance(typeSpec, s) + if (found.isDefined) { + return found + } + } + } + found + } + + override def anyField(value: expr, attrName: String): String = { + resetLastFoundMemberClass() + val t = translate(value) + var a = doName(attrName) + attrName match { + case Identifier.PARENT => a = a + unwrap(".get_value().borrow().upgrade()") + case _ => + } + var r = "" + if (need_deref(attrName)) { + if (t.charAt(0) == '*') { + r = s"$t.$a" + } else { + r = s"*$t.$a" + } + } else { + if (t.charAt(0) == '*') { + r = s"${t.substring(1)}.$a" + } else { + r = s"$t.$a" + } + } + r + } + + def rem_vec_amp(s: String): String = { + if (s.startsWith("&vec!")) { + s.substring(1) + } else { + s + } + } + + def ensure_vec_amp(s: String): String = { + if (s.startsWith("vec!")) { + s"&$s" + } else { + s + } + } + + def ensure_amp(s: String): String = { + if (s.charAt(0) == '&') { + s + } else { + s"&$s" + } + } + + def remove_deref(s: String): String = { + if (s.charAt(0) == '*') { + s.substring(1) + } else { + s } } - override def doName(s: String) = s + def ensure_deref(s: String): String = { + if (s.startsWith(self_name())) { + s"*$s" + } else { + s + } + } - override def doEnumByLabel(enumSpec: EnumSpec, label: String): String = { - val enumClass = types2classAbs(enumSpec.name) - s"$enumClass::${Utils.upperUnderscoreCase(label)}" + def enum_numeric_only(dataType: DataType): Boolean = { + var types : Set[DataType] = Set() + var enum_typename = false + dataType match { + case st: SwitchType => + types = st.cases.values.toSet + enum_typename = true + //case _: EnumType => return true + case _ => return false + } + var enum_only_numeric = true + types.foreach { + case _: NumericType => // leave unchanged + case _ => enum_only_numeric = false + } + enum_only_numeric } + + def is_copy_type(dataType: DataType): Boolean = dataType match { + case _: SwitchType => false + case _: UserType => false + case _: BytesType => false + case _: ArrayType => false + case _: StrType => false + case _: EnumType => false + case _ => true + } + + def need_deref(s: String, c: ClassSpec = provider.nowClass): Boolean = { + var deref = false + val memberFound = findMember(s, c) + if (memberFound.isDefined ) { + val spec = memberFound.get + spec match { + case _: AttrSpec | _: ParamDefSpec => + deref = !enum_numeric_only(spec.dataTypeComposite) + case _: ValueInstanceSpec | _: ParseInstanceSpec => + deref = true + case _ => + } + } + deref + } + + override def doLocalName(s: String): String = s match { + case Identifier.ITERATOR => "_tmpa" + case Identifier.ITERATOR2 => "_tmpb" + case Identifier.INDEX => "_i" + case Identifier.IO => s"${RustCompiler.privateMemberName(IoIdentifier)}" + case Identifier.ROOT => "_r" + case Identifier.PARENT => unwrap("_prc") + case _ => + // reset "looking for variable" context + resetLastFoundMemberClass() + val n = doName(s) + val deref = !n.endsWith(".as_str()") && !n.endsWith(".as_slice()") && need_deref(s) + if (deref) { + s"*${self_name()}.$n" + } else { + s"${self_name()}.$n" + } + } + override def doEnumCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = + s"${translate(left)} ${cmpOp(op)} ${translate(right)}" + + override def doInternalName(id: Identifier): String = + s"${doLocalName(idToStr(id))}" + + override def doEnumByLabel(enumSpec: EnumSpec, label: String): String = + s"${RustCompiler.types2class(enumSpec.name)}::${Utils.upperCamelCase(label)}" + + override def doNumericCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = { + val lt = detectType(left) + val rt = detectType(right) + if (lt != rt) { + val ct = RustCompiler.kaitaiPrimitiveToNativeType(TypeDetector.combineTypes(lt, rt)) + s"((${translate(left)} as $ct) ${cmpOp(op)} (${translate(right)} as $ct))" + } else { + s"${translate(left)} ${cmpOp(op)} ${translate(right)}" + } + } + + override def doStrCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = + s"${ensure_deref(translate(left))} ${cmpOp(op)} ${remove_deref(translate(right))}.to_string()" + override def doEnumById(enumSpec: EnumSpec, id: String): String = - // Just an integer, without any casts / resolutions - one would have to look up constants manually - id + s"($id as i64).try_into()?" override def arraySubscript(container: expr, idx: expr): String = - s"${translate(container)}[${translate(idx)}]" - override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = - "if " + translate(condition) + - " { " + translate(ifTrue) + " } else { " + - translate(ifFalse) + "}" + s"${remove_deref(translate(container))}[${translate(idx)} as usize]" + + override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = { + var to_type = "" + detectType(ifTrue) match { + case _: UserType => to_type = ".clone()" + case _: EnumType => to_type = ".clone()" + case _: StrType => to_type = ".to_string()" + case _: BytesType => to_type = ".to_vec()" + case _: CalcArrayType => to_type = ".clone()" + case _ => + } + if (to_type.isEmpty) { + s"if ${translate(condition)} { ${translate(ifTrue)} } else { ${translate(ifFalse)} }" + } else { + s"if ${translate(condition)} { ${remove_deref(translate(ifTrue))}$to_type } else { ${remove_deref(translate(ifFalse))}$to_type }" + } + } + + override def doCast(value: Ast.expr, castTypeName: DataType): String = { + val value_type = detectType(value) + if(castTypeName == value_type) + return translate(value) + + val ct = RustCompiler.kaitaiTypeToNativeType(None, provider.nowClass, castTypeName, excludeOptionWrapper = true) + var into = false + castTypeName match { + case _: UserType => into = true; + case CalcBytesType => into = true; + case _ => + } + if (into) { + s"Into::<$ct>::into(&${translate(value)})" + } else { + s"(${translate(value)} as $ct)" + } + } - // Predefined methods of various types - override def strConcat(left: expr, right: expr, extPrec: Int) = - "format!(\"{}{}\", " + translate(left) + ", " + translate(right) + ")" + override def translate(v: Ast.expr): String = { + v match { + case Ast.expr.EnumById(enumType, id, inType) => + id match { + case ifExp: Ast.expr.IfExp => + val enumSpec = provider.resolveEnum(inType, enumType.name) + val enumName = RustCompiler.types2class(enumSpec.name) + def toStr(ex: Ast.expr) = ex match { + case Ast.expr.IntNum(n) => s"$enumName::try_from($n)?" + case _ => super.translate(ex) + } + val ifTrue = toStr(ifExp.ifTrue) + val ifFalse = toStr(ifExp.ifFalse) + + "if " + translate(ifExp.condition) + s" { $ifTrue } else { $ifFalse }" + case _ => super.translate(v) + } + case _ => + super.translate(v) + } + } + + override def strConcat(left: Ast.expr, right: Ast.expr, extPrec: Int): String = + s"""format!("{}{}", ${translate(left)}, ${translate(right)})""" + + override def doInterpolatedStringLiteral(exprs: Seq[Ast.expr]): String = + if (exprs.isEmpty) { + doStringLiteral("") + } else { // format!("{expr1}{expr2}{expr3}") + var s = "format!(\"" + exprs.foreach(i => { s+= "{}" }) + s += "\", " + s += exprs.map(translate).mkString(", ") + s += ")" + s + } override def strToInt(s: expr, base: expr): String = translate(base) match { case "10" => - s"${translate(s)}.parse().unwrap()" + s"${translate(s)}.parse::().map_err(|_| KError::CastError)?" case _ => - "panic!(\"Converting from string to int in base {} is unimplemented\", " + translate(base) + ")" + s"i32::from_str_radix(${translate(s)}, ${translate(base)}).map_err(|_| KError::CastError)?" } override def enumToInt(v: expr, et: EnumType): String = - translate(v) + s"i64::from(&${translate(v)})" override def boolToInt(v: expr): String = - s"${translate(v)} as i32" + s"(${translate(v)}) as i32" override def floatToInt(v: expr): String = s"${translate(v)} as i32" override def intToStr(i: expr): String = - s"${translate(i)}.to_string()" + s"${remove_deref(translate(i))}.to_string()" override def bytesToStr(bytesExpr: String, encoding: String): String = - encoding match { - case "ASCII" => - s"String::from_utf8_lossy($bytesExpr)" - case _ => - "panic!(\"Unimplemented encoding for bytesToStr: {}\", \"" + encoding + "\")" - } + s"""bytes_to_str(&$bytesExpr, "$encoding")?""" + override def bytesLength(b: Ast.expr): String = - s"${translate(b, METHOD_PRECEDENCE)}.len()" + s"${remove_deref(translate(b))}.len()" + override def strLength(s: expr): String = - s"${translate(s, METHOD_PRECEDENCE)}.len()" - override def strReverse(s: expr): String = - s"${translate(s, METHOD_PRECEDENCE)}.graphemes(true).rev().flat_map(|g| g.chars()).collect()" + s"${remove_deref(translate(s))}.len()" + + override def strReverse(s: expr): String = { + val e = translate(s) + if (e.charAt(0) == '*') + s"reverse_string(&$e)?" + else + s"reverse_string($e)?" + } override def strSubstring(s: expr, from: expr, to: expr): String = - s"${translate(s, METHOD_PRECEDENCE)}.substring(${translate(from)}, ${translate(to)})" + s"${translate(s, METHOD_PRECEDENCE)}[${translate(from)}..${translate(to)}]" override def arrayFirst(a: expr): String = - s"${translate(a)}.first()" + s"${ensure_deref(translate(a))}.first().ok_or(KError::EmptyIterator)?" override def arrayLast(a: expr): String = - s"${translate(a)}.last()" + s"${ensure_deref(translate(a))}.last().ok_or(KError::EmptyIterator)?" override def arraySize(a: expr): String = - s"${translate(a)}.len()" - override def arrayMin(a: Ast.expr): String = - s"${translate(a)}.iter().min()" - override def arrayMax(a: Ast.expr): String = - s"${translate(a)}.iter().max()" + s"${remove_deref(translate(a))}.len()" + + def is_float_type(a: Ast.expr): Boolean = { + detectType(a) match { + case t: CalcArrayType => + t.elType match { + case _: FloatMultiType => true + case CalcFloatType => true + case _ => false + } + case t: ArrayType => + t.elType match { + case _: FloatMultiType => true + case _ => false + } + case _ => false + } + } + + override def arrayMin(a: Ast.expr): String = { + if (is_float_type(a)) { + s"${ensure_deref(translate(a))}.iter().reduce(|a, b| if (a.min(*b)) == *b {b} else {a}).ok_or(KError::EmptyIterator)?" + } else { + s"${ensure_deref(translate(a))}.iter().min().ok_or(KError::EmptyIterator)?" + } + } - def types2classAbs(names: List[String]) = - names match { - case List("kaitai_struct") => RustCompiler.kstructName - case _ => RustCompiler.types2classRel(names) + override def arrayMax(a: Ast.expr): String = { + if (is_float_type(a)) { + s"${ensure_deref(translate(a))}.iter().reduce(|a, b| if (a.max(*b)) == *b {b} else {a}).ok_or(KError::EmptyIterator)?" + } else { + s"${ensure_deref(translate(a))}.iter().max().ok_or(KError::EmptyIterator)?" } + } }