Skip to content

Commit

Permalink
codegen: Semiauto schema derivation (#3671)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson authored Apr 19, 2024
1 parent 1cf911d commit 20a526c
Show file tree
Hide file tree
Showing 18 changed files with 518 additions and 79 deletions.
7 changes: 4 additions & 3 deletions doc/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@ defined case-classes and endpoint definitions.
The generator currently supports these settings, you can override them in the `build.sbt`;

```eval_rst
===================================== ==================================== =======================================================================================
===================================== ==================================== ==================================================================================================
setting default value description
===================================== ==================================== =======================================================================================
===================================== ==================================== ==================================================================================================
openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions.
openapiPackage sttp.tapir.generated The name for the generated package.
openapiObject TapirGeneratedEndpoints The name for the generated object.
openapiUseHeadTagForObjectName false If true, put endpoints in separate files based on first declared tag.
openapiJsonSerdeLib circe The json serde library to use.
openapiValidateNonDiscriminatedOneOfs true Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated.
===================================== ==================================== =======================================================================================
openapiMaxSchemasPerFile 400 Maximum number of schemas to generate in a single file (tweak if hitting javac class size limits).
===================================== ==================================== ==================================================================================================
```

The general usage is;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ object GenScala {
"v"
)
.orFalse
private val maxSchemasPerFileOpt: Opts[Option[Int]] =
Opts
.option[Int]("maxSchemasPerFile", "Maximum number of schemas to generate in a single file.", "m")
.orNone

private val jsonLibOpt: Opts[Option[String]] =
Opts.option[String]("jsonLib", "Json library to use for serdes", "j").orNone
Expand All @@ -71,8 +75,8 @@ object GenScala {
}

val cmd: Command[IO[ExitCode]] = Command("genscala", "Generate Scala classes", helpFlag = true) {
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt, jsonLibOpt, validateNonDiscriminatedOneOfsOpt)
.mapN { case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames, jsonLib, validateNonDiscriminatedOneOfs) =>
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt, jsonLibOpt, validateNonDiscriminatedOneOfsOpt, maxSchemasPerFileOpt)
.mapN { case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames, jsonLib, validateNonDiscriminatedOneOfs, maxSchemasPerFile) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

def generateCode(doc: OpenapiDocument): IO[Unit] = for {
Expand All @@ -84,7 +88,8 @@ object GenScala {
targetScala3,
headTagForNames,
jsonLib.getOrElse("circe"),
validateNonDiscriminatedOneOfs
validateNonDiscriminatedOneOfs,
maxSchemasPerFile.getOrElse(400)
)
)
destFiles <- contents.toVector.traverse { case (fileName, content) => writeGeneratedFile(destDir, fileName, content) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ object BasicGenerator {
targetScala3: Boolean,
useHeadTagForObjectNames: Boolean,
jsonSerdeLib: String,
validateNonDiscriminatedOneOfs: Boolean
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int
): Map[String, String] = {
val normalisedJsonLib = jsonSerdeLib.toLowerCase match {
case "circe" => JsonSerdeLib.Circe
Expand All @@ -47,7 +48,7 @@ object BasicGenerator {
}

val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs) = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val GeneratedClassDefinitions(classDefns, extras) =
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
doc = doc,
Expand All @@ -56,15 +57,19 @@ object BasicGenerator {
jsonSerdeLib = normalisedJsonLib,
jsonParamRefs = jsonParamRefs,
fullModelPath = s"$packagePath.$objName",
validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs
validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs,
maxSchemasPerFile = maxSchemasPerFile
)
.getOrElse(GeneratedClassDefinitions("", None))
val isSplit = extras.nonEmpty
val internalImports =
if (isSplit)
s"""import $packagePath.$objName._
|import ${objName}JsonSerdes._""".stripMargin
else s"import $objName._"
.getOrElse(GeneratedClassDefinitions("", None, Nil))
val hasJsonSerdes = jsonSerdes.nonEmpty

val maybeJsonImport = if (hasJsonSerdes) s"\nimport $packagePath.${objName}JsonSerdes._" else ""
val maybeSchemaImport =
if (schemas.size > 1) (1 to schemas.size).map(i => s"import ${objName}Schemas$i._").mkString("\n", "\n", "")
else if (schemas.size == 1) s"\nimport ${objName}Schemas._"
else ""
val internalImports = s"import $packagePath.$objName._$maybeJsonImport$maybeSchemaImport"

val taggedObjs = endpointsByTag.collect {
case (Some(headTag), body) if body.nonEmpty =>
val taggedObj =
Expand All @@ -81,14 +86,39 @@ object BasicGenerator {
|}""".stripMargin
headTag -> taggedObj
}
val extraObj = extras.map { body =>

val jsonSerdeObj = jsonSerdes.map { body =>
s"""package $packagePath
|
|object ${objName}JsonSerdes {
| import $packagePath.$objName._
| import sttp.tapir.generic.auto._
|${indent(2)(body)}
|}""".stripMargin
}

val schemaObjs = if (schemas.size > 1) schemas.zipWithIndex.map { case (body, idx) =>
val priorImports = (0 until idx).map { i => s"import $packagePath.${objName}Schemas${i + 1}._" }.mkString("\n")
val name = s"${objName}Schemas${idx + 1}"
name -> s"""package $packagePath
|
|object $name {
| import $packagePath.$objName._
| import sttp.tapir.generic.auto._
|${indent(2)(priorImports)}
|${indent(2)(body)}
|}""".stripMargin
}
else if (schemas.size == 1)
Seq(s"${objName}Schemas" -> s"""package $packagePath
|
|object ${objName}Schemas {
| import $packagePath.$objName._
| import sttp.tapir.generic.auto._
|${indent(2)(schemas.head)}
|}""".stripMargin)
else Nil

val endpointsInMain = endpointsByTag.getOrElse(None, "")

val maybeSpecificationExtensionKeys = doc.paths
Expand All @@ -100,21 +130,21 @@ object BasicGenerator {
val values = pairs.map(_._2)
val `type` = SpecificationExtensionRenderer.renderCombinedType(values)
val name = strippedToCamelCase(keyName)
val uncapitalisedName = name.head.toLower + name.tail
val capitalisedName = name.head.toUpper + name.tail
val uncapitalisedName = uncapitalise(name)
val capitalisedName = uncapitalisedName.capitalize
s"""type ${capitalisedName}Extension = ${`type`}
|val ${uncapitalisedName}ExtensionKey = new sttp.tapir.AttributeKey[${capitalisedName}Extension]("$packagePath.$objName.${capitalisedName}Extension")
|""".stripMargin
}
.mkString("\n")

val serdeImport = if (isSplit && endpointsInMain.nonEmpty) s"\nimport $packagePath.${objName}JsonSerdes._" else ""
val mainObj = s"""|
val extraImports = if (endpointsInMain.nonEmpty) s"$maybeJsonImport$maybeSchemaImport" else ""
val mainObj = s"""
|package $packagePath
|
|object $objName {
|
|${indent(2)(imports(normalisedJsonLib) + serdeImport)}
|${indent(2)(imports(normalisedJsonLib) + extraImports)}
|
|${indent(2)(classDefns)}
|
Expand All @@ -124,7 +154,7 @@ object BasicGenerator {
|
|}
|""".stripMargin
taggedObjs ++ extraObj.map(s"${objName}JsonSerdes" -> _) + (objName -> mainObj)
taggedObjs ++ jsonSerdeObj.map(s"${objName}JsonSerdes" -> _) ++ schemaObjs + (objName -> mainObj)
}

private[codegen] def imports(jsonSerdeLib: JsonSerdeLib.JsonSerdeLib): String = {
Expand Down Expand Up @@ -184,4 +214,6 @@ object BasicGenerator {
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString

def uncapitalise(name: String): String = name.head.toLower +: name.tail
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType._

import scala.annotation.tailrec

case class GeneratedClassDefinitions(classRepr: String, serdeRepr: Option[String])
case class GeneratedClassDefinitions(classRepr: String, serdeRepr: Option[String], schemaRepr: Seq[String])

class ClassDefinitionGenerator {

Expand All @@ -18,7 +18,8 @@ class ClassDefinitionGenerator {
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib = JsonSerdeLib.Circe,
jsonParamRefs: Set[String] = Set.empty,
fullModelPath: String = "",
validateNonDiscriminatedOneOfs: Boolean = true
validateNonDiscriminatedOneOfs: Boolean = true,
maxSchemasPerFile: Int = 400
): Option[GeneratedClassDefinitions] = {
val allSchemas: Map[String, OpenapiSchemaType] = doc.components.toSeq.flatMap(_.schemas).toMap
val allOneOfSchemas = allSchemas.collect { case (name, oneOf: OpenapiSchemaOneOf) => name -> oneOf }.toSeq
Expand All @@ -40,7 +41,8 @@ class ClassDefinitionGenerator {

val adtTypes = adtInheritanceMap.flatMap(_._2).toSeq.distinct.map(name => s"sealed trait $name").mkString("", "\n", "\n")
val enumQuerySerdeHelper = if (!generatesQueryParamEnums) "" else enumQuerySerdeHelperDefn(targetScala3)
val postDefns = JsonSerdeGenerator.serdeDefs(
val schemas = SchemaGenerator.generateSchemas(doc, allSchemas, fullModelPath, jsonSerdeLib, maxSchemasPerFile)
val jsonSerdes = JsonSerdeGenerator.serdeDefs(
doc,
jsonSerdeLib,
jsonParamRefs,
Expand All @@ -63,8 +65,8 @@ class ClassDefinitionGenerator {
val helpers = (enumQuerySerdeHelper + adtTypes).linesIterator
.filterNot(_.forall(_.isWhitespace))
.mkString("\n")
// Json serdes live in a separate file from the class defns
defns.map(helpers + "\n" + _).map(defStr => GeneratedClassDefinitions(defStr, postDefns))
// Json serdes & schemas live in separate files from the class defns
defns.map(helpers + "\n" + _).map(defStr => GeneratedClassDefinitions(defStr, jsonSerdes, schemas))
}

private def mkMapParentsByChild(allOneOfSchemas: Seq[(String, OpenapiSchemaOneOf)]): Map[String, Seq[String]] =
Expand Down Expand Up @@ -219,7 +221,7 @@ class ClassDefinitionGenerator {
| case ${obj.items.map(_.value).mkString(", ")}
|}""".stripMargin :: Nil
} else {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
val members = obj.items.map { i => s"case object ${i.value} extends $name" }
val maybeCodecExtension = jsonSerdeLib match {
case _ if !jsonParamRefs.contains(name) && !queryParamRefs.contains(name) => ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
OpenapiSchemaBoolean,
OpenapiSchemaEnum,
OpenapiSchemaField,
OpenapiSchemaMap,
OpenapiSchemaNumericType,
OpenapiSchemaObject,
Expand Down Expand Up @@ -86,7 +87,7 @@ object JsonSerdeGenerator {
// if lhs has some required non-nullable fields with no default that rhs will never contain, then right cannot be mistaken for left
if ((requiredL.keySet -- anyR.keySet).nonEmpty) false
else {
// otherwise, if any required field on rhs can't look like the similarly-named field on lhs, then r can't look like l
// otherwise, if any field on rhs required by lhs can't look like the similarly-named field on lhs, then r can't look like l
val rForRequiredL = anyR.filter(requiredL.keySet contains _._1)
requiredL.forall { case (k, lhsV) => rCanLookLikeL(lhsV.`type`, rForRequiredL(k).`type`) }
}
Expand Down Expand Up @@ -118,8 +119,10 @@ object JsonSerdeGenerator {
// Enum serdes are generated at the declaration site
case (_, _: OpenapiSchemaEnum) => None
// We generate the serde if it's referenced in any json model
case (name, _: OpenapiSchemaObject | _: OpenapiSchemaMap) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceNamedSerde(name))
case (name, schema: OpenapiSchemaObject) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceObjectSerde(name, schema))
case (name, schema: OpenapiSchemaMap) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceMapSerde(name, schema))
case (name, schema: OpenapiSchemaOneOf) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceAdtSerde(allSchemas, schema, name, validateNonDiscriminatedOneOfs))
case (_, _: OpenapiSchemaObject | _: OpenapiSchemaMap | _: OpenapiSchemaEnum | _: OpenapiSchemaOneOf) => None
Expand All @@ -128,19 +131,36 @@ object JsonSerdeGenerator {
.map(_.mkString("\n"))
}

private def genCirceNamedSerde(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
s"""implicit lazy val ${uncapitalisedName}JsonDecoder: io.circe.Decoder[$name] = io.circe.generic.semiauto.deriveDecoder[$name]
private def genCirceObjectSerde(name: String, schema: OpenapiSchemaObject): String = {
val subs = schema.properties.collect {
case (k, OpenapiSchemaField(`type`: OpenapiSchemaObject, _)) => genCirceObjectSerde(s"$name${k.capitalize}", `type`)
case (k, OpenapiSchemaField(OpenapiSchemaArray(`type`: OpenapiSchemaObject, _), _)) =>
genCirceObjectSerde(s"$name${k.capitalize}Item", `type`)
case (k, OpenapiSchemaField(OpenapiSchemaMap(`type`: OpenapiSchemaObject, _), _)) =>
genCirceObjectSerde(s"$name${k.capitalize}Item", `type`)
} match {
case Nil => ""
case s => s.mkString("", "\n", "\n")
}
val uncapitalisedName = BasicGenerator.uncapitalise(name)
s"""${subs}implicit lazy val ${uncapitalisedName}JsonDecoder: io.circe.Decoder[$name] = io.circe.generic.semiauto.deriveDecoder[$name]
|implicit lazy val ${uncapitalisedName}JsonEncoder: io.circe.Encoder[$name] = io.circe.generic.semiauto.deriveEncoder[$name]""".stripMargin
}
private def genCirceMapSerde(name: String, schema: OpenapiSchemaMap): String = {
val subs = schema.items match {
case `type`: OpenapiSchemaObject => Some(genCirceObjectSerde(s"${name}ObjectsItem", `type`))
case _ => None
}
subs.fold("")("\n" + _)
}

private def genCirceAdtSerde(
allSchemas: Map[String, OpenapiSchemaType],
schema: OpenapiSchemaOneOf,
name: String,
validateNonDiscriminatedOneOfs: Boolean
): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)

schema match {
case OpenapiSchemaOneOf(_, Some(discriminator)) =>
Expand Down Expand Up @@ -256,7 +276,7 @@ object JsonSerdeGenerator {
}

private def genJsoniterClassSerde(supertypes: Seq[OpenapiSchemaOneOf])(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
if (supertypes.exists(_.discriminator.isDefined))
throw new NotImplementedError(
s"A class cannot be used both in a oneOf with discriminator and at the top level when using jsoniter serdes at $name"
Expand All @@ -266,13 +286,13 @@ object JsonSerdeGenerator {
}

private def genJsoniterEnumSerde(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
s"""
|implicit lazy val ${uncapitalisedName}JsonCodec: $jsoniterPkgCore.JsonValueCodec[${name}] = $jsoniterPkgMacros.JsonCodecMaker.make($jsoniteEnumConfig.withDiscriminatorFieldName(scala.None))""".stripMargin
}

private def genJsoniterNamedSerde(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
s"""
|implicit lazy val ${uncapitalisedName}JsonCodec: $jsoniterPkgCore.JsonValueCodec[$name] = $jsoniterPkgMacros.JsonCodecMaker.make($jsoniterBaseConfig)""".stripMargin
}
Expand All @@ -285,7 +305,7 @@ object JsonSerdeGenerator {
validateNonDiscriminatedOneOfs: Boolean
): String = {
val fullPathPrefix = maybeFullModelPath.map(_ + ".").getOrElse("")
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
schema match {
case OpenapiSchemaOneOf(_, Some(discriminator)) =>
def subtypeNames = schema.types.map {
Expand Down Expand Up @@ -321,7 +341,7 @@ object JsonSerdeGenerator {
if (validateNonDiscriminatedOneOfs) checkForSoundness(allSchemas)(schema.types.map(_.asInstanceOf[OpenapiSchemaRef]))
val childNameAndSerde = schemas.collect { case ref: OpenapiSchemaRef =>
val name = ref.stripped
name -> s"${name.head.toLower +: name.tail}JsonCodec"
name -> s"${BasicGenerator.uncapitalise(name)}JsonCodec"
}
val childSerdes = childNameAndSerde.map(_._2)
val doDecode = childSerdes.mkString("List(\n ", ",\n ", ")\n") +
Expand Down
Loading

0 comments on commit 20a526c

Please sign in to comment.