From 74b20290d494e81c9cd7931847092548b7945379 Mon Sep 17 00:00:00 2001 From: Murphy Randle Date: Wed, 1 May 2024 06:44:26 -0400 Subject: [PATCH] Support type spreading in records --- ppx_src/src/Records.ml | 139 ++++++++++++++++++++++--------- ppx_src/src/Structure.ml | 5 -- ppx_src/src/Utils.ml | 5 ++ src/Decco.res | 15 +++- test/__tests__/CustomCodecs.res | 10 +-- test/__tests__/RecordSpreads.res | 97 ++++++++++----------- 6 files changed, 166 insertions(+), 105 deletions(-) diff --git a/ppx_src/src/Records.ml b/ppx_src/src/Records.ml index 015bf09..918a9e0 100644 --- a/ppx_src/src/Records.ml +++ b/ppx_src/src/Records.ml @@ -5,6 +5,8 @@ open Utils type parsedRecordFieldDeclaration = { name: string; + (* If this field was a spread, this is the name that comes after the three dots *) + spreadName: string option; key: expression; field: expression; codecs: expression option * expression option; @@ -17,31 +19,30 @@ let makeArrayOfJsonFieldsFromParsedFieldDeclarations parsedFields = [%expr [%e key], [%e BatOption.get encoder] [%e field]]) |> Exp.array -let wrapInSpreadEncoder parsedFields baseExpr = - let spreadExpr = - match List.find_opt (fun {name} -> name = "...") parsedFields with - | Some {codecs = Some otherEncoder, _} -> - (* We've encountered a spread operator here. At this point, we - want to call the encode function for the name of the thing - that's being spread, and then produce an expression that will - merge another object over the encoded spread object. +let wrapInSpreadEncoders parsedFields baseExpr = + let spreadExprs = + List.filter_map + (fun {name; codecs} -> + match (name, codecs) with + | "...", (Some otherEncoder, _) -> + (* We've encountered a spread operator here. At this point, we + want to call the encode function for the name of the thing + that's being spread, and then produce an expression that will + merge another object over the encoded spread object. - Make sure to use the text 'valueToEncode' here. It should match the value defined in - generateEncoder below. There's a comment there about why we don't pass this name in - as a parameter. *) - let otherEncoderLident = - [%expr [%e otherEncoder] (Obj.magic valueToEncode)] - in - Some [%expr Decco.unsafeMergeJsonObjectsCurried [%e otherEncoderLident]] - | _ -> None + Make sure to use the text 'valueToEncode' here. It should match the value defined in + generateEncoder below. There's a comment there about why we don't pass this name in + as a parameter. *) + let otherEncoderLident = + [%expr [%e otherEncoder] (Obj.magic valueToEncode)] + in + Some [%expr Decco.unsafeMergeObjects [%e otherEncoderLident]] + | _, _ -> None) + parsedFields in - match spreadExpr with - (* If we have a spread expression to apply, wrap the whole base encoder expression in - the function that merges it with the result of the spread *) - | Some spreadExpr -> [%expr [%e spreadExpr] [%e baseExpr]] - | None -> - (* If we're not handling a spread record, just return the base encoder expression *) - baseExpr + List.fold_right + (fun spreadExpr acc -> [%expr [%e spreadExpr] [%e acc]]) + spreadExprs baseExpr let generateEncoder parsedFields unboxed (rootTypeNameOfRecord : label) = (* If we've got a record with a spread type in it, we'll need to omit the spread @@ -74,7 +75,7 @@ let generateEncoder parsedFields unboxed (rootTypeNameOfRecord : label) = [%e makeArrayOfJsonFieldsFromParsedFieldDeclarations parsedFieldsWithoutSpread])] - |> wrapInSpreadEncoder parsedFields + |> wrapInSpreadEncoders parsedFields (* This is where the final encoder function is constructed. If you need to do something with the parameters, this is the place. *) |> Exp.fun_ Asttypes.Nolabel None constrainedFunctionArgsPattern @@ -95,25 +96,71 @@ let generateDictGet {key; codecs = _, decoder; default} = let generateDictGets decls = decls |> List.map generateDictGet |> tupleOrSingleton Exp.tuple -let generateErrorCase {key} = +let generateErrorCase {key; spreadName} = + let finalKey = + match spreadName with + | Some spreadName -> + Exp.constant (Pconst_string ("..." ^ spreadName, Location.none, None)) + | None -> key + in { pc_lhs = [%pat? Belt.Result.Error (e : Decco.decodeError)]; pc_guard = None; - pc_rhs = [%expr Belt.Result.Error {e with path = "." ^ [%e key] ^ e.path}]; + pc_rhs = + [%expr Belt.Result.Error {e with path = "." ^ [%e finalKey] ^ e.path}]; } let generateFinalRecordExpr allFieldDeclarations = - allFieldDeclarations - |> List.map (fun {name} -> (lid name, makeIdentExpr name)) - |> fun l -> [%expr Belt.Result.Ok [%e Exp.record l None]] + let fieldDeclarationsWithoutSpread = + List.filter (fun {name} -> name <> "...") allFieldDeclarations + in + (* If there's a spread on the record, it gets passed as an optional expression as the last argument + to the record constructor. I don't know why, but there you go. *) + let spreadExpressions = + List.filter_map + (fun {name; spreadName} -> + match (name, spreadName) with + | "...", Some spreadName -> + (* We found a spread! But the type system won't be happy + if we spread it directly because smaller types still can't + be spread insto larger types. We'll have to use Object.magic *) + Some (Exp.ident (lid spreadName)) + | _ -> None) + allFieldDeclarations + in + let rootObject = + List.fold_right + (fun {name} acc -> + [%expr + Decco.unsafeAddFieldToObject + [%e Exp.constant (Ast_helper.Const.string name)] + [%e makeIdentExpr name] [%e acc]]) + fieldDeclarationsWithoutSpread [%expr Js.Dict.empty ()] + in + let mergedWithSpreads = + List.fold_right + (fun spreadExpr acc -> + [%expr Decco.unsafeMergeObjects [%e spreadExpr] [%e acc]]) + spreadExpressions rootObject + in + [%expr Belt.Result.Ok (Obj.magic [%e mergedWithSpreads])] -let generateSuccessCase {name} successExpr = +let generateSuccessCase {name; spreadName} successExpr = + let actualNameToUseForOkayPayload = + match (name, spreadName) with + | "...", Some spreadName -> spreadName + | _ -> name + in { - pc_lhs = (mknoloc name |> Pat.var |> fun p -> [%pat? Belt.Result.Ok [%p p]]); + pc_lhs = + ( mknoloc actualNameToUseForOkayPayload |> Pat.var |> fun p -> + [%pat? Belt.Result.Ok [%p p]] ); pc_guard = None; pc_rhs = successExpr; } +(* Recursively generates an expression containing nested switches, first + decoding the first record items, then (if successful) the second, etc. *) let rec generateNestedSwitchesRecurse allDecls remainingDecls = let current, successExpr = match remainingDecls with @@ -121,21 +168,27 @@ let rec generateNestedSwitchesRecurse allDecls remainingDecls = | last :: [] -> (last, generateFinalRecordExpr allDecls) | first :: tail -> (first, generateNestedSwitchesRecurse allDecls tail) in + (* Normally the expression we'll switch on is getting a value from Js.Dict, + but in the case of a spread operator, ..., we're going to call the decoder + for that field instead *) + let switchExpression = + match current with + | {name = "..."; codecs = _, decoder} -> + [%expr [%e BatOption.get decoder] v] + | _ -> generateDictGet current + in [generateErrorCase current] |> List.append [generateSuccessCase current successExpr] - |> Exp.match_ (generateDictGet current) -[@@ocaml.doc - " Recursively generates an expression containing nested switches, first\n\ - \ * decoding the first record items, then (if successful) the second, etc. "] + |> Exp.match_ switchExpression let generateNestedSwitches decls = generateNestedSwitchesRecurse decls decls let generateDecoder decls unboxed = - let fieldDeclarationsWithoutSpread = - List.filter (fun {name} -> name <> "...") decls - in match unboxed with | true -> + let fieldDeclarationsWithoutSpread = + List.filter (fun {name} -> name <> "...") decls + in let {codecs; name} = List.hd fieldDeclarationsWithoutSpread in let _, d = codecs in let recordExpr = @@ -148,8 +201,7 @@ let generateDecoder decls unboxed = [%expr fun v -> match Js.Json.classify v with - | Js.Json.JSONObject dict -> - [%e generateNestedSwitches fieldDeclarationsWithoutSpread] + | Js.Json.JSONObject dict -> [%e generateNestedSwitches decls] | _ -> Decco.error "Not an object" v] let parseRecordField encodeDecodeFlags (rootTypeNameOfRecord : label) @@ -166,8 +218,15 @@ let parseRecordField encodeDecodeFlags (rootTypeNameOfRecord : label) | Ok None -> Exp.constant (Pconst_string (txt, Location.none, None)) | Error s -> fail pld_loc s in + let spreadName = + match (txt, pld_type) with + | "...", {ptyp_desc = Ptyp_constr ({txt = Lident spreadName; _}, _); _} -> + Some spreadName + | _ -> None + in { name = txt; + spreadName; key; field = Exp.field [%expr valueToEncode] (lid txt); codecs = Codecs.generateCodecs encodeDecodeFlags pld_type; diff --git a/ppx_src/src/Structure.ml b/ppx_src/src/Structure.ml index 563aa43..c46026e 100644 --- a/ppx_src/src/Structure.ml +++ b/ppx_src/src/Structure.ml @@ -3,11 +3,6 @@ open Parsetree open Ast_helper open Utils -type typeInfo = {typeName: label; typeParams: label list} - -let typeNameAndParamsToTypeDeclaration {typeName; typeParams} = - Typ.constr (lid typeName) (List.map (fun s -> Typ.var s) typeParams) - let jsJsonTypeDecl = Typ.constr (lid "Js.Json.t") [] let buildRightHandSideOfEqualSignForCodecDeclarations (paramNames : label list) diff --git a/ppx_src/src/Utils.ml b/ppx_src/src/Utils.ml index b8d39da..75ab878 100644 --- a/ppx_src/src/Utils.ml +++ b/ppx_src/src/Utils.ml @@ -138,3 +138,8 @@ let print_strings strings = Printf.printf "[%s]\n" formatted let labelToCoreType label = Ast_helper.Typ.constr (lid label) [] + +type typeInfo = {typeName: label; typeParams: label list} + +let typeNameAndParamsToTypeDeclaration {typeName; typeParams} = + Typ.constr (lid typeName) (List.map (fun s -> Typ.var s) typeParams) \ No newline at end of file diff --git a/src/Decco.res b/src/Decco.res index 691c46c..eb6b744 100644 --- a/src/Decco.res +++ b/src/Decco.res @@ -153,15 +153,24 @@ let dictFromJson = (decoder, json) => } /** - * Merges two JSON objects together. If there are any duplicate keys, the value from the second object will be used. + * Merges two javascript objects together. If there are any duplicate keys, the value from the second object will be used. * This function is type-unsafe and should be used with caution. It's here to be used by generated decoder * functions for records that use spreads in their types, and these functions are careful only to pass in - * JSON objects and not other kinds of values. + * objects and not other kinds of values. */ -let unsafeMergeJsonObjectsCurried = (a: Js.Json.t) => (b: Js.Json.t): Js.Json.t => { +let unsafeMergeObjects = (a: 'a, b: 'b): 'c => { Js.Obj.assign(a->Obj.magic, b->Obj.magic)->Obj.magic } +/** + * Adds a field to a javascript object. This function is type-unsafe and should be used with caution. It's here to be used by + * generated decoder functions for records that use spreads in their types, and these functions are careful only to pass in + * objects and not other kinds of values. + */ +let unsafeAddFieldToObject = (key: string, value: 'b, obj: 'a): 'c => { + Obj.magic(obj)->Js.Dict.set(key, value)->Obj.magic +} + module Codecs = { include Decco_Codecs let string = (stringToJson, stringFromJson) diff --git a/test/__tests__/CustomCodecs.res b/test/__tests__/CustomCodecs.res index a91f5a5..10471d2 100644 --- a/test/__tests__/CustomCodecs.res +++ b/test/__tests__/CustomCodecs.res @@ -4,8 +4,8 @@ open Jest open Expect -let intToStr = (i: int) => i->string_of_int -let intFromStr = (s: string) => s->int_of_string +let intToStr = (i: int) => i->string_of_int->Decco.stringToJson +let intFromStr = (s: Js.Json.t) => s->Decco.stringFromJson->Belt.Result.map(int_of_string) @decco type intAsStr = @decco.codec((intToStr, intFromStr)) int @@ -14,12 +14,12 @@ describe("CustomCodecs", () => { let x: intAsStr = 42 let encoded = x->intAsStr_encode - expect(encoded)->toBe("42") + expect(encoded)->toBe("42"->Decco.stringToJson) }) test("should decode", () => { - let encoded = "42" + let encoded = "42"->Decco.stringToJson let decoded = intAsStr_decode(encoded) - expect(decoded)->toBe(42) + expect(decoded)->toBe(Ok(42)) }) }) diff --git a/test/__tests__/RecordSpreads.res b/test/__tests__/RecordSpreads.res index 9086b0b..c96a8eb 100644 --- a/test/__tests__/RecordSpreads.res +++ b/test/__tests__/RecordSpreads.res @@ -1,52 +1,45 @@ -// open Jest -// open TestUtils -// open Expect - -// module A = { -// @decco -// type t = {first_name: string} -// } - -// module B = { -// @decco -// type t = {last_name: string} -// } - -// module C = { -// @decco -// type t = { -// ...A.t, -// ...B.t, -// age: int, -// } -// } - -// describe("record spreading", () => { -// test("should encode", () => { -// let a = { -// first_name: "Bob", -// } - -// let b = { -// last_name: "pizza", -// } - -// let c: c = { -// first_name: "bob", -// last_name: "pizza", -// age: 3, -// } - -// let encoded = c_encode(c) - -// // expect("123")->toBe(Js.Json.stringify(c_encode(c))) -// expect("123")->toBe("123") -// }) - -// // test("should decode", () => { -// // let encoded = "42" -// // let decoded = intAsStr_decode(encoded) -// // expect(decoded)->toBe(42) -// // }) -// }) - +open Jest +open Expect + +module A = { + @decco + type t = {first_name: string} + @decco + type b = {last_name: string} +} + +module C = { + @decco + type t = { + ...A.t, + ...A.b, + age: int, + } +} + +describe("record spreading", () => { + test("should encode", () => { + let c: C.t = { + first_name: "bob", + last_name: "pizza", + age: 3, + } + + let encoded = C.t_encode(c) + + expect(Js.Json.stringify(encoded))->toBe( + {"first_name": "bob", "last_name": "pizza", "age": 3} + ->Obj.magic + ->Js.Json.stringify, + ) + }) + + test("should decode", () => { + let json = Js.Json.parseExn(`{"first_name":"bob","last_name":"pizza","age":3}`) + let decoded = C.t_decode(json) + + expect(decoded->Belt.Result.map(x => (x.first_name, x.last_name, x.age)))->toEqual( + Ok(("bob", "pizza", 3)), + ) + }) +})