diff --git a/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala index a148abc43..9b29ae504 100644 --- a/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala @@ -75,6 +75,7 @@ class RustClassCompiler( curClass.instances.foreach { case (instName, instSpec) => compileInstance(curClass.name, instName, instSpec, curClass.meta.endian) } + lang.instanceDeclFooter(curClass.name) } override def compileInstance(className: List[String], instName: InstanceIdentifier, instSpec: InstanceSpec, endian: Option[Endianness]): Unit = { diff --git a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala index 38886e3b4..8f3edacae 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala @@ -1,605 +1,712 @@ package io.kaitai.struct.languages -import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils, _} import io.kaitai.struct.datatype.DataType._ -import io.kaitai.struct.datatype.{CalcEndian, DataType, FixedEndian, InheritedEndian} +import io.kaitai.struct.datatype.{DataType, Endianness, FixedEndian} import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.format.{NoRepeat, RepeatEos, RepeatExpr, RepeatSpec, _} +import io.kaitai.struct.format.{RepeatSpec, _} import io.kaitai.struct.languages.components._ -import io.kaitai.struct.translators.{RustTranslator, TypeDetector} +import io.kaitai.struct.translators.RustTranslator +import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig} class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) extends LanguageCompiler(typeProvider, config) + with AllocateIOLocalVar + with EveryReadIsExpression + with FixedContentsUsingArrayByteLiteral with ObjectOrientedLanguage - with UpperCamelCaseClasses with SingleOutputFile - with AllocateIOLocalVar + with UpperCamelCaseClasses with UniversalFooter - with UniversalDoc - with FixedContentsUsingArrayByteLiteral - with EveryReadIsExpression { + with UniversalDoc { import RustCompiler._ + override val translator: RustTranslator = + new RustTranslator(typeProvider, config) + override def innerClasses = false override def innerEnums = false - override val translator: RustTranslator = new RustTranslator(typeProvider, config) - - override def universalFooter: Unit = { - out.dec - out.puts("}") - } - - override def outImports(topClass: ClassSpec) = - importList.toList.map((x) => s"use $x;").mkString("", "\n", "\n") - override def indent: String = " " + override def outFileName(topClassName: String): String = s"$topClassName.rs" + override def outImports(topClass: ClassSpec): String = + importList.toList + .map(i => s"#[allow(unused_imports)]\nuse $i;") + .mkString("", "\n", "\n") + override def fileHeader(topClassName: String): Unit = { outHeader.puts(s"// $headerComment") outHeader.puts - importList.add("std::option::Option") - importList.add("std::boxed::Box") - importList.add("std::io::Result") - importList.add("std::io::Cursor") - importList.add("std::vec::Vec") - importList.add("std::default::Default") - importList.add("kaitai_struct::KaitaiStream") - importList.add("kaitai_struct::KaitaiStruct") + importList.add( + "kaitai::{KError, KResult, KStream, KStruct, KStructUnit, TypedStack}" + ) + importList.add("kaitai::{kf32_max, kf64_max, kf32_min, kf64_min}") + importList.add("std::convert::{TryFrom, TryInto}") + } + + override def opaqueClassDeclaration(classSpec: ClassSpec): Unit = + importList.add( + s"crate::${classSpec.name.last}::${type2class(classSpec.name.last)}" + ) + override def classHeader(name: List[String]): Unit = { out.puts - } + out.puts("#[allow(non_camel_case_types)]") + out.puts("#[derive(Default, Debug, PartialEq)]") + out.puts(s"pub struct ${classTypeName(typeProvider.nowClass)} {") + out.inc - override def opaqueClassDeclaration(classSpec: ClassSpec): Unit = { - val name = type2class(classSpec.name.last) - val pkg = type2classAbs(classSpec.name) - - importList.add(s"$pkg::$name") - } + // Because we can't predict whether opaque types will need lifetimes as a type parameter, + // everyone gets a phantom data marker + out.puts(s"_phantom: std::marker::PhantomData<&$streamLife ()>,") - override def classHeader(name: List[String]): Unit = - classHeader(name, Some(kstructName)) + typeProvider.nowClass.params.foreach { p => + // Make sure the parameter is imported if necessary + p.dataType match { + case u: UserType => if (u.isOpaque && u.classSpec.isDefined) opaqueClassDeclaration(u.classSpec.get) + case _ => () + } - def classHeader(name: List[String], parentClass: Option[String]): Unit = { - out.puts("#[derive(Default)]") - out.puts(s"pub struct ${type2class(name)} {") + // Declare parameters as if they were attributes + attributeDeclaration(p.id, p.dataType, isNullable = false) + } } - override def classFooter(name: List[String]): Unit = universalFooter + // Intentional no-op; Rust has already ended the struct definition by the time we reach this + override def classFooter(name: List[String]): Unit = {} - override def classConstructorHeader(name: List[String], parentType: DataType, rootClassName: List[String], isHybrid: Boolean, params: List[ParamDefSpec]): Unit = { - out.puts("}") - out.puts + override def classConstructorHeader(name: List[String], + parentType: DataType, + rootClassName: List[String], + isHybrid: Boolean, + params: List[ParamDefSpec]): Unit = { - out.puts(s"impl KaitaiStruct for ${type2class(name)} {") - out.inc - - // Parameter names - val pIo = paramName(IoIdentifier) - val pParent = paramName(ParentIdentifier) - val pRoot = paramName(RootIdentifier) - - // Types - val tIo = kstreamName - val tParent = kaitaiType2NativeType(parentType) - - out.puts(s"fn new(stream: &mut S,") - out.puts(s" _parent: &Option>,") - out.puts(s" _root: &Option>)") - out.puts(s" -> Result") - out.inc - out.puts(s"where Self: Sized {") + // Unlike other OOP languages, implementing an interface happens outside the struct declaration. + universalFooter + + // If there are any switch types in the struct definition, create the enums for them + typeProvider.nowClass.seq.foreach( + a => + a.dataType match { + case st: SwitchType => switchTypeEnum(a.id, st) + case _ => () + } + ) + typeProvider.nowClass.instances.foreach( + i => + i._2.dataTypeComposite match { + case st: SwitchType => switchTypeEnum(i._1, st) + case _ => () + } + ) - out.puts(s"let mut s: Self = Default::default();") + out.puts( + s"impl<$readLife, $streamLife: $readLife> $kstructName<$readLife, $streamLife> for ${classTypeName(typeProvider.nowClass)} {" + ) + out.inc + out.puts(s"type Root = ${rootClassTypeName(typeProvider.nowClass)};") + out.puts( + s"type ParentStack = ${parentStackTypeName(typeProvider.nowClass)};" + ) out.puts + } - out.puts(s"s.stream = stream;") + override def runRead(): Unit = out.puts(s"// runRead()") - out.puts(s"s.read(stream, _parent, _root)?;") - out.puts + override def runReadCalc(): Unit = out.puts(s"// runReadCalc()") - out.puts("Ok(s)") + override def readHeader(endian: Option[FixedEndian], + isEmpty: Boolean): Unit = { + out.puts(s"fn read(") + out.inc + out.puts(s"&mut self,") + out.puts(s"${privateMemberName(IoIdentifier)}: &$streamLife S,") + out.puts( + s"${privateMemberName(RootIdentifier)}: Option<&$readLife Self::Root>," + ) + out.puts( + s"${privateMemberName(ParentIdentifier)}: TypedStack" + ) out.dec - out.puts("}") - out.puts + out.puts(s") -> KResult<$streamLife, ()> {") + out.inc + + // If there aren't any attributes to parse, we need to end the read implementation here + if (typeProvider.nowClass.seq.isEmpty) + endRead() } - override def runRead(): Unit = { + override def readFooter(): Unit = out.puts(s"// readFooter()") - } + override def attributeDeclaration(attrName: Identifier, + attrType: DataType, + isNullable: Boolean): Unit = { + val typeName = attrName match { + // For keeping lifetimes simple, we don't store _io, _root, or _parent with the struct + case IoIdentifier | RootIdentifier | ParentIdentifier => return + case _ => + kaitaiTypeToNativeType(attrName, typeProvider.nowClass, attrType) + } - override def runReadCalc(): Unit = { - + out.puts(s"pub ${idToStr(attrName)}: $typeName,") } - override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = { - out.puts - out.puts(s"fn read(&mut self,") - out.puts(s" stream: &mut S,") - out.puts(s" _parent: &Option>,") - out.puts(s" _root: &Option>)") - out.puts(s" -> Result<()>") - out.inc - out.puts(s"where Self: Sized {") + // Intentional no-op; Rust handles ownership, so don't worry about reader methods + override def attributeReader(attrName: Identifier, + attrType: DataType, + isNullable: Boolean): Unit = {} + + override def attrParse(attr: AttrLikeSpec, + id: Identifier, + defEndian: Option[Endianness]): Unit = { + super.attrParse(attr, id, defEndian) + + // Detect if this is the last attribute parse and finish the read method + if (typeProvider.nowClass.seq.nonEmpty && typeProvider.nowClass.seq.last.id == id) + endRead() } - override def readFooter(): Unit = { - out.puts + def endRead(): Unit = { out.puts("Ok(())") out.dec out.puts("}") } - override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { - attrName match { - case ParentIdentifier | RootIdentifier | IoIdentifier => - // just ignore it for now - case IoIdentifier => - out.puts(s" stream: ${kaitaiType2NativeType(attrType)},") - case _ => - out.puts(s" pub ${idToStr(attrName)}: ${kaitaiType2NativeType(attrType)},") - } - } + override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = + out.puts(s"// attrParseHybrid(${leProc()}, ${beProc()})") - override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { - - } + override def condIfHeader(expr: Ast.expr): Unit = { + // TODO: Actual implementation, this is a shim to enable compiling + out.puts("{") + out.inc - override def universalDoc(doc: DocSpec): Unit = { - if (doc.summary.isDefined) { - out.puts - out.puts("/*") - doc.summary.foreach((summary) => out.putsLines(" * ", summary)) - out.puts(" */") - } + out.puts(s"// condIfHeader($expr)") } - override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = { - out.puts("if ($this->_m__is_le) {") - out.inc - leProc() - out.dec - out.puts("} else {") + override def condRepeatEosHeader(id: Identifier, + io: String, + dataType: DataType, + needRaw: Boolean): Unit = { + // TODO: Actual implementation, this is a shim to enable compiling + out.puts("{") out.inc - beProc() - out.dec - out.puts("}") - } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);") - - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = { - val srcName = privateMemberName(varSrc) - val destName = privateMemberName(varDest) + out.puts(s"// condRepeatEosHeader($id, $io, $dataType, $needRaw)") + } - proc match { - case ProcessXor(xorValue) => - val procName = translator.detectType(xorValue) match { - case _: IntType => "processXorOne" - case _: BytesType => "processXorMany" - } - out.puts(s"$destName = $kstreamName::$procName($srcName, ${expression(xorValue)});") - case ProcessZlib => - out.puts(s"$destName = $kstreamName::processZlib($srcName);") - case ProcessRotate(isLeft, rotValue) => - val expr = if (isLeft) { - expression(rotValue) - } else { - s"8 - (${expression(rotValue)})" - } - out.puts(s"$destName = $kstreamName::processRotateLeft($srcName, $expr, 1);") - case ProcessCustom(name, args) => - val procClass = if (name.length == 1) { - val onlyName = name.head - val className = type2class(onlyName) - importList.add(s"$onlyName::$className") - className - } else { - val pkgName = type2classAbs(name.init) - val className = type2class(name.last) - importList.add(s"$pkgName::$className") - s"$pkgName::$className" - } + override def condRepeatExprHeader(id: Identifier, + io: String, + dataType: DataType, + needRaw: Boolean, + repeatExpr: Ast.expr): Unit = { + // TODO: Actual implementation, this is a shim to enable compiling + out.puts("{") + out.inc - out.puts(s"let _process = $procClass::new(${args.map(expression).mkString(", ")});") - out.puts(s"$destName = _process.decode($srcName);") - } + out.puts( + s"// condRepeatExprHeader($id, $io, $dataType, $needRaw, $repeatExpr)" + ) } - override def allocateIO(id: Identifier, rep: RepeatSpec): String = { - val memberName = privateMemberName(id) - - val args = rep match { - case RepeatEos | RepeatExpr(_) => s"$memberName.last()" - case RepeatUntil(_) => translator.doLocalName(Identifier.ITERATOR2) - case NoRepeat => memberName - } + override def condRepeatUntilHeader(id: Identifier, + io: String, + dataType: DataType, + needRaw: Boolean, + repeatExpr: Ast.expr): Unit = { + // TODO: Actual implementation, this is a shim to enable compiling + out.puts("{") + out.inc - out.puts(s"let mut io = Cursor::new($args);") - "io" + out.puts( + s"// condRepeatUntilHeader($id, $io, $dataType, $needRaw, $repeatExpr)" + ) } - override def useIO(ioEx: Ast.expr): String = { - out.puts(s"let mut io = ${expression(ioEx)};") - "io" + override def condRepeatUntilFooter(id: Identifier, + io: String, + dataType: DataType, + needRaw: Boolean, + repeatExpr: Ast.expr): Unit = { + out.puts( + s"// condRepeatUntilFooter($id, $io, $dataType, $needRaw, $repeatExpr)" + ) + out.dec + out.puts("} {}") } - override def pushPos(io: String): Unit = - out.puts(s"let _pos = $io.pos();") + override def attrProcess(proc: ProcessExpr, + varSrc: Identifier, + varDest: Identifier): Unit = + out.puts(s"// attrProcess($proc, $varSrc, $varDest)") + + override def useIO(ioEx: Ast.expr): String = s"// useIO($ioEx)" + + override def pushPos(io: String): Unit = out.puts(s"// pushPos($io)") override def seek(io: String, pos: Ast.expr): Unit = - out.puts(s"$io.seek(${expression(pos)});") + out.puts(s"// seek($io, $pos)") - override def popPos(io: String): Unit = - out.puts(s"$io.seek(_pos);") + override def popPos(io: String): Unit = out.puts(s"// popPos($io)") override def alignToByte(io: String): Unit = - out.puts(s"$io.alignToByte();") + out.puts(s"${privateMemberName(IoIdentifier)}.align_to_byte()?;") - override def condIfHeader(expr: Ast.expr): Unit = { - out.puts(s"if ${expression(expr)} {") - out.inc - } + override def privateMemberName(id: Identifier): String = + RustCompiler.privateMemberName(id) - override def condRepeatEosHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean): Unit = { - if (needRaw) - out.puts(s"${privateMemberName(RawIdentifier(id))} = [];") - out.puts(s"${privateMemberName(id)} = [];") - out.puts(s"while !$io.isEof() {") + override def instanceDeclHeader(className: List[String]): Unit = { + out.puts( + s"impl<$readLife, $streamLife: $readLife> ${classTypeName(typeProvider.nowClass)} {" + ) out.inc } - override def handleAssignmentRepeatEos(id: Identifier, expr: String): Unit = { - out.puts(s"${privateMemberName(id)}.push($expr);") + override def instanceDeclFooter(className: List[String]): Unit = + universalFooter + + override def universalFooter: Unit = { + out.dec + out.puts("}") } - override def condRepeatEosFooter: Unit = { - super.condRepeatEosFooter + override def instanceDeclaration(attrName: InstanceIdentifier, + attrType: DataType, + isNullable: Boolean): Unit = { + val typeName = kaitaiTypeToNativeType( + attrName, + typeProvider.nowClass, + attrType, + excludeOptionWrapper = true + ) + attrType match { + case _: ArrayType => out.puts(s"pub ${idToStr(attrName)}: $typeName,") + case _ => out.puts(s"pub ${idToStr(attrName)}: Option<$typeName>,") + } } - override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, repeatExpr: Ast.expr): Unit = { - if (needRaw) - out.puts(s"${privateMemberName(RawIdentifier(id))} = vec!();") - out.puts(s"${privateMemberName(id)} = vec!();") - out.puts(s"for i in 0..${expression(repeatExpr)} {") + override def idToStr(id: Identifier): String = RustCompiler.idToStr(id) + + override def instanceHeader(className: List[String], + instName: InstanceIdentifier, + dataType: DataType, + isNullable: Boolean): Unit = { + + out.puts(s"fn ${idToStr(instName)}(") + out.inc + out.puts("&mut self,") + out.puts(s"${privateMemberName(IoIdentifier)}: &$streamLife S,") + out.puts( + s"${privateMemberName(RootIdentifier)}: Option<&$readLife ${rootClassTypeName(typeProvider.nowClass)}>," + ) + out.puts( + s"${privateMemberName(ParentIdentifier)}: TypedStack<${parentStackTypeName(typeProvider.nowClass)}>" + ) + out.dec + val typeName = kaitaiTypeToNativeType( + instName, + typeProvider.nowClass, + dataType, + excludeOptionWrapper = true + ) + out.puts(s") -> KResult<$streamLife, $typeName> {") out.inc } - override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = { - out.puts(s"${privateMemberName(id)}.push($expr);") + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, + dataType: DataType): Unit = + out.puts(s"// instanceCheckCacheAndReturn($instName, $dataType)") + + override def instanceReturn(instName: InstanceIdentifier, + attrType: DataType): Unit = { + out.puts("panic!(\"Instance calculation not yet supported.\");") + out.puts(s"// instanceReturn($instName, $attrType)") } - override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: Ast.expr): Unit = { - if (needRaw) - out.puts(s"${privateMemberName(RawIdentifier(id))} = vec!();") - out.puts(s"${privateMemberName(id)} = vec!();") - out.puts("while {") + override def enumDeclaration(curClass: List[String], + enumName: String, + enumColl: Seq[(Long, EnumValueSpec)]): Unit = { + + val enumClass = types2class(curClass ::: List(enumName)) + + // Set up the actual enum definition + out.puts(s"#[allow(non_camel_case_types)]") + out.puts(s"#[derive(Debug, PartialEq)]") + out.puts(s"pub enum $enumClass {") out.inc - } - override def handleAssignmentRepeatUntil(id: Identifier, expr: String, isRaw: Boolean): Unit = { - val tempVar = if (isRaw) { - translator.doLocalName(Identifier.ITERATOR2) - } else { - translator.doLocalName(Identifier.ITERATOR) + enumColl.foreach { + case (_, label) => + if (label.doc.summary.isDefined) + universalDoc(label.doc) + + out.puts(s"${type2class(label.name)},") } - out.puts(s"let $tempVar = $expr;") - out.puts(s"${privateMemberName(id)}.append($expr);") - } - override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: Ast.expr): Unit = { - typeProvider._currentIteratorType = Some(dataType) - out.puts(s"!(${expression(untilExpr)})") out.dec - out.puts("} { }") - } + out.puts("}") - override def handleAssignmentSimple(id: Identifier, expr: String): Unit = { - out.puts(s"${privateMemberName(id)} = $expr;") - } + // Set up parsing enums from the underlying value + out.puts(s"impl TryFrom for $enumClass {") - override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = { - dataType match { - case t: ReadableType => - s"$io.read_${t.apiCall(defEndian)}()?" - case blt: BytesLimitType => - s"$io.read_bytes(${expression(blt.size)})?" - case _: BytesEosType => - s"$io.read_bytes_full()?" - case BytesTerminatedType(terminator, include, consume, eosError, _) => - s"$io.read_bytes_term($terminator, $include, $consume, $eosError)?" - case BitsType1 => - s"$io.read_bits_int(1)? != 0" - case BitsType(width: Int) => - s"$io.read_bits_int($width)?" - case t: UserType => - val addParams = Utils.join(t.args.map((a) => translator.translate(a)), "", ", ", ", ") - val addArgs = if (t.isOpaque) { - "" - } else { - val parent = t.forcedParent match { - case Some(USER_TYPE_NO_PARENT) => "null" - case Some(fp) => translator.translate(fp) - case None => "self" - } - val addEndian = t.classSpec.get.meta.endian match { - case Some(InheritedEndian) => s", ${privateMemberName(EndianIdentifier)}" - case _ => "" - } - s", $parent, ${privateMemberName(RootIdentifier)}$addEndian" - } - - s"Box::new(${translator.types2classAbs(t.classSpec.get.name)}::new(self.stream, self, _root)?)" + out.inc + // We typically need the lifetime in KError for returning byte slices from stream; + // because we can only return `UnknownVariant` which contains a Copy type, it's safe + // to declare that the error type is `'static` + out.puts(s"type Error = KError<'static>;") + out.puts(s"fn try_from(flag: i64) -> KResult<'static, $enumClass> {") + + out.inc + out.puts(s"match flag {") + + out.inc + enumColl.foreach { + case (value, label) => + out.puts(s"$value => Ok($enumClass::${type2class(label.name)}),") } + out.puts("_ => Err(KError::UnknownVariant(flag)),") + out.dec + + out.puts(s"}") + out.dec + out.puts(s"}") + out.dec + out.puts(s"}") + out.puts } - override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean): String = { - val expr1 = padRight match { - case Some(padByte) => s"$kstreamName::bytesStripRight($expr0, $padByte)" - case None => expr0 - } - val expr2 = terminator match { - case Some(term) => s"$kstreamName::bytesTerminate($expr1, $term, $include)" - case None => expr1 - } - expr2 + override def universalDoc(doc: DocSpec): Unit = { + out.puts(s"// universalDoc()") } - var switchIfs = false - val NAME_SWITCH_ON = Ast.expr.Name(Ast.identifier(Identifier.SWITCH_ON)) + override def handleAssignmentRepeatEos(id: Identifier, expr: String): Unit = + out.puts(s"// handleAssignmentRepeatEos($id, $expr)") - override def switchStart(id: Identifier, on: Ast.expr): Unit = { - val onType = translator.detectType(on) + override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = + out.puts(s"// handleAssignmentRepeatExpr($id, $expr)") - switchIfs = onType match { - case _: ArrayType | _: BytesType => true - case _ => false - } + override def handleAssignmentRepeatUntil(id: Identifier, + expr: String, + isRaw: Boolean): Unit = + out.puts(s"// handleAssignmentRepeatUntil($id, $expr, $isRaw)") - if (!switchIfs) { - out.puts(s"match ${expression(on)} {") - out.inc + override def handleAssignmentSimple(id: Identifier, expr: String): Unit = { + val seqId = typeProvider.nowClass.seq.find(s => s.id == id) + + if (seqId.isDefined) seqId.get.dataType match { + case _: EnumType => + out.puts( + s"${privateMemberName(id)} = Some(($expr as i64).try_into()?);" + ) + case _: UserType | _: SwitchType | _: BytesLimitType => + out.puts(s"// handleAssignmentSimple($id, $expr)") + case _ => out.puts(s"${privateMemberName(id)} = $expr;") } } - def switchCmpExpr(condition: Ast.expr): String = - expression( - Ast.expr.Compare( - NAME_SWITCH_ON, - Ast.cmpop.Eq, - condition - ) - ) - - override def switchCaseFirstStart(condition: Ast.expr): Unit = { - if (switchIfs) { - out.puts(s"if ${switchCmpExpr(condition)} {") - out.inc - } else { - switchCaseStart(condition) + override def parseExpr(dataType: DataType, + assignType: DataType, + io: String, + defEndian: Option[FixedEndian]): String = + dataType match { + case IntMultiType(_, _, None) => "panic!(\"Unable to parse unknown-endian integers\")" + case t: ReadableType => s"$io.read_${t.apiCall(defEndian)}()?" + case _: BytesEosType => s"$io.read_bytes_full()?" + case b: BytesTerminatedType => + s"$io.read_bytes_term(${b.terminator}, ${b.include}, ${b.consume}, ${b.eosError})?" + case b: BytesLimitType => s"$io.read_bytes(${expression(b.size)} as usize)?" + case BitsType1 => s"$io.read_bits_int(1)? != 0" + case BitsType(width) => s"$io.read_bits_int($width)?" + case _ => s"// parseExpr($dataType, $assignType, $io, $defEndian)" } - } - override def switchCaseStart(condition: Ast.expr): Unit = { - if (switchIfs) { - out.puts(s"elss if ${switchCmpExpr(condition)} {") - out.inc - } else { - out.puts(s"${expression(condition)} => {") - out.inc + override def bytesPadTermExpr(expr0: String, + padRight: Option[Int], + terminator: Option[Int], + include: Boolean): String = { + val ioId = privateMemberName(IoIdentifier) + val expr = padRight match { + case Some(p) => s"$ioId.bytes_strip_right($expr0, $p)" + case None => expr0 } - } - override def switchCaseEnd(): Unit = { - if (switchIfs) { - out.dec - out.puts("}") - } else { - out.dec - out.puts("},") + terminator match { + case Some(term) => s"$ioId.bytes_terminate($expr, $term, $include)" + case None => expr } } - override def switchElseStart(): Unit = { - if (switchIfs) { - out.puts("else {") - out.inc - } else { - out.puts("_ => {") - out.inc - } - } + override def attrFixedContentsParse(attrName: Identifier, + contents: String): Unit = + out.puts(s"// attrFixedContentsParse($attrName, $contents)") - override def switchElseEnd(): Unit = { - out.dec - out.puts("}") - } + override def publicMemberName(id: Identifier): String = + s"// publicMemberName($id)" - override def switchEnd(): Unit = universalFooter + override def localTemporaryName(id: Identifier): String = + s"// localTemporaryName($id)" - override def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = { - out.puts(s" pub ${idToStr(attrName)}: Option<${kaitaiType2NativeType(attrType)}>,") - } + override def switchStart(id: Identifier, on: Ast.expr): Unit = + out.puts(s"// switchStart($id, $on)") - override def instanceDeclHeader(className: List[String]): Unit = { - out.dec - out.puts("}") - out.puts + override def switchCaseStart(condition: Ast.expr): Unit = + out.puts(s"// switchCaseStart($condition)") - out.puts(s"impl ${type2class(className)} {") - out.inc - } + override def switchCaseEnd(): Unit = out.puts(s"// switchCaseEnd()") - override def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = { - out.puts(s"fn ${idToStr(instName)}(&mut self) -> ${kaitaiType2NativeType(dataType)} {") - out.inc - } + override def switchElseStart(): Unit = out.puts(s"// switchElseStart()") - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { - out.puts(s"if let Some(x) = ${privateMemberName(instName)} {") - out.inc - out.puts("return x;") - out.dec - out.puts("}") - out.puts - } + override def switchEnd(): Unit = out.puts(s"// switchEnd()") - override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { - out.puts(s"return ${privateMemberName(instName)};") + override def extraAttrForIO(id: Identifier, + rep: RepeatSpec): List[AttrSpec] = { + out.puts(s"// extraAttrForIO($id, $rep)") + Nil } - override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = { - val enumClass = type2class(curClass ::: List(enumName)) + override def allocateIO(varName: Identifier, rep: RepeatSpec): String = + s"// allocateIO($varName, $rep)" - out.puts(s"enum $enumClass {") + def switchTypeEnum(id: Identifier, st: SwitchType): Unit = { + // Because Rust can't handle `AnyType` in the type hierarchy, + // we generate an enum with all possible variations + val typeName = kaitaiTypeToNativeType( + id, + typeProvider.nowClass, + st, + excludeOptionWrapper = true + ) + out.puts("#[allow(non_camel_case_types)]") + out.puts("#[derive(Debug, PartialEq)]") + out.puts(s"pub enum $typeName {") out.inc - - enumColl.foreach { case (id, label) => - universalDoc(label.doc) - out.puts(s"${value2Const(label.name)},") - } + + val types = st.cases.values.toSet + types.foreach(t => { + // Because this switch type will itself be in an option, we can exclude it from user types + val variantName = switchVariantName(id, t) + val typeName = kaitaiTypeToNativeType( + id, + typeProvider.nowClass, + t, + excludeOptionWrapper = true + ) + out.puts(s"$variantName($typeName),") + }) out.dec out.puts("}") } - def value2Const(label: String) = label.toUpperCase + def switchVariantName(id: Identifier, attrType: DataType): String = + attrType match { + // TODO: Not exhaustive + case Int1Type(false) => "U1" + case IntMultiType(false, Width2, _) => "U2" + case IntMultiType(false, Width4, _) => "U4" + case IntMultiType(false, Width8, _) => "U8" + + case Int1Type(true) => "S1" + case IntMultiType(true, Width2, _) => "S2" + case IntMultiType(true, Width4, _) => "S4" + case IntMultiType(true, Width8, _) => "S8" + + case FloatMultiType(Width4, _) => "F4" + case FloatMultiType(Width8, _) => "F8" + + case BitsType(_) => "Bits" + case _: BooleanType => "Boolean" + case CalcIntType => "Int" + case CalcFloatType => "Float" + case _: StrType => "String" + case _: BytesType => "Bytes" - def idToStr(id: Identifier): String = { - id match { - case SpecialIdentifier(name) => name - case NamedIdentifier(name) => Utils.lowerCamelCase(name) - case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" - case InstanceIdentifier(name) => Utils.lowerCamelCase(name) - case RawIdentifier(innerId) => "_raw_" + idToStr(innerId) + case t: UserType => + kaitaiTypeToNativeType( + id, + typeProvider.nowClass, + t, + excludeOptionWrapper = true, + excludeLifetime = true, + excludeBox = true + ) + case t: EnumType => + kaitaiTypeToNativeType( + id, + typeProvider.nowClass, + t, + excludeOptionWrapper = true + ) + case t: ArrayType => s"Arr${switchVariantName(id, t.elType)}" } +} + +object RustCompiler + extends LanguageCompilerStatic + with StreamStructNames + with UpperCamelCaseClasses { + override def getCompiler(tp: ClassTypeProvider, + config: RuntimeConfig): LanguageCompiler = + new RustCompiler(tp, config) + + override def kstreamName = "KStream" + + def privateMemberName(id: Identifier): String = id match { + case IoIdentifier => "_io" + case RootIdentifier => "_root" + case ParentIdentifier => "_parent" + case _ => s"self.${idToStr(id)}" } - override def privateMemberName(id: Identifier): String = { - id match { - case IoIdentifier => s"self.stream" - case RootIdentifier => s"_root" - case ParentIdentifier => s"_parent" - case _ => s"self.${idToStr(id)}" - } + def idToStr(id: Identifier): String = id match { + case SpecialIdentifier(n) => n + case NamedIdentifier(n) => n + case InstanceIdentifier(n) => n + case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" + case RawIdentifier(inner) => s"raw_${idToStr(inner)}" } - override def publicMemberName(id: Identifier) = idToStr(id) + def rootClassTypeName(c: ClassSpec, isRecurse: Boolean = false): String = { + if (!isRecurse && c.isTopLevel) + "Self" + else if (c.isTopLevel) + classTypeName(c) + else + rootClassTypeName(c.upClass.get, isRecurse = true) + } - override def localTemporaryName(id: Identifier): String = s"$$_t_${idToStr(id)}" + def parentStackTypeName(c: ClassSpec): String = { + if (c.isTopLevel) + s"($kstructUnitName)" + else + s"(&$streamLife ${classTypeName(c.upClass.get)}, <${classTypeName(c.upClass.get)} as $kstructName<$readLife, $streamLife>>::ParentStack)" + } - override def paramName(id: Identifier): String = s"${idToStr(id)}" - - def kaitaiType2NativeType(attrType: DataType): String = { - attrType match { - case Int1Type(false) => "u8" - case IntMultiType(false, Width2, _) => "u16" - case IntMultiType(false, Width4, _) => "u32" - case IntMultiType(false, Width8, _) => "u64" + override def kstructName = s"KStruct" - case Int1Type(true) => "i8" - case IntMultiType(true, Width2, _) => "i16" - case IntMultiType(true, Width4, _) => "i32" - case IntMultiType(true, Width8, _) => "i64" + def readLife = "'r" - case FloatMultiType(Width4, _) => "f32" - case FloatMultiType(Width8, _) => "f64" + def kstructUnitName = "KStructUnit" - case BitsType(_) => "u64" + def classTypeName(c: ClassSpec): String = + s"${types2class(c.name)}<$streamLife>" - case _: BooleanType => "bool" - case CalcIntType => "i32" - case CalcFloatType => "f64" + def streamLife = "'s" - case _: StrType => "String" - case _: BytesType => "Vec" + def types2class(names: List[String]): String = + // TODO: Use `mod` to scope types instead of weird names + names.map(x => type2class(x)).mkString("_") - case t: UserType => t.classSpec match { - case Some(cs) => s"Box<${type2class(cs.name)}>" - case None => s"Box<${type2class(t.name)}>" - } - - case t: EnumType => t.enumSpec match { - case Some(cs) => s"Box<${type2class(cs.name)}>" - case None => s"Box<${type2class(t.name)}>" - } + def lifetimeParam(d: DataType): String = + if (containsReferences(d)) s"<$streamLife>" else "" + + def containsReferences(d: DataType): Boolean = containsReferences(d, None) - case ArrayType(inType) => s"Vec<${kaitaiType2NativeType(inType)}>" + def containsReferences(c: ClassSpec, + originating: Option[ClassSpec]): Boolean = + c.seq.exists(t => containsReferences(t.dataType, originating)) || + c.instances.exists( + i => containsReferences(i._2.dataTypeComposite, originating) + ) - case KaitaiStreamType => s"Option>" - case KaitaiStructType | CalcKaitaiStructType => s"Option>" - - case st: SwitchType => kaitaiType2NativeType(st.combinedType) + def containsReferences(d: DataType, originating: Option[ClassSpec]): Boolean = + d match { + case _: BytesType | _: StrType => true + case t: UserType => true + /* + t.classSpec match { + // Recursive types may need references, but the recursion itself + // will be handled by `Box<>`, so doesn't need a reference + case Some(inner) if originating.contains(inner) => false + case Some(inner) => containsReferences(inner, originating.orElse(Some(inner))) + case None => false + } + */ + case t: ArrayType => containsReferences(t.elType, originating) + case st: SwitchType => + st.cases.values.exists(t => containsReferences(t, originating)) + case _ => false } - } - - def kaitaiType2Default(attrType: DataType): String = { - attrType match { - case Int1Type(false) => "0" - case IntMultiType(false, Width2, _) => "0" - case IntMultiType(false, Width4, _) => "0" - case IntMultiType(false, Width8, _) => "0" - case Int1Type(true) => "0" - case IntMultiType(true, Width2, _) => "0" - case IntMultiType(true, Width4, _) => "0" - case IntMultiType(true, Width8, _) => "0" + def kaitaiTypeToNativeType(id: Identifier, + cs: ClassSpec, + attrType: DataType, + excludeOptionWrapper: Boolean = false, + excludeLifetime: Boolean = false, + excludeBox: Boolean = false): String = + attrType match { + // TODO: Not exhaustive + case _: NumericType => kaitaiPrimitiveToNativeType(attrType) + case _: BooleanType => kaitaiPrimitiveToNativeType(attrType) + case _: StrType => kaitaiPrimitiveToNativeType(attrType) + case _: BytesType => kaitaiPrimitiveToNativeType(attrType) - case FloatMultiType(Width4, _) => "0" - case FloatMultiType(Width8, _) => "0" + case t: UserType => + val baseName = t.classSpec match { + case Some(spec) => types2class(spec.name) + case None => types2class(t.name) + } + val lifetime = if (!excludeLifetime) s"<$streamLife>" else "" + + // Because we can't predict if opaque types will recurse, we have to box them + val typeName = + if (!excludeBox && t.isOpaque) s"Box<$baseName$lifetime>" + else s"$baseName$lifetime" + if (excludeOptionWrapper) typeName else s"Option<$typeName>" + + case t: EnumType => + val typeName = t.enumSpec match { + case Some(spec) => s"${types2class(spec.name)}" + case None => s"${types2class(t.name)}" + } + if (excludeOptionWrapper) typeName else s"Option<$typeName>" + + case t: ArrayType => + s"Vec<${kaitaiTypeToNativeType(id, cs, t.elType, excludeOptionWrapper = true, excludeLifetime = excludeLifetime)}>" + + case st: SwitchType => + val types = st.cases.values.toSet + val lifetime = + if (!excludeLifetime && types.exists(containsReferences)) + s"<$streamLife>" + else "" + val typeName = id match { + case name: NamedIdentifier => + s"${types2class(cs.name ::: List(name.name))}$lifetime" + case name: InstanceIdentifier => + s"${types2class(cs.name ::: List(name.name))}$lifetime" + case _ => kstructUnitName + } - case BitsType(_) => "0" + if (excludeOptionWrapper) typeName else s"Option<$typeName>" - case _: BooleanType => "false" - case CalcIntType => "0" - case CalcFloatType => "0" + case KaitaiStreamType => kstreamName + } - case _: StrType => "\"\"" - case _: BytesType => "vec!()" + def kaitaiPrimitiveToNativeType(attrType: DataType): String = attrType match { + case Int1Type(false) => "u8" + case IntMultiType(false, Width2, _) => "u16" + case IntMultiType(false, Width4, _) => "u32" + case IntMultiType(false, Width8, _) => "u64" - case t: UserType => "Default::default()" - case t: EnumType => "Default::default()" + case Int1Type(true) => "i8" + case IntMultiType(true, Width2, _) => "i16" + case IntMultiType(true, Width4, _) => "i32" + case IntMultiType(true, Width8, _) => "i64" - case ArrayType(inType) => "vec!()" + case FloatMultiType(Width4, _) => "f32" + case FloatMultiType(Width8, _) => "f64" - case KaitaiStreamType => "None" - case KaitaiStructType => "None" - - case _: SwitchType => "" - // TODO - } - } - - def type2class(names: List[String]) = types2classRel(names) + case BitsType(_) => "u64" - def type2classAbs(names: List[String]) = - names.mkString("::") -} + case _: BooleanType => "bool" + case CalcIntType => "i32" + case CalcFloatType => "f64" -object RustCompiler extends LanguageCompilerStatic - with StreamStructNames - with UpperCamelCaseClasses { - override def getCompiler( - tp: ClassTypeProvider, - config: RuntimeConfig - ): LanguageCompiler = new RustCompiler(tp, config) - - override def kstructName = "&Option>" - override def kstreamName = "&mut S" - - def types2class(typeName: Ast.typeId) = { - typeName.names.map(type2class).mkString( - if (typeName.absolute) "__" else "", - "__", - "" - ) + case _: StrType => s"&$streamLife str" + case _: BytesType => s"&$streamLife [u8]" } - - def types2classRel(names: List[String]) = - names.map(type2class).mkString("__") } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala index f41214463..161d31afc 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala @@ -120,6 +120,8 @@ abstract class LanguageCompiler( def alignToByte(io: String): Unit def instanceDeclHeader(className: List[String]): Unit = {} + + def instanceDeclFooter(className: List[String]): Unit = {} def instanceClear(instName: InstanceIdentifier): Unit = {} def instanceSetCalculated(instName: InstanceIdentifier): Unit = {} def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = attributeDeclaration(attrName, attrType, isNullable) diff --git a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala index 267ea06b4..c3e43a74e 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala @@ -3,15 +3,17 @@ package io.kaitai.struct.translators import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.exprlang.Ast.expr -import io.kaitai.struct.format.Identifier +import io.kaitai.struct.format.{Identifier, InstanceIdentifier, IoIdentifier, NamedIdentifier, ParentIdentifier, RootIdentifier} import io.kaitai.struct.languages.RustCompiler import io.kaitai.struct.{RuntimeConfig, Utils} -class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends BaseTranslator(provider) { +class RustTranslator(provider: TypeProvider, config: RuntimeConfig) + extends BaseTranslator(provider) { + + import RustCompiler._ + override def doByteArrayLiteral(arr: Seq[Byte]): String = - "vec!([" + arr.map((x) => - "%0#2x".format(x & 0xff) - ).mkString(", ") + "])" + "&[" + arr.map(x => "%0#2x".format(x & 0xff)).mkString(", ") + "]" override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = s"pack('C*', ${elts.map(translate).mkString(", ")})" @@ -26,7 +28,9 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base override def strLiteralUnicode(code: Char): String = "\\u{%x}".format(code.toInt) - override def numericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) = { + override def numericBinOp(left: Ast.expr, + op: Ast.operator, + right: Ast.expr) = { (detectType(left), detectType(right), op) match { case (_: IntType, _: IntType, Ast.operator.Div) => s"${translate(left)} / ${translate(right)}" @@ -42,26 +46,41 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base case Identifier.ITERATOR => "tmpa" case Identifier.ITERATOR2 => "tmpb" case Identifier.INDEX => "i" - case _ => s"self.${doName(s)}" + case Identifier.IO => s"${RustCompiler.privateMemberName(IoIdentifier)}" + case Identifier.ROOT => s"${RustCompiler.privateMemberName(RootIdentifier)}.ok_or(KError::MissingRoot)?" + case Identifier.PARENT => + // TODO: How to handle _parent._parent? + s"${RustCompiler.privateMemberName(ParentIdentifier)}.peek()" + case _ => + if (provider.nowClass.seq.exists(a => a.id != IoIdentifier && a.id == NamedIdentifier(s))) { + // If the name is part of the `seq` parse list, it's safe to return as-is + s"self.${doName(s)}" + } else if (provider.nowClass.instances.contains(InstanceIdentifier(s))) { + // It's an instance, we need to safely handle lookup + s"self.${doName(s)}(${privateMemberName(IoIdentifier)}, ${privateMemberName(RootIdentifier)}, ${privateMemberName(ParentIdentifier)})?" + } else { + // TODO: Is it possible to reach this block? RawIdentifier? + s"self.${doName(s)}" + } } } override def doName(s: String) = s - override def doEnumByLabel(enumTypeAbs: List[String], label: String): String = { - val enumClass = types2classAbs(enumTypeAbs) - s"$enumClass::${label.toUpperCase}" - } + override def doEnumByLabel(enumTypeAbs: List[String], label: String): String = + s"${RustCompiler.types2class(enumTypeAbs)}::${Utils.upperCamelCase(label)}" + override def doEnumById(enumTypeAbs: List[String], id: String) = // Just an integer, without any casts / resolutions - one would have to look up constants manually id override def doSubscript(container: expr, idx: expr): String = - s"${translate(container)}[${translate(idx)}]" + s"${translate(container)}[${translate(idx)} as usize]" + override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = "if " + translate(condition) + - " { " + translate(ifTrue) + " } else { " + - translate(ifFalse) + "}" + " { " + translate(ifTrue) + " } else { " + + translate(ifFalse) + "}" // Predefined methods of various types override def strConcat(left: Ast.expr, right: Ast.expr): String = @@ -72,7 +91,9 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base case "10" => s"${translate(s)}.parse().unwrap()" case _ => - "panic!(\"Converting from string to int in base {} is unimplemented\"" + translate(base) + ")" + "panic!(\"Converting from string to int in base {} is unimplemented\"" + translate( + base + ) + ")" } override def enumToInt(v: expr, et: EnumType): String = @@ -96,10 +117,13 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base override def bytesToStr(bytesExpr: String, encoding: Ast.expr): String = translate(encoding) match { case "\"ASCII\"" => - s"String::from_utf8_lossy($bytesExpr)" + // Currently has issues because the `&str` created doesn't outlive the function, + // will likely need to decode *as* as string or handle specially elsewhere + // s"&String::from_utf8_lossy($bytesExpr)" + "panic!(\"Unresolved lifetime issues with string parsing\")" case _ => "panic!(\"Unimplemented encoding for bytesToStr: {}\", " + - translate(encoding) + ")" + translate(encoding) + ")" } override def bytesLength(b: Ast.expr): String = s"${translate(b)}.len()" @@ -120,10 +144,4 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base s"${translate(a)}.iter().min()" override def arrayMax(a: Ast.expr): String = s"${translate(a)}.iter().max()" - - def types2classAbs(names: List[String]) = - names match { - case List("kaitai_struct") => RustCompiler.kstructName - case _ => RustCompiler.types2classRel(names) - } }