From 5dbf318ea311ec4a265be63d8cc5c45bb156f2cc Mon Sep 17 00:00:00 2001 From: Mingun Date: Sun, 14 Apr 2024 00:43:24 +0500 Subject: [PATCH] Go: override methods which generates error checks instead of hacking output --- .../testtranslator/specgenerators/GoSG.scala | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala b/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala index 7d66eb1be..cb213e514 100644 --- a/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala +++ b/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala @@ -4,8 +4,8 @@ import _root_.io.kaitai.struct.datatype.{DataType, EndOfStreamError, KSError} import _root_.io.kaitai.struct.exprlang.Ast import _root_.io.kaitai.struct.languages.GoCompiler import _root_.io.kaitai.struct.testtranslator.{Main, TestAssert, TestEquals, TestSpec} -import _root_.io.kaitai.struct.translators.GoTranslator -import _root_.io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, StringLanguageOutputWriter, Utils} +import _root_.io.kaitai.struct.translators.{GoTranslator, TypeProvider} +import _root_.io.kaitai.struct.{ClassTypeProvider, ImportList, RuntimeConfig, StringLanguageOutputWriter, Utils} class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(spec) { /** @@ -14,15 +14,32 @@ class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(sp */ class GoOutputWriter(indentStr: String) extends StringLanguageOutputWriter(indentStr) { override def puts(s: String): Unit = { - val mangled = s.replace(REPLACER, "r.").replaceAll("return err$", "t.Fatal(err)") - super.puts(mangled) + super.puts(s.replace(REPLACER, "r.")) + } + } + + /** + * Special wrapper around translator that catches all attempts to write error + * check and turns it into assertion. + */ + class GoTestTranslator( + out: StringLanguageOutputWriter, + provider: TypeProvider, + importList: ImportList, + ) extends GoTranslator(out, provider, importList) { + override def outAddErrCheck(): Unit = { + out.puts("if err != nil {") + out.inc + out.puts("t.Fatal(err)") + out.dec + out.puts("}") } } override val out = new GoOutputWriter(indentStr) val compiler = new GoCompiler(provider, RuntimeConfig()) val className = GoCompiler.types2class(List(spec.id)) - val translator = new GoTranslator(out, provider, importList) + val translator = new GoTestTranslator(out, provider, importList) override def fileName(name: String): String = s"${name}_test.go"