Skip to content

Commit

Permalink
Support type spreading in records
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmurphy committed May 1, 2024
1 parent c138d06 commit 74b2029
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 105 deletions.
139 changes: 99 additions & 40 deletions ppx_src/src/Records.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ open Utils

type parsedRecordFieldDeclaration = {
name: string;
(* If this field was a spread, this is the name that comes after the three dots *)
spreadName: string option;
key: expression;
field: expression;
codecs: expression option * expression option;
Expand All @@ -17,31 +19,30 @@ let makeArrayOfJsonFieldsFromParsedFieldDeclarations parsedFields =
[%expr [%e key], [%e BatOption.get encoder] [%e field]])
|> Exp.array

let wrapInSpreadEncoder parsedFields baseExpr =
let spreadExpr =
match List.find_opt (fun {name} -> name = "...") parsedFields with
| Some {codecs = Some otherEncoder, _} ->
(* We've encountered a spread operator here. At this point, we
want to call the encode function for the name of the thing
that's being spread, and then produce an expression that will
merge another object over the encoded spread object.
let wrapInSpreadEncoders parsedFields baseExpr =
let spreadExprs =
List.filter_map
(fun {name; codecs} ->
match (name, codecs) with
| "...", (Some otherEncoder, _) ->
(* We've encountered a spread operator here. At this point, we
want to call the encode function for the name of the thing
that's being spread, and then produce an expression that will
merge another object over the encoded spread object.
Make sure to use the text 'valueToEncode' here. It should match the value defined in
generateEncoder below. There's a comment there about why we don't pass this name in
as a parameter. *)
let otherEncoderLident =
[%expr [%e otherEncoder] (Obj.magic valueToEncode)]
in
Some [%expr Decco.unsafeMergeJsonObjectsCurried [%e otherEncoderLident]]
| _ -> None
Make sure to use the text 'valueToEncode' here. It should match the value defined in
generateEncoder below. There's a comment there about why we don't pass this name in
as a parameter. *)
let otherEncoderLident =
[%expr [%e otherEncoder] (Obj.magic valueToEncode)]
in
Some [%expr Decco.unsafeMergeObjects [%e otherEncoderLident]]
| _, _ -> None)
parsedFields
in
match spreadExpr with
(* If we have a spread expression to apply, wrap the whole base encoder expression in
the function that merges it with the result of the spread *)
| Some spreadExpr -> [%expr [%e spreadExpr] [%e baseExpr]]
| None ->
(* If we're not handling a spread record, just return the base encoder expression *)
baseExpr
List.fold_right
(fun spreadExpr acc -> [%expr [%e spreadExpr] [%e acc]])
spreadExprs baseExpr

let generateEncoder parsedFields unboxed (rootTypeNameOfRecord : label) =
(* If we've got a record with a spread type in it, we'll need to omit the spread
Expand Down Expand Up @@ -74,7 +75,7 @@ let generateEncoder parsedFields unboxed (rootTypeNameOfRecord : label) =
[%e
makeArrayOfJsonFieldsFromParsedFieldDeclarations
parsedFieldsWithoutSpread])]
|> wrapInSpreadEncoder parsedFields
|> wrapInSpreadEncoders parsedFields
(* This is where the final encoder function is constructed. If
you need to do something with the parameters, this is the place. *)
|> Exp.fun_ Asttypes.Nolabel None constrainedFunctionArgsPattern
Expand All @@ -95,47 +96,99 @@ let generateDictGet {key; codecs = _, decoder; default} =
let generateDictGets decls =
decls |> List.map generateDictGet |> tupleOrSingleton Exp.tuple

let generateErrorCase {key} =
let generateErrorCase {key; spreadName} =
let finalKey =
match spreadName with
| Some spreadName ->
Exp.constant (Pconst_string ("..." ^ spreadName, Location.none, None))
| None -> key
in
{
pc_lhs = [%pat? Belt.Result.Error (e : Decco.decodeError)];
pc_guard = None;
pc_rhs = [%expr Belt.Result.Error {e with path = "." ^ [%e key] ^ e.path}];
pc_rhs =
[%expr Belt.Result.Error {e with path = "." ^ [%e finalKey] ^ e.path}];
}

let generateFinalRecordExpr allFieldDeclarations =
allFieldDeclarations
|> List.map (fun {name} -> (lid name, makeIdentExpr name))
|> fun l -> [%expr Belt.Result.Ok [%e Exp.record l None]]
let fieldDeclarationsWithoutSpread =
List.filter (fun {name} -> name <> "...") allFieldDeclarations
in
(* If there's a spread on the record, it gets passed as an optional expression as the last argument
to the record constructor. I don't know why, but there you go. *)
let spreadExpressions =
List.filter_map
(fun {name; spreadName} ->
match (name, spreadName) with
| "...", Some spreadName ->
(* We found a spread! But the type system won't be happy
if we spread it directly because smaller types still can't
be spread insto larger types. We'll have to use Object.magic *)
Some (Exp.ident (lid spreadName))
| _ -> None)
allFieldDeclarations
in
let rootObject =
List.fold_right
(fun {name} acc ->
[%expr
Decco.unsafeAddFieldToObject
[%e Exp.constant (Ast_helper.Const.string name)]
[%e makeIdentExpr name] [%e acc]])
fieldDeclarationsWithoutSpread [%expr Js.Dict.empty ()]
in
let mergedWithSpreads =
List.fold_right
(fun spreadExpr acc ->
[%expr Decco.unsafeMergeObjects [%e spreadExpr] [%e acc]])
spreadExpressions rootObject
in
[%expr Belt.Result.Ok (Obj.magic [%e mergedWithSpreads])]

let generateSuccessCase {name} successExpr =
let generateSuccessCase {name; spreadName} successExpr =
let actualNameToUseForOkayPayload =
match (name, spreadName) with
| "...", Some spreadName -> spreadName
| _ -> name
in
{
pc_lhs = (mknoloc name |> Pat.var |> fun p -> [%pat? Belt.Result.Ok [%p p]]);
pc_lhs =
( mknoloc actualNameToUseForOkayPayload |> Pat.var |> fun p ->
[%pat? Belt.Result.Ok [%p p]] );
pc_guard = None;
pc_rhs = successExpr;
}

(* Recursively generates an expression containing nested switches, first
decoding the first record items, then (if successful) the second, etc. *)
let rec generateNestedSwitchesRecurse allDecls remainingDecls =
let current, successExpr =
match remainingDecls with
| [] -> failwith "Decco internal error: [] not expected"
| last :: [] -> (last, generateFinalRecordExpr allDecls)
| first :: tail -> (first, generateNestedSwitchesRecurse allDecls tail)
in
(* Normally the expression we'll switch on is getting a value from Js.Dict,
but in the case of a spread operator, ..., we're going to call the decoder
for that field instead *)
let switchExpression =
match current with
| {name = "..."; codecs = _, decoder} ->
[%expr [%e BatOption.get decoder] v]
| _ -> generateDictGet current
in
[generateErrorCase current]
|> List.append [generateSuccessCase current successExpr]
|> Exp.match_ (generateDictGet current)
[@@ocaml.doc
" Recursively generates an expression containing nested switches, first\n\
\ * decoding the first record items, then (if successful) the second, etc. "]
|> Exp.match_ switchExpression

let generateNestedSwitches decls = generateNestedSwitchesRecurse decls decls

let generateDecoder decls unboxed =
let fieldDeclarationsWithoutSpread =
List.filter (fun {name} -> name <> "...") decls
in
match unboxed with
| true ->
let fieldDeclarationsWithoutSpread =
List.filter (fun {name} -> name <> "...") decls
in
let {codecs; name} = List.hd fieldDeclarationsWithoutSpread in
let _, d = codecs in
let recordExpr =
Expand All @@ -148,8 +201,7 @@ let generateDecoder decls unboxed =
[%expr
fun v ->
match Js.Json.classify v with
| Js.Json.JSONObject dict ->
[%e generateNestedSwitches fieldDeclarationsWithoutSpread]
| Js.Json.JSONObject dict -> [%e generateNestedSwitches decls]
| _ -> Decco.error "Not an object" v]

let parseRecordField encodeDecodeFlags (rootTypeNameOfRecord : label)
Expand All @@ -166,8 +218,15 @@ let parseRecordField encodeDecodeFlags (rootTypeNameOfRecord : label)
| Ok None -> Exp.constant (Pconst_string (txt, Location.none, None))
| Error s -> fail pld_loc s
in
let spreadName =
match (txt, pld_type) with
| "...", {ptyp_desc = Ptyp_constr ({txt = Lident spreadName; _}, _); _} ->
Some spreadName
| _ -> None
in
{
name = txt;
spreadName;
key;
field = Exp.field [%expr valueToEncode] (lid txt);
codecs = Codecs.generateCodecs encodeDecodeFlags pld_type;
Expand Down
5 changes: 0 additions & 5 deletions ppx_src/src/Structure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ open Parsetree
open Ast_helper
open Utils

type typeInfo = {typeName: label; typeParams: label list}

let typeNameAndParamsToTypeDeclaration {typeName; typeParams} =
Typ.constr (lid typeName) (List.map (fun s -> Typ.var s) typeParams)

let jsJsonTypeDecl = Typ.constr (lid "Js.Json.t") []

let buildRightHandSideOfEqualSignForCodecDeclarations (paramNames : label list)
Expand Down
5 changes: 5 additions & 0 deletions ppx_src/src/Utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ let print_strings strings =
Printf.printf "[%s]\n" formatted

let labelToCoreType label = Ast_helper.Typ.constr (lid label) []

type typeInfo = {typeName: label; typeParams: label list}

let typeNameAndParamsToTypeDeclaration {typeName; typeParams} =
Typ.constr (lid typeName) (List.map (fun s -> Typ.var s) typeParams)
15 changes: 12 additions & 3 deletions src/Decco.res
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,24 @@ let dictFromJson = (decoder, json) =>
}

/**
* Merges two JSON objects together. If there are any duplicate keys, the value from the second object will be used.
* Merges two javascript objects together. If there are any duplicate keys, the value from the second object will be used.
* This function is type-unsafe and should be used with caution. It's here to be used by generated decoder
* functions for records that use spreads in their types, and these functions are careful only to pass in
* JSON objects and not other kinds of values.
* objects and not other kinds of values.
*/
let unsafeMergeJsonObjectsCurried = (a: Js.Json.t) => (b: Js.Json.t): Js.Json.t => {
let unsafeMergeObjects = (a: 'a, b: 'b): 'c => {
Js.Obj.assign(a->Obj.magic, b->Obj.magic)->Obj.magic
}

/**
* Adds a field to a javascript object. This function is type-unsafe and should be used with caution. It's here to be used by
* generated decoder functions for records that use spreads in their types, and these functions are careful only to pass in
* objects and not other kinds of values.
*/
let unsafeAddFieldToObject = (key: string, value: 'b, obj: 'a): 'c => {
Obj.magic(obj)->Js.Dict.set(key, value)->Obj.magic
}

module Codecs = {
include Decco_Codecs
let string = (stringToJson, stringFromJson)
Expand Down
10 changes: 5 additions & 5 deletions test/__tests__/CustomCodecs.res
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
open Jest
open Expect

let intToStr = (i: int) => i->string_of_int
let intFromStr = (s: string) => s->int_of_string
let intToStr = (i: int) => i->string_of_int->Decco.stringToJson
let intFromStr = (s: Js.Json.t) => s->Decco.stringFromJson->Belt.Result.map(int_of_string)

@decco type intAsStr = @decco.codec((intToStr, intFromStr)) int

Expand All @@ -14,12 +14,12 @@ describe("CustomCodecs", () => {
let x: intAsStr = 42

let encoded = x->intAsStr_encode
expect(encoded)->toBe("42")
expect(encoded)->toBe("42"->Decco.stringToJson)
})

test("should decode", () => {
let encoded = "42"
let encoded = "42"->Decco.stringToJson
let decoded = intAsStr_decode(encoded)
expect(decoded)->toBe(42)
expect(decoded)->toBe(Ok(42))
})
})
97 changes: 45 additions & 52 deletions test/__tests__/RecordSpreads.res
Original file line number Diff line number Diff line change
@@ -1,52 +1,45 @@
// open Jest
// open TestUtils
// open Expect

// module A = {
// @decco
// type t = {first_name: string}
// }

// module B = {
// @decco
// type t = {last_name: string}
// }

// module C = {
// @decco
// type t = {
// ...A.t,
// ...B.t,
// age: int,
// }
// }

// describe("record spreading", () => {
// test("should encode", () => {
// let a = {
// first_name: "Bob",
// }

// let b = {
// last_name: "pizza",
// }

// let c: c = {
// first_name: "bob",
// last_name: "pizza",
// age: 3,
// }

// let encoded = c_encode(c)

// // expect("123")->toBe(Js.Json.stringify(c_encode(c)))
// expect("123")->toBe("123")
// })

// // test("should decode", () => {
// // let encoded = "42"
// // let decoded = intAsStr_decode(encoded)
// // expect(decoded)->toBe(42)
// // })
// })

open Jest
open Expect

module A = {
@decco
type t = {first_name: string}
@decco
type b = {last_name: string}
}

module C = {
@decco
type t = {
...A.t,
...A.b,
age: int,
}
}

describe("record spreading", () => {
test("should encode", () => {
let c: C.t = {
first_name: "bob",
last_name: "pizza",
age: 3,
}

let encoded = C.t_encode(c)

expect(Js.Json.stringify(encoded))->toBe(
{"first_name": "bob", "last_name": "pizza", "age": 3}
->Obj.magic
->Js.Json.stringify,
)
})

test("should decode", () => {
let json = Js.Json.parseExn(`{"first_name":"bob","last_name":"pizza","age":3}`)
let decoded = C.t_decode(json)

expect(decoded->Belt.Result.map(x => (x.first_name, x.last_name, x.age)))->toEqual(
Ok(("bob", "pizza", 3)),
)
})
})

0 comments on commit 74b2029

Please sign in to comment.