Skip to content

Commit

Permalink
add zio-json support to openapi-codegen (#3728)
Browse files Browse the repository at this point in the history
  • Loading branch information
oker1 authored May 9, 2024
1 parent 585434d commit 5632d8b
Show file tree
Hide file tree
Showing 12 changed files with 629 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
import sttp.tapir.codegen.openapi.models.SpecificationExtensionRenderer

object JsonSerdeLib extends Enumeration {
val Circe, Jsoniter = Value
val Circe, Jsoniter, Zio = Value
type JsonSerdeLib = Value
}

Expand All @@ -40,6 +40,7 @@ object BasicGenerator {
val normalisedJsonLib = jsonSerdeLib.toLowerCase match {
case "circe" => JsonSerdeLib.Circe
case "jsoniter" => JsonSerdeLib.Jsoniter
case "zio" => JsonSerdeLib.Zio
case _ =>
System.err.println(
s"!!! Unrecognised value $jsonSerdeLib for json serde lib -- should be one of circe, jsoniter. Defaulting to circe !!!"
Expand Down Expand Up @@ -166,6 +167,9 @@ object BasicGenerator {
"""import sttp.tapir.json.jsoniter._
|import com.github.plokhotnyuk.jsoniter_scala.macros._
|import com.github.plokhotnyuk.jsoniter_scala.core._""".stripMargin
case JsonSerdeLib.Zio =>
"""import sttp.tapir.json.zio._
|import zio.json._""".stripMargin
}
s"""import sttp.tapir._
|import sttp.tapir.model._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ object JsonSerdeGenerator {
if (fullModelPath.isEmpty) None else Some(fullModelPath),
validateNonDiscriminatedOneOfs
)
case JsonSerdeLib.Zio => genZioSerdes(doc, allSchemas, allTransitiveJsonParamRefs, validateNonDiscriminatedOneOfs)
}
}

Expand Down Expand Up @@ -367,4 +368,118 @@ object JsonSerdeGenerator {
serde
}
}

///
/// Zio
///
private def genZioSerdes(
doc: OpenapiDocument,
allSchemas: Map[String, OpenapiSchemaType],
allTransitiveJsonParamRefs: Set[String],
validateNonDiscriminatedOneOfs: Boolean
): Option[String] = {
doc.components
.map(_.schemas.flatMap {
// Enum serdes are generated at the declaration site
case (_, _: OpenapiSchemaEnum) => None
// We generate the serde if it's referenced in any json model
case (name, schema: OpenapiSchemaObject) if allTransitiveJsonParamRefs.contains(name) =>
Some(genZioObjectSerde(name, schema))
case (name, schema: OpenapiSchemaMap) if allTransitiveJsonParamRefs.contains(name) =>
Some(genZioMapSerde(name, schema))
case (name, schema: OpenapiSchemaOneOf) if allTransitiveJsonParamRefs.contains(name) =>
Some(genZioAdtSerde(allSchemas, schema, name, validateNonDiscriminatedOneOfs))
case (_, _: OpenapiSchemaObject | _: OpenapiSchemaMap | _: OpenapiSchemaEnum | _: OpenapiSchemaOneOf) => None
case (n, x) => throw new NotImplementedError(s"Only objects, enums, maps and oneOf supported! (for $n found ${x})")
})
.map(_.mkString("\n"))
}

private def genZioObjectSerde(name: String, schema: OpenapiSchemaObject): String = {
val subs = schema.properties.collect {
case (k, OpenapiSchemaField(`type`: OpenapiSchemaObject, _)) => genZioObjectSerde(s"$name${k.capitalize}", `type`)
case (k, OpenapiSchemaField(OpenapiSchemaArray(`type`: OpenapiSchemaObject, _), _)) =>
genZioObjectSerde(s"$name${k.capitalize}Item", `type`)
case (k, OpenapiSchemaField(OpenapiSchemaMap(`type`: OpenapiSchemaObject, _), _)) =>
genZioObjectSerde(s"$name${k.capitalize}Item", `type`)
} match {
case Nil => ""
case s => s.mkString("", "\n", "\n")
}
val uncapitalisedName = BasicGenerator.uncapitalise(name)
s"""${subs}implicit lazy val ${uncapitalisedName}JsonDecoder: zio.json.JsonDecoder[$name] = zio.json.DeriveJsonDecoder.gen[$name]
|implicit lazy val ${uncapitalisedName}JsonEncoder: zio.json.JsonEncoder[$name] = zio.json.DeriveJsonEncoder.gen[$name]""".stripMargin
}

private def genZioMapSerde(name: String, schema: OpenapiSchemaMap): String = {
val subs = schema.items match {
case `type`: OpenapiSchemaObject => Some(genZioObjectSerde(s"${name}ObjectsItem", `type`))
case _ => None
}
subs.fold("")("\n" + _)
}

private def genZioAdtSerde(
allSchemas: Map[String, OpenapiSchemaType],
schema: OpenapiSchemaOneOf,
name: String,
validateNonDiscriminatedOneOfs: Boolean
): String = {
val uncapitalisedName = BasicGenerator.uncapitalise(name)

schema match {
case OpenapiSchemaOneOf(_, Some(discriminator)) =>
val subtypeNames = schema.types.map {
case ref: OpenapiSchemaRef => ref.stripped
case other => throw new IllegalArgumentException(s"oneOf subtypes must be refs to explicit schema models, found $other for $name")
}
val schemaToJsonMapping = discriminator.mapping match {
case Some(mapping) =>
mapping.map { case (jsonValue, fullRef) => fullRef.stripPrefix("#/components/schemas/") -> jsonValue }
case None => subtypeNames.map(s => s -> s).toMap
}
val encoders = subtypeNames
.map { t =>
val jsonTypeName = schemaToJsonMapping(t)
s"""case x: $t => zio.json.ast.Json.decoder.decodeJson(zio.json.JsonEncoder[$t].encodeJson(x)).getOrElse(throw new RuntimeException("Unable to encode tagged ADT type ${name} to json")).mapObject(_.add("${discriminator.propertyName}", zio.json.ast.Json.Str("$jsonTypeName")))"""
}
.mkString("\n")
val decoders = subtypeNames
.map { t => s"""case zio.json.ast.Json.Str("${schemaToJsonMapping(t)}") => zio.json.JsonDecoder[$t].fromJsonAST(json)""" }
.mkString("\n")
s"""implicit lazy val ${uncapitalisedName}JsonEncoder: zio.json.JsonEncoder[$name] = zio.json.JsonEncoder[zio.json.ast.Json].contramap {
|${indent(2)(encoders)}
|}
|implicit lazy val ${uncapitalisedName}JsonDecoder: zio.json.JsonDecoder[$name] = zio.json.JsonDecoder[zio.json.ast.Json].mapOrFail {
| case [email protected](fields) =>
| (fields.find(_._1 == "type") match {
| case None => Left("Unable to decode json to tagged ADT type ${name}")
| case Some(r) => Right(r._2)
| }).flatMap {
|${indent(6)(decoders)}
| case _ => Left("Unable to decode json to tagged ADT type ${name}")
| }
| case _ => Left("Unable to decode json to tagged ADT type ${name}")
|}""".stripMargin
case OpenapiSchemaOneOf(_, None) =>
val subtypeNames = schema.types.map {
case ref: OpenapiSchemaRef => ref.stripped
case other => throw new IllegalArgumentException(s"oneOf subtypes must be refs to explicit schema models, found $other for $name")
}
if (validateNonDiscriminatedOneOfs) checkForSoundness(allSchemas)(schema.types.map(_.asInstanceOf[OpenapiSchemaRef]))
val encoders = subtypeNames.map(t => s"case x: $t => zio.json.JsonEncoder[$t].unsafeEncode(x, indent, out)").mkString("\n")
val decoders = subtypeNames.map(t => s"zio.json.JsonDecoder[$t].asInstanceOf[zio.json.JsonDecoder[$name]]").mkString(",\n")
s"""implicit lazy val ${uncapitalisedName}JsonEncoder: zio.json.JsonEncoder[$name] = new zio.json.JsonEncoder[$name] {
| override def unsafeEncode(v: $name, indent: Option[Int], out: zio.json.internal.Write): Unit = {
| v match {
|${indent(6)(encoders)}
| }
| }
|}
|implicit lazy val ${uncapitalisedName}JsonDecoder: zio.json.JsonDecoder[$name] =
| List[zio.json.JsonDecoder[$name]](
|${indent(4)(decoders)}
| ).reduceLeft(_ orElse _)""".stripMargin
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

package sttp.tapir.generated

object TapirGeneratedEndpoints {

import sttp.tapir._
import sttp.tapir.model._
import sttp.tapir.generic.auto._
import sttp.tapir.json.zio._
import zio.json._

import sttp.tapir.generated.TapirGeneratedEndpointsJsonSerdes._
import TapirGeneratedEndpointsSchemas._

sealed trait ADTWithoutDiscriminator
sealed trait ADTWithDiscriminator
sealed trait ADTWithDiscriminatorNoMapping
case class SubtypeWithoutD1 (
s: String,
i: Option[Int] = None,
a: Seq[String],
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithD1 (
s: String,
i: Option[Int] = None,
d: Option[Double] = None
) extends ADTWithDiscriminator with ADTWithDiscriminatorNoMapping
case class SubtypeWithoutD3 (
s: String,
i: Option[Int] = None,
d: Option[Double] = None,
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithoutD2 (
a: Seq[String],
absent: Option[String] = None
) extends ADTWithoutDiscriminator
case class SubtypeWithD2 (
s: String,
a: Option[Seq[String]] = None
) extends ADTWithDiscriminator with ADTWithDiscriminatorNoMapping



lazy val putAdtTest =
endpoint
.put
.in(("adt" / "test"))
.in(jsonBody[ADTWithoutDiscriminator])
.out(jsonBody[ADTWithoutDiscriminator].description("successful operation"))

lazy val postAdtTest =
endpoint
.post
.in(("adt" / "test"))
.in(jsonBody[ADTWithDiscriminatorNoMapping])
.out(jsonBody[ADTWithDiscriminator].description("successful operation"))


lazy val generatedEndpoints = List(putAdtTest, postAdtTest)

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package sttp.tapir.generated

object TapirGeneratedEndpointsJsonSerdes {
import sttp.tapir.generated.TapirGeneratedEndpoints._
import sttp.tapir.generic.auto._
implicit lazy val aDTWithDiscriminatorJsonEncoder: zio.json.JsonEncoder[ADTWithDiscriminator] = zio.json.JsonEncoder[zio.json.ast.Json].contramap {
case x: SubtypeWithD1 => zio.json.ast.Json.decoder.decodeJson(zio.json.JsonEncoder[SubtypeWithD1].encodeJson(x)).getOrElse(throw new RuntimeException("Unable to encode tagged ADT type ADTWithDiscriminator to json")).mapObject(_.add("type", zio.json.ast.Json.Str("SubA")))
case x: SubtypeWithD2 => zio.json.ast.Json.decoder.decodeJson(zio.json.JsonEncoder[SubtypeWithD2].encodeJson(x)).getOrElse(throw new RuntimeException("Unable to encode tagged ADT type ADTWithDiscriminator to json")).mapObject(_.add("type", zio.json.ast.Json.Str("SubB")))
}
implicit lazy val aDTWithDiscriminatorJsonDecoder: zio.json.JsonDecoder[ADTWithDiscriminator] = zio.json.JsonDecoder[zio.json.ast.Json].mapOrFail {
case [email protected](fields) =>
(fields.find(_._1 == "type") match {
case None => Left("Unable to decode json to tagged ADT type ADTWithDiscriminator")
case Some(r) => Right(r._2)
}).flatMap {
case zio.json.ast.Json.Str("SubA") => zio.json.JsonDecoder[SubtypeWithD1].fromJsonAST(json)
case zio.json.ast.Json.Str("SubB") => zio.json.JsonDecoder[SubtypeWithD2].fromJsonAST(json)
case _ => Left("Unable to decode json to tagged ADT type ADTWithDiscriminator")
}
case _ => Left("Unable to decode json to tagged ADT type ADTWithDiscriminator")
}
implicit lazy val subtypeWithoutD1JsonDecoder: zio.json.JsonDecoder[SubtypeWithoutD1] = zio.json.DeriveJsonDecoder.gen[SubtypeWithoutD1]
implicit lazy val subtypeWithoutD1JsonEncoder: zio.json.JsonEncoder[SubtypeWithoutD1] = zio.json.DeriveJsonEncoder.gen[SubtypeWithoutD1]
implicit lazy val subtypeWithD1JsonDecoder: zio.json.JsonDecoder[SubtypeWithD1] = zio.json.DeriveJsonDecoder.gen[SubtypeWithD1]
implicit lazy val subtypeWithD1JsonEncoder: zio.json.JsonEncoder[SubtypeWithD1] = zio.json.DeriveJsonEncoder.gen[SubtypeWithD1]
implicit lazy val aDTWithDiscriminatorNoMappingJsonEncoder: zio.json.JsonEncoder[ADTWithDiscriminatorNoMapping] = zio.json.JsonEncoder[zio.json.ast.Json].contramap {
case x: SubtypeWithD1 => zio.json.ast.Json.decoder.decodeJson(zio.json.JsonEncoder[SubtypeWithD1].encodeJson(x)).getOrElse(throw new RuntimeException("Unable to encode tagged ADT type ADTWithDiscriminatorNoMapping to json")).mapObject(_.add("type", zio.json.ast.Json.Str("SubtypeWithD1")))
case x: SubtypeWithD2 => zio.json.ast.Json.decoder.decodeJson(zio.json.JsonEncoder[SubtypeWithD2].encodeJson(x)).getOrElse(throw new RuntimeException("Unable to encode tagged ADT type ADTWithDiscriminatorNoMapping to json")).mapObject(_.add("type", zio.json.ast.Json.Str("SubtypeWithD2")))
}
implicit lazy val aDTWithDiscriminatorNoMappingJsonDecoder: zio.json.JsonDecoder[ADTWithDiscriminatorNoMapping] = zio.json.JsonDecoder[zio.json.ast.Json].mapOrFail {
case [email protected](fields) =>
(fields.find(_._1 == "type") match {
case None => Left("Unable to decode json to tagged ADT type ADTWithDiscriminatorNoMapping")
case Some(r) => Right(r._2)
}).flatMap {
case zio.json.ast.Json.Str("SubtypeWithD1") => zio.json.JsonDecoder[SubtypeWithD1].fromJsonAST(json)
case zio.json.ast.Json.Str("SubtypeWithD2") => zio.json.JsonDecoder[SubtypeWithD2].fromJsonAST(json)
case _ => Left("Unable to decode json to tagged ADT type ADTWithDiscriminatorNoMapping")
}
case _ => Left("Unable to decode json to tagged ADT type ADTWithDiscriminatorNoMapping")
}
implicit lazy val subtypeWithoutD3JsonDecoder: zio.json.JsonDecoder[SubtypeWithoutD3] = zio.json.DeriveJsonDecoder.gen[SubtypeWithoutD3]
implicit lazy val subtypeWithoutD3JsonEncoder: zio.json.JsonEncoder[SubtypeWithoutD3] = zio.json.DeriveJsonEncoder.gen[SubtypeWithoutD3]
implicit lazy val subtypeWithoutD2JsonDecoder: zio.json.JsonDecoder[SubtypeWithoutD2] = zio.json.DeriveJsonDecoder.gen[SubtypeWithoutD2]
implicit lazy val subtypeWithoutD2JsonEncoder: zio.json.JsonEncoder[SubtypeWithoutD2] = zio.json.DeriveJsonEncoder.gen[SubtypeWithoutD2]
implicit lazy val subtypeWithD2JsonDecoder: zio.json.JsonDecoder[SubtypeWithD2] = zio.json.DeriveJsonDecoder.gen[SubtypeWithD2]
implicit lazy val subtypeWithD2JsonEncoder: zio.json.JsonEncoder[SubtypeWithD2] = zio.json.DeriveJsonEncoder.gen[SubtypeWithD2]
implicit lazy val aDTWithoutDiscriminatorJsonEncoder: zio.json.JsonEncoder[ADTWithoutDiscriminator] = new zio.json.JsonEncoder[ADTWithoutDiscriminator] {
override def unsafeEncode(v: ADTWithoutDiscriminator, indent: Option[Int], out: zio.json.internal.Write): Unit = {
v match {
case x: SubtypeWithoutD1 => zio.json.JsonEncoder[SubtypeWithoutD1].unsafeEncode(x, indent, out)
case x: SubtypeWithoutD2 => zio.json.JsonEncoder[SubtypeWithoutD2].unsafeEncode(x, indent, out)
case x: SubtypeWithoutD3 => zio.json.JsonEncoder[SubtypeWithoutD3].unsafeEncode(x, indent, out)
}
}
}
implicit lazy val aDTWithoutDiscriminatorJsonDecoder: zio.json.JsonDecoder[ADTWithoutDiscriminator] =
List[zio.json.JsonDecoder[ADTWithoutDiscriminator]](
zio.json.JsonDecoder[SubtypeWithoutD1].asInstanceOf[zio.json.JsonDecoder[ADTWithoutDiscriminator]],
zio.json.JsonDecoder[SubtypeWithoutD2].asInstanceOf[zio.json.JsonDecoder[ADTWithoutDiscriminator]],
zio.json.JsonDecoder[SubtypeWithoutD3].asInstanceOf[zio.json.JsonDecoder[ADTWithoutDiscriminator]]
).reduceLeft(_ orElse _)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package sttp.tapir.generated

object TapirGeneratedEndpointsSchemas {
import sttp.tapir.generated.TapirGeneratedEndpoints._
import sttp.tapir.generic.auto._
implicit lazy val subtypeWithD1TapirSchema: sttp.tapir.Schema[SubtypeWithD1] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithD2TapirSchema: sttp.tapir.Schema[SubtypeWithD2] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithoutD1TapirSchema: sttp.tapir.Schema[SubtypeWithoutD1] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithoutD2TapirSchema: sttp.tapir.Schema[SubtypeWithoutD2] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithoutD3TapirSchema: sttp.tapir.Schema[SubtypeWithoutD3] = sttp.tapir.Schema.derived
implicit lazy val aDTWithDiscriminatorTapirSchema: sttp.tapir.Schema[ADTWithDiscriminator] = {
val derived = implicitly[sttp.tapir.generic.Derived[sttp.tapir.Schema[ADTWithDiscriminator]]].value
derived.schemaType match {
case s: sttp.tapir.SchemaType.SCoproduct[_] => derived.copy(schemaType = s.addDiscriminatorField(
sttp.tapir.FieldName("type"),
sttp.tapir.Schema.string,
Map(
"SubA" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD1")),
"SubB" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD2"))
)
))
case _ => throw new IllegalStateException("Derived schema for ADTWithDiscriminator should be a coproduct")
}
}
implicit lazy val aDTWithDiscriminatorNoMappingTapirSchema: sttp.tapir.Schema[ADTWithDiscriminatorNoMapping] = {
val derived = implicitly[sttp.tapir.generic.Derived[sttp.tapir.Schema[ADTWithDiscriminatorNoMapping]]].value
derived.schemaType match {
case s: sttp.tapir.SchemaType.SCoproduct[_] => derived.copy(schemaType = s.addDiscriminatorField(
sttp.tapir.FieldName("type"),
sttp.tapir.Schema.string,
Map(
"SubtypeWithD1" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD1")),
"SubtypeWithD2" -> sttp.tapir.SchemaType.SRef(sttp.tapir.Schema.SName("sttp.tapir.generated.TapirGeneratedEndpoints.SubtypeWithD2"))
)
))
case _ => throw new IllegalStateException("Derived schema for ADTWithDiscriminatorNoMapping should be a coproduct")
}
}
implicit lazy val aDTWithoutDiscriminatorTapirSchema: sttp.tapir.Schema[ADTWithoutDiscriminator] = sttp.tapir.Schema.derived
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
lazy val root = (project in file("."))
.enablePlugins(OpenapiCodegenPlugin)
.settings(
scalaVersion := "2.13.13",
version := "0.1",
openapiJsonSerdeLib := "zio"
)

libraryDependencies ++= Seq(
"com.softwaremill.sttp.tapir" %% "tapir-json-zio" % "1.10.0",
"com.softwaremill.sttp.tapir" %% "tapir-openapi-docs" % "1.10.0",
"com.softwaremill.sttp.apispec" %% "openapi-circe-yaml" % "0.8.0",
"org.scalatest" %% "scalatest" % "3.2.18" % Test,
"com.softwaremill.sttp.tapir" %% "tapir-sttp-stub-server" % "1.10.0" % Test
)

import scala.io.Source

TaskKey[Unit]("check") := {
def check(generatedFileName: String, expectedFileName: String) = {
val generatedCode =
Source.fromFile(s"target/scala-2.13/src_managed/main/sbt-openapi-codegen/$generatedFileName").getLines.mkString("\n")
val expectedCode = Source.fromFile(expectedFileName).getLines.mkString("\n")
val generatedTrimmed =
generatedCode.linesIterator.zipWithIndex.filterNot(_._1.forall(_.isWhitespace)).map { case (a, i) => a.trim -> i }.toSeq
val expectedTrimmed = expectedCode.linesIterator.filterNot(_.forall(_.isWhitespace)).map(_.trim).toSeq
if (generatedTrimmed.size != expectedTrimmed.size)
sys.error(s"expected ${expectedTrimmed.size} non-empty lines, found ${generatedTrimmed.size}")
generatedTrimmed.zip(expectedTrimmed).foreach { case ((a, i), b) =>
if (a != b) sys.error(s"Generated code in file $generatedCode did not match (expected '$b' on line $i, found '$a')")
}
}
Seq(
"TapirGeneratedEndpoints.scala" -> "Expected.scala.txt",
"TapirGeneratedEndpointsJsonSerdes.scala" -> "ExpectedJsonSerdes.scala.txt",
"TapirGeneratedEndpointsSchemas.scala" -> "ExpectedSchemas.scala.txt"
).foreach { case (generated, expected) => check(generated, expected) }
()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=1.9.9
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
val pluginVersion = System.getProperty("plugin.version")
if (pluginVersion == null)
throw new RuntimeException("""|
|
|The system property 'plugin.version' is not defined.
|Specify this property using the scriptedLaunchOpts -D.
|
|""".stripMargin)
else addSbtPlugin("com.softwaremill.sttp.tapir" % "sbt-openapi-codegen" % pluginVersion)
}
Loading

0 comments on commit 5632d8b

Please sign in to comment.