diff --git a/lib-driver/caqti_driver_postgresql.ml b/lib-driver/caqti_driver_postgresql.ml index 6c58b9ae..afb86e71 100644 --- a/lib-driver/caqti_driver_postgresql.ml +++ b/lib-driver/caqti_driver_postgresql.ml @@ -178,71 +178,84 @@ let rec set_binary_params set_binary_params bp t2 %> set_binary_params bp t3 | Caqti_type.Custom {rep; _} -> set_binary_params bp rep -let rec encode_field - : type a. uri: Uri.t -> a Caqti_type.field -> a -> (string, _) result = - fun ~uri field_type x -> - (match field_type with - | Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x) - | Caqti_type.Int -> Ok (string_of_int x) - | Caqti_type.Int32 -> Ok (Int32.to_string x) - | Caqti_type.Int64 -> Ok (Int64.to_string x) - | Caqti_type.Float -> Ok (sprintf "%.17g" x) - | Caqti_type.String -> Ok x - | Caqti_type.Octets -> Ok x - | Caqti_type.Pdate -> Ok (iso8601_of_pdate x) - | Caqti_type.Ptime -> - Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x) - | Caqti_type.Ptime_span -> - Ok (Pg_ext.string_of_ptime_span x) - | _ -> - (match Caqti_type.Field.coding driver_info field_type with - | None -> Error (Caqti_error.encode_missing ~uri ~field_type ()) - | Some (Caqti_type.Field.Coding {rep; encode; _}) -> - (match encode x with - | Ok y -> encode_field ~uri rep y - | Error msg -> - let msg = Caqti_error.Msg msg in - let typ = Caqti_type.field field_type in - Error (Caqti_error.encode_rejected ~uri ~typ msg)))) - -let rec copy_encode_field - : type a. uri: Uri.t -> - a Caqti_type.field -> - bool -> - (string -> string) -> - a -> - (string, _) result = - fun ~uri field_type field_is_binary binary_escape x -> - let escape_and_quote s = - "\"" ^ (String.concat "\"\"" (String.split_on_char '"' s)) ^ "\"" - in - (match field_type with - | Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x) - | Caqti_type.Int -> Ok (string_of_int x) - | Caqti_type.Int32 -> Ok (Int32.to_string x) - | Caqti_type.Int64 -> Ok (Int64.to_string x) - | Caqti_type.Float -> Ok (sprintf "%.17g" x) - | Caqti_type.String -> - let x = if field_is_binary then binary_escape x else x in - Ok (escape_and_quote x) - | Caqti_type.Octets -> - let x = if field_is_binary then binary_escape x else x in - Ok (escape_and_quote x) - | Caqti_type.Pdate -> Ok (iso8601_of_pdate x) - | Caqti_type.Ptime -> - Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x) - | Caqti_type.Ptime_span -> - Ok (Pg_ext.string_of_ptime_span x) - | _ -> - (match Caqti_type.Field.coding driver_info field_type with - | None -> Error (Caqti_error.encode_missing ~uri ~field_type ()) - | Some (Caqti_type.Field.Coding {rep; encode; _}) -> - (match encode x with - | Ok y -> copy_encode_field ~uri rep field_is_binary binary_escape y - | Error msg -> - let msg = Caqti_error.Msg msg in - let typ = Caqti_type.field field_type in - Error (Caqti_error.encode_rejected ~uri ~typ msg)))) +module type STRING_ENCODER = sig + val encode_string : string -> string + val encode_octets : string -> string +end + +module Make_encoder (String_encoder : STRING_ENCODER) = struct + open String_encoder + + let rec encode_field + : type a. uri: Uri.t -> a Caqti_type.field -> a -> (string, _) result = + fun ~uri field_type x -> + (match field_type with + | Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x) + | Caqti_type.Int -> Ok (string_of_int x) + | Caqti_type.Int32 -> Ok (Int32.to_string x) + | Caqti_type.Int64 -> Ok (Int64.to_string x) + | Caqti_type.Float -> Ok (sprintf "%.17g" x) + | Caqti_type.String -> Ok (encode_string x) + | Caqti_type.Octets -> Ok (encode_octets x) + | Caqti_type.Pdate -> Ok (iso8601_of_pdate x) + | Caqti_type.Ptime -> + Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x) + | Caqti_type.Ptime_span -> + Ok (Pg_ext.string_of_ptime_span x) + | _ -> + (match Caqti_type.Field.coding driver_info field_type with + | None -> Error (Caqti_error.encode_missing ~uri ~field_type ()) + | Some (Caqti_type.Field.Coding {rep; encode; _}) -> + (match encode x with + | Ok y -> encode_field ~uri rep y + | Error msg -> + let msg = Caqti_error.Msg msg in + let typ = Caqti_type.field field_type in + Error (Caqti_error.encode_rejected ~uri ~typ msg)))) + + let rec encode' + : type a. uri: Uri.t -> string array -> + a Caqti_type.t -> a -> int -> (int, _) result = + fun ~uri params t x -> + (match t, x with + | Caqti_type.Unit, () -> fun i -> Ok i + | Caqti_type.Field ft, fv -> fun i -> + (match encode_field ~uri ft fv with + | Ok s -> params.(i) <- s; Ok (i + 1) + | Error _ as r -> r) + | Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t) + | Caqti_type.Option t, Some x -> + encode' ~uri params t x + | Caqti_type.Tup2 (t0, t1), (x0, x1) -> + encode' ~uri params t0 x0 %>? + encode' ~uri params t1 x1 + | Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) -> + encode' ~uri params t0 x0 %>? + encode' ~uri params t1 x1 %>? + encode' ~uri params t2 x2 + | Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) -> + encode' ~uri params t0 x0 %>? + encode' ~uri params t1 x1 %>? + encode' ~uri params t2 x2 %>? + encode' ~uri params t3 x3 + | Caqti_type.Custom {rep; encode = encode_custom; _}, x -> fun i -> + (match encode_custom x with + | Ok y -> + encode' ~uri params rep y i + | Error msg -> + let msg = Caqti_error.Msg msg in + Error (Caqti_error.encode_rejected ~uri ~typ:t msg))) + + let encode ~uri params t x = + (match encode' ~uri params t x 0 with + | Ok n -> assert (n = Array.length params); Ok () + | Error _ as r -> r) +end + +module Param_encoder = Make_encoder (struct + let encode_string s = s + let encode_octets s = s +end) let rec decode_field : type a. uri: Uri.t -> a Caqti_type.field -> string -> (a, _) result = @@ -296,72 +309,6 @@ let rec decode_field Error (Caqti_error.decode_rejected ~uri ~typ msg)) | Error _ as r -> r))) -let rec encode_param - : type a. uri: Uri.t -> string array -> - a Caqti_type.t -> a -> int -> (int, _) result = - fun ~uri params t x -> - (match t, x with - | Caqti_type.Unit, () -> fun i -> Ok i - | Caqti_type.Field ft, fv -> fun i -> - (match encode_field ~uri ft fv with - | Ok s -> params.(i) <- s; Ok (i + 1) - | Error _ as r -> r) - | Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t) - | Caqti_type.Option t, Some x -> encode_param ~uri params t x - | Caqti_type.Tup2 (t0, t1), (x0, x1) -> - encode_param ~uri params t0 x0 %>? encode_param ~uri params t1 x1 - | Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) -> - encode_param ~uri params t0 x0 %>? encode_param ~uri params t1 x1 %>? - encode_param ~uri params t2 x2 - | Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) -> - encode_param ~uri params t0 x0 %>? encode_param ~uri params t1 x1 %>? - encode_param ~uri params t2 x2 %>? encode_param ~uri params t3 x3 - | Caqti_type.Custom {rep; encode; _}, x -> fun i -> - (match encode x with - | Ok y -> encode_param ~uri params rep y i - | Error msg -> - let msg = Caqti_error.Msg msg in - Error (Caqti_error.encode_rejected ~uri ~typ:t msg))) - -let rec copy_encode_param - : type a. uri: Uri.t -> - string array -> - bool array -> - (string -> string) -> - a Caqti_type.t -> - a -> - int -> - (int, _) result = - fun ~uri params binary_params binary_escape t x -> - (match t, x with - | Caqti_type.Unit, () -> fun i -> Ok i - | Caqti_type.Field ft, fv -> fun i -> - (match copy_encode_field ~uri ft binary_params.(i) binary_escape fv with - | Ok s -> params.(i) <- s; Ok (i + 1) - | Error _ as r -> r) - | Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t) - | Caqti_type.Option t, Some x -> - copy_encode_param ~uri params binary_params binary_escape t x - | Caqti_type.Tup2 (t0, t1), (x0, x1) -> - copy_encode_param ~uri params binary_params binary_escape t0 x0 %>? - copy_encode_param ~uri params binary_params binary_escape t1 x1 - | Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) -> - copy_encode_param ~uri params binary_params binary_escape t0 x0 %>? - copy_encode_param ~uri params binary_params binary_escape t1 x1 %>? - copy_encode_param ~uri params binary_params binary_escape t2 x2 - | Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) -> - copy_encode_param ~uri params binary_params binary_escape t0 x0 %>? - copy_encode_param ~uri params binary_params binary_escape t1 x1 %>? - copy_encode_param ~uri params binary_params binary_escape t2 x2 %>? - copy_encode_param ~uri params binary_params binary_escape t3 x3 - | Caqti_type.Custom {rep; encode; _}, x -> fun i -> - (match encode x with - | Ok y -> - copy_encode_param ~uri params binary_params binary_escape rep y i - | Error msg -> - let msg = Caqti_error.Msg msg in - Error (Caqti_error.encode_rejected ~uri ~typ:t msg))) - let rec decode_row' : type b. uri: Uri.t -> Pg.result * int -> b Caqti_type.t -> int -> (int * b, _) result = @@ -542,6 +489,12 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct struct open Db + module Copy_encoder = Make_encoder (struct + let encode_string s = + "\"" ^ (String.concat "\"\"" (String.split_on_char '"' s)) ^ "\"" + let encode_octets s = db#escape_bytea s + end) + let using_db_ref = ref false let using_db f = H.assert_single_use using_db_ref f let in_transaction = ref false @@ -695,9 +648,8 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct let nbp = set_binary_params binary_params param_type 0 in assert (nbp = param_length); let params = Array.make param_length Pg.null in - (match encode_param ~uri params param_type param 0 with - | Ok n -> - assert (n = param_length); + (match Param_encoder.encode ~uri params param_type param with + | Ok () -> query_oneshot ~params ~binary_params query >|=? fun result -> Ok (query, result) | Error _ as r -> @@ -715,9 +667,8 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct {query; param_length; binary_params} in let params = Array.make prepared.param_length Pg.null in - (match encode_param ~uri params param_type param 0 with - | Ok n -> - assert (n = prepared.param_length); + (match Param_encoder.encode ~uri params param_type param with + | Ok () -> query_prepared query_id prepared params >|=? fun result -> Ok (prepared.query, result) | Error _ as r -> @@ -774,7 +725,6 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct let binary_params = Array.make param_length false in let nbp = set_binary_params binary_params row_type 0 in assert (nbp = param_length); - let binary_escape = db#escape_bytea in let fail msg = return (Error (Caqti_error.request_failed ~uri ~query (Caqti_error.Msg msg))) @@ -799,19 +749,8 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct in let copy_row row = let params = Array.make param_length Pg.null in - let num_encoded_params = - copy_encode_param - ~uri - params - binary_params - binary_escape - row_type - row - 0 - in - (match num_encoded_params with - | Ok n -> - assert (n = param_length); + (match Copy_encoder.encode ~uri params row_type row with + | Ok () -> return (Ok (String.concat "," (Array.to_list params))) | Error _ as r -> return r)