Skip to content

Commit

Permalink
improve serde gen for inline class defns
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson committed Mar 3, 2025
1 parent 69dca11 commit b8f8e85
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object BasicGenerator {
StreamingImplementation.FS2
}

val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams, inlineDefns) =
endpointGenerator.endpointDefs(
doc,
useHeadTagForObjectNames,
Expand Down Expand Up @@ -214,6 +214,7 @@ object BasicGenerator {
|${indent(2)(queryParamSupport)}
|
|${indent(2)(classDefns)}
|${indent(2)(inlineDefns.mkString("\n"))}
|
|${indent(2)(maybeSpecificationExtensionKeys)}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ case class GeneratedEndpoints(
namesAndParamsByFile: Seq[GeneratedEndpointsForFile],
queryParamRefs: Set[String],
jsonParamRefs: Set[String],
definesEnumQueryParam: Boolean
definesEnumQueryParam: Boolean,
inlineDefns: Seq[String]
) {
def merge(that: GeneratedEndpoints): GeneratedEndpoints =
GeneratedEndpoints(
Expand All @@ -62,14 +63,16 @@ case class GeneratedEndpoints(
.toSeq,
queryParamRefs ++ that.queryParamRefs,
jsonParamRefs ++ that.jsonParamRefs,
definesEnumQueryParam || that.definesEnumQueryParam
definesEnumQueryParam || that.definesEnumQueryParam,
inlineDefns ++ that.inlineDefns
)
}
case class EndpointDefs(
endpointDecls: Map[Option[String], String],
queryOrPathParamRefs: Set[String],
jsonParamRefs: Set[String],
enumsDefinedOnEndpointParams: Boolean
enumsDefinedOnEndpointParams: Boolean,
inlineDefns: Seq[String]
)

class EndpointGenerator {
Expand Down Expand Up @@ -100,10 +103,10 @@ class EndpointGenerator {
): EndpointDefs = {
val capabilities = capabilityImpl(streamingImplementation)
val components = Option(doc.components).flatten
val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam) =
val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam, inlineDefns) =
doc.paths
.map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib, streamingImplementation, doc))
.foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false))(_ merge _)
.foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false, Nil))(_ merge _)
val endpointDecls = endpointsByFile.map { case GeneratedEndpointsForFile(k, ge) =>
val definitions = ge
.map { case GeneratedEndpoint(name, definition, maybeInlineDefns, types) =>
Expand All @@ -126,7 +129,7 @@ class EndpointGenerator {
|$allEP
|""".stripMargin
}.toMap
EndpointDefs(endpointDecls, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam)
EndpointDefs(endpointDecls, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam, inlineDefns)
}

private[codegen] def generatedEndpoints(
Expand All @@ -140,7 +143,7 @@ class EndpointGenerator {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)

val (fileNamesAndParams, unflattenedParamRefs, definesParams) = p.methods
val (fileNamesAndParams, unflattenedParamRefs, inlineParamInfo) = p.methods
.map(_.withResolvedParentParameters(parameters, p.parameters))
.map { m =>
implicit val location: Location = Location(p.url, m.methodType)
Expand All @@ -156,7 +159,7 @@ class EndpointGenerator {
}
}

val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize))
val name = m.name(p.url)
val (pathDecl, pathTypes) = urlMapper(p.url, m.resolvedParameters)
val (securityDecl, securityTypes) = security(securitySchemes, m.security)
val (inParams, maybeLocalEnums, inTypes, inlineInDefns) =
Expand Down Expand Up @@ -202,9 +205,9 @@ class EndpointGenerator {
}
.toSet
(
(maybeTargetFileName, GeneratedEndpoint(name, definition, combine(maybeLocalEnums, inlineDefn), allTypes)),
(maybeTargetFileName, GeneratedEndpoint(name, definition, maybeLocalEnums, allTypes)),
(queryOrPathParamRefs, jsonParamRefs),
maybeLocalEnums.isDefined
(maybeLocalEnums.isDefined, inlineDefn)
)
} catch {
case e: NotImplementedError => throw e
Expand All @@ -217,11 +220,13 @@ class EndpointGenerator {
.groupBy(_._1)
.toSeq
.map { case (maybeTargetFileName, defns) => GeneratedEndpointsForFile(maybeTargetFileName, defns.map(_._2)) }
val (definesParams, inlineDefns) = inlineParamInfo.unzip
GeneratedEndpoints(
namesAndParamsByFile,
unflattenedQueryParamRefs.foldLeft(Set.empty[String])(_ ++ _),
unflattenedJsonParamRefs.foldLeft(Set.empty[String])(_ ++ _),
definesParams.contains(true)
definesParams.contains(true),
inlineDefns.flatten
)
}

Expand Down Expand Up @@ -569,20 +574,34 @@ class EndpointGenerator {
case "text/html" =>
MappedContentType("htmlBodyUtf8", "String")
case "application/json" =>
val outT = schema match {
val (outT, maybeInline) = schema match {
case st: OpenapiSchemaSimpleType =>
val (t, _) = mapSchemaSimpleTypeToType(st)
t
t -> None
case OpenapiSchemaArray(st: OpenapiSchemaSimpleType, _) =>
val (t, _) = mapSchemaSimpleTypeToType(st)
s"List[$t]"
s"List[$t]" -> None
case OpenapiSchemaMap(st: OpenapiSchemaSimpleType, _) =>
val (t, _) = mapSchemaSimpleTypeToType(st)
s"Map[String, $t]"
s"Map[String, $t]" -> None
case schemaRef: OpenapiSchemaObject if schemaRef.properties.forall(_._2.`type`.isInstanceOf[OpenapiSchemaSimpleType]) =>
val inlineClassName = endpointName.capitalize + position
val properties = schemaRef.properties.map { case (k, v) =>
val (st, nb) = mapSchemaSimpleTypeToType(v.`type`.asInstanceOf[OpenapiSchemaSimpleType], multipartForm = true)
val default = v.default
.map(j => " = " + DefaultValueRenderer.render(Map.empty, v.`type`, schemaRef.required.contains(k) || nb, RenderConfig())(j))
.getOrElse("")
s"$k: $st$default"
}
val inlineClassDefn =
s"""case class $inlineClassName (
|${indent(2)(properties.mkString(",\n"))}
|)""".stripMargin
inlineClassName -> Some(inlineClassDefn)
case x => bail(s"Can't create non-simple or array params as output (found $x)")
}
val req = if (required) outT else s"Option[$outT]"
MappedContentType(s"jsonBody[$req]", req)
MappedContentType(s"jsonBody[$req]", req, maybeInline)

case "multipart/form-data" =>
schema match {
Expand All @@ -591,7 +610,6 @@ class EndpointGenerator {
case schemaRef: OpenapiSchemaRef =>
val (t, _) = mapSchemaSimpleTypeToType(schemaRef, multipartForm = true)
MappedContentType(s"multipartBody[$t]", t)
// sack
case schemaRef: OpenapiSchemaObject if schemaRef.properties.forall(_._2.`type`.isInstanceOf[OpenapiSchemaStringType]) =>
val inlineClassName = endpointName.capitalize + position
val properties = schemaRef.properties.map { case (k, v) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,32 +265,52 @@ object JsonSerdeGenerator {
|implicit def optionCodec[T: $jsoniterPkgCore.JsonValueCodec]: $jsoniterPkgCore.JsonValueCodec[Option[T]] =
| $jsoniterPkgMacros.JsonCodecMaker.make[Option[T]]
|""".stripMargin
doc.components
.map(_.schemas.flatMap {
val docSchemas = doc.components.toSeq.flatMap(_.schemas)
val pathSchemas =
doc.paths.flatMap(p =>
p.methods.flatMap(m =>
m.responses
.flatMap(_.content)
.filter(o => o.contentType == "application/json" && o.schema.isInstanceOf[OpenapiSchemaObject])
.map(c => (m.name(p.url) + "Response", c.schema, true)) ++
m.requestBody.toSeq
.flatMap(_.content)
.filter(o => o.contentType == "application/json" && o.schema.isInstanceOf[OpenapiSchemaObject])
.map(c => (m.name(p.url) + "Request", c.schema, true))
)
)
(docSchemas.map { case (n, t) => (n, t, false) } ++ pathSchemas)
.flatMap {
// For standard objects, generate the schema if it's a 'top level' json schema or if it's referenced as a subtype of an ADT without a discriminator
case (name, _: OpenapiSchemaObject) =>
case (name, _: OpenapiSchemaObject, isJson) =>
val supertypes =
adtInheritanceMap.get(name).getOrElse(Nil).map(allSchemas.apply).collect { case oneOf: OpenapiSchemaOneOf => oneOf }
if (jsonParamRefs.contains(name) || supertypes.exists(_.discriminator.isEmpty)) Some(genJsoniterClassSerde(supertypes)(name))
adtInheritanceMap.getOrElse(name, Nil).map(allSchemas.apply).collect { case oneOf: OpenapiSchemaOneOf => oneOf }
if (isJson || jsonParamRefs.contains(name) || supertypes.exists(_.discriminator.isEmpty))
Some(genJsoniterClassSerde(supertypes)(name))
else None
// For named maps or seqs, only generate the schema if it's a 'top level' json schema
case (name, _: OpenapiSchemaMap) if jsonParamRefs.contains(name) =>
// For named maps, only generate the schema if it's a 'top level' json schema
case (name, _: OpenapiSchemaMap, isJson) if jsonParamRefs.contains(name) || isJson =>
Some(genJsoniterNamedSerde(name))
// For enums, generate the serde if it's referenced in any json model
case (name, _: OpenapiSchemaEnum) if allTransitiveJsonParamRefs.contains(name) =>
case (name, _: OpenapiSchemaEnum, _) if allTransitiveJsonParamRefs.contains(name) =>
Some(genJsoniterEnumSerde(name))
// For ADTs, generate the serde if it's referenced in any json model
case (name, schema: OpenapiSchemaOneOf) if allTransitiveJsonParamRefs.contains(name) =>
case (name, schema: OpenapiSchemaOneOf, _) if allTransitiveJsonParamRefs.contains(name) =>
Some(generateJsoniterAdtSerde(allSchemas, name, schema, validateNonDiscriminatedOneOfs))
case (
_,
_: OpenapiSchemaObject | _: OpenapiSchemaMap | _: OpenapiSchemaArray | _: OpenapiSchemaEnum | _: OpenapiSchemaOneOf |
_: OpenapiSchemaAny
_: OpenapiSchemaAny,
_
) =>
None
case (n, x) => throw new NotImplementedError(s"Only objects, enums, maps, arrays and oneOf supported! (for $n found ${x})")
})
.map(jsonSerdeHelpers + additionalExplicitSerdes + _.mkString("\n"))
case (n, x, _) => throw new NotImplementedError(s"Only objects, enums, maps, arrays and oneOf supported! (for $n found ${x})")
}
.foldLeft(Option.empty[String]) {
case (Some(a), b) => Some(a + "\n" + b)
case (None, a) => Some(a)
}
.map(jsonSerdeHelpers + additionalExplicitSerdes + _)
}

private def genJsoniterClassSerde(supertypes: Seq[OpenapiSchemaOneOf])(name: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import cats.implicits.toTraverseOps
import cats.syntax.either._
import OpenapiSchemaType.OpenapiSchemaRef
import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.strippedToCamelCase
// https://swagger.io/specification/
object OpenapiModels {

Expand Down Expand Up @@ -50,6 +51,7 @@ object OpenapiModels {
operationId: Option[String] = None,
specificationExtensions: Map[String, Json] = Map.empty
) {
def name(url: String) = strippedToCamelCase(operationId.getOrElse(methodType + url.capitalize))
def resolvedParameters: Seq[OpenapiParameter] = parameters.collect { case Resolved(t) => t }
def withResolvedParentParameters(
pMap: Map[String, OpenapiParameter],
Expand Down

0 comments on commit b8f8e85

Please sign in to comment.