Skip to content

Commit

Permalink
Go: override methods which generates error checks instead of hacking …
Browse files Browse the repository at this point in the history
…output
  • Loading branch information
Mingun committed Jul 19, 2024
1 parent cccb05d commit 5dbf318
Showing 1 changed file with 22 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import _root_.io.kaitai.struct.datatype.{DataType, EndOfStreamError, KSError}
import _root_.io.kaitai.struct.exprlang.Ast
import _root_.io.kaitai.struct.languages.GoCompiler
import _root_.io.kaitai.struct.testtranslator.{Main, TestAssert, TestEquals, TestSpec}
import _root_.io.kaitai.struct.translators.GoTranslator
import _root_.io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, StringLanguageOutputWriter, Utils}
import _root_.io.kaitai.struct.translators.{GoTranslator, TypeProvider}
import _root_.io.kaitai.struct.{ClassTypeProvider, ImportList, RuntimeConfig, StringLanguageOutputWriter, Utils}

class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(spec) {
/**
Expand All @@ -14,15 +14,32 @@ class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(sp
*/
class GoOutputWriter(indentStr: String) extends StringLanguageOutputWriter(indentStr) {
override def puts(s: String): Unit = {
val mangled = s.replace(REPLACER, "r.").replaceAll("return err$", "t.Fatal(err)")
super.puts(mangled)
super.puts(s.replace(REPLACER, "r."))
}
}

/**
* Special wrapper around translator that catches all attempts to write error
* check and turns it into assertion.
*/
class GoTestTranslator(
out: StringLanguageOutputWriter,
provider: TypeProvider,
importList: ImportList,
) extends GoTranslator(out, provider, importList) {
override def outAddErrCheck(): Unit = {
out.puts("if err != nil {")
out.inc
out.puts("t.Fatal(err)")
out.dec
out.puts("}")
}
}

override val out = new GoOutputWriter(indentStr)
val compiler = new GoCompiler(provider, RuntimeConfig())
val className = GoCompiler.types2class(List(spec.id))
val translator = new GoTranslator(out, provider, importList)
val translator = new GoTestTranslator(out, provider, importList)

override def fileName(name: String): String = s"${name}_test.go"

Expand Down

0 comments on commit 5dbf318

Please sign in to comment.