Skip to content

Commit

Permalink
Deduplicate parameter and CSV encoder in PG driver.
Browse files Browse the repository at this point in the history
  • Loading branch information
paurkedal committed Dec 8, 2019
1 parent efbe7de commit e9befd6
Showing 1 changed file with 90 additions and 151 deletions.
241 changes: 90 additions & 151 deletions lib-driver/caqti_driver_postgresql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,71 +178,84 @@ let rec set_binary_params
set_binary_params bp t2 %> set_binary_params bp t3
| Caqti_type.Custom {rep; _} -> set_binary_params bp rep

let rec encode_field
: type a. uri: Uri.t -> a Caqti_type.field -> a -> (string, _) result =
fun ~uri field_type x ->
(match field_type with
| Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x)
| Caqti_type.Int -> Ok (string_of_int x)
| Caqti_type.Int32 -> Ok (Int32.to_string x)
| Caqti_type.Int64 -> Ok (Int64.to_string x)
| Caqti_type.Float -> Ok (sprintf "%.17g" x)
| Caqti_type.String -> Ok x
| Caqti_type.Octets -> Ok x
| Caqti_type.Pdate -> Ok (iso8601_of_pdate x)
| Caqti_type.Ptime ->
Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x)
| Caqti_type.Ptime_span ->
Ok (Pg_ext.string_of_ptime_span x)
| _ ->
(match Caqti_type.Field.coding driver_info field_type with
| None -> Error (Caqti_error.encode_missing ~uri ~field_type ())
| Some (Caqti_type.Field.Coding {rep; encode; _}) ->
(match encode x with
| Ok y -> encode_field ~uri rep y
| Error msg ->
let msg = Caqti_error.Msg msg in
let typ = Caqti_type.field field_type in
Error (Caqti_error.encode_rejected ~uri ~typ msg))))

let rec copy_encode_field
: type a. uri: Uri.t ->
a Caqti_type.field ->
bool ->
(string -> string) ->
a ->
(string, _) result =
fun ~uri field_type field_is_binary binary_escape x ->
let escape_and_quote s =
"\"" ^ (String.concat "\"\"" (String.split_on_char '"' s)) ^ "\""
in
(match field_type with
| Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x)
| Caqti_type.Int -> Ok (string_of_int x)
| Caqti_type.Int32 -> Ok (Int32.to_string x)
| Caqti_type.Int64 -> Ok (Int64.to_string x)
| Caqti_type.Float -> Ok (sprintf "%.17g" x)
| Caqti_type.String ->
let x = if field_is_binary then binary_escape x else x in
Ok (escape_and_quote x)
| Caqti_type.Octets ->
let x = if field_is_binary then binary_escape x else x in
Ok (escape_and_quote x)
| Caqti_type.Pdate -> Ok (iso8601_of_pdate x)
| Caqti_type.Ptime ->
Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x)
| Caqti_type.Ptime_span ->
Ok (Pg_ext.string_of_ptime_span x)
| _ ->
(match Caqti_type.Field.coding driver_info field_type with
| None -> Error (Caqti_error.encode_missing ~uri ~field_type ())
| Some (Caqti_type.Field.Coding {rep; encode; _}) ->
(match encode x with
| Ok y -> copy_encode_field ~uri rep field_is_binary binary_escape y
| Error msg ->
let msg = Caqti_error.Msg msg in
let typ = Caqti_type.field field_type in
Error (Caqti_error.encode_rejected ~uri ~typ msg))))
module type STRING_ENCODER = sig
val encode_string : string -> string
val encode_octets : string -> string
end

module Make_encoder (String_encoder : STRING_ENCODER) = struct
open String_encoder

let rec encode_field
: type a. uri: Uri.t -> a Caqti_type.field -> a -> (string, _) result =
fun ~uri field_type x ->
(match field_type with
| Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x)
| Caqti_type.Int -> Ok (string_of_int x)
| Caqti_type.Int32 -> Ok (Int32.to_string x)
| Caqti_type.Int64 -> Ok (Int64.to_string x)
| Caqti_type.Float -> Ok (sprintf "%.17g" x)
| Caqti_type.String -> Ok (encode_string x)
| Caqti_type.Octets -> Ok (encode_octets x)
| Caqti_type.Pdate -> Ok (iso8601_of_pdate x)
| Caqti_type.Ptime ->
Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x)
| Caqti_type.Ptime_span ->
Ok (Pg_ext.string_of_ptime_span x)
| _ ->
(match Caqti_type.Field.coding driver_info field_type with
| None -> Error (Caqti_error.encode_missing ~uri ~field_type ())
| Some (Caqti_type.Field.Coding {rep; encode; _}) ->
(match encode x with
| Ok y -> encode_field ~uri rep y
| Error msg ->
let msg = Caqti_error.Msg msg in
let typ = Caqti_type.field field_type in
Error (Caqti_error.encode_rejected ~uri ~typ msg))))

let rec encode'
: type a. uri: Uri.t -> string array ->
a Caqti_type.t -> a -> int -> (int, _) result =
fun ~uri params t x ->
(match t, x with
| Caqti_type.Unit, () -> fun i -> Ok i
| Caqti_type.Field ft, fv -> fun i ->
(match encode_field ~uri ft fv with
| Ok s -> params.(i) <- s; Ok (i + 1)
| Error _ as r -> r)
| Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t)
| Caqti_type.Option t, Some x ->
encode' ~uri params t x
| Caqti_type.Tup2 (t0, t1), (x0, x1) ->
encode' ~uri params t0 x0 %>?
encode' ~uri params t1 x1
| Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) ->
encode' ~uri params t0 x0 %>?
encode' ~uri params t1 x1 %>?
encode' ~uri params t2 x2
| Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) ->
encode' ~uri params t0 x0 %>?
encode' ~uri params t1 x1 %>?
encode' ~uri params t2 x2 %>?
encode' ~uri params t3 x3
| Caqti_type.Custom {rep; encode = encode_custom; _}, x -> fun i ->
(match encode_custom x with
| Ok y ->
encode' ~uri params rep y i
| Error msg ->
let msg = Caqti_error.Msg msg in
Error (Caqti_error.encode_rejected ~uri ~typ:t msg)))

let encode ~uri params t x =
(match encode' ~uri params t x 0 with
| Ok n -> assert (n = Array.length params); Ok ()
| Error _ as r -> r)
end

module Param_encoder = Make_encoder (struct
let encode_string s = s
let encode_octets s = s
end)

let rec decode_field
: type a. uri: Uri.t -> a Caqti_type.field -> string -> (a, _) result =
Expand Down Expand Up @@ -296,72 +309,6 @@ let rec decode_field
Error (Caqti_error.decode_rejected ~uri ~typ msg))
| Error _ as r -> r)))

let rec encode_param
: type a. uri: Uri.t -> string array ->
a Caqti_type.t -> a -> int -> (int, _) result =
fun ~uri params t x ->
(match t, x with
| Caqti_type.Unit, () -> fun i -> Ok i
| Caqti_type.Field ft, fv -> fun i ->
(match encode_field ~uri ft fv with
| Ok s -> params.(i) <- s; Ok (i + 1)
| Error _ as r -> r)
| Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t)
| Caqti_type.Option t, Some x -> encode_param ~uri params t x
| Caqti_type.Tup2 (t0, t1), (x0, x1) ->
encode_param ~uri params t0 x0 %>? encode_param ~uri params t1 x1
| Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) ->
encode_param ~uri params t0 x0 %>? encode_param ~uri params t1 x1 %>?
encode_param ~uri params t2 x2
| Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) ->
encode_param ~uri params t0 x0 %>? encode_param ~uri params t1 x1 %>?
encode_param ~uri params t2 x2 %>? encode_param ~uri params t3 x3
| Caqti_type.Custom {rep; encode; _}, x -> fun i ->
(match encode x with
| Ok y -> encode_param ~uri params rep y i
| Error msg ->
let msg = Caqti_error.Msg msg in
Error (Caqti_error.encode_rejected ~uri ~typ:t msg)))

let rec copy_encode_param
: type a. uri: Uri.t ->
string array ->
bool array ->
(string -> string) ->
a Caqti_type.t ->
a ->
int ->
(int, _) result =
fun ~uri params binary_params binary_escape t x ->
(match t, x with
| Caqti_type.Unit, () -> fun i -> Ok i
| Caqti_type.Field ft, fv -> fun i ->
(match copy_encode_field ~uri ft binary_params.(i) binary_escape fv with
| Ok s -> params.(i) <- s; Ok (i + 1)
| Error _ as r -> r)
| Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t)
| Caqti_type.Option t, Some x ->
copy_encode_param ~uri params binary_params binary_escape t x
| Caqti_type.Tup2 (t0, t1), (x0, x1) ->
copy_encode_param ~uri params binary_params binary_escape t0 x0 %>?
copy_encode_param ~uri params binary_params binary_escape t1 x1
| Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) ->
copy_encode_param ~uri params binary_params binary_escape t0 x0 %>?
copy_encode_param ~uri params binary_params binary_escape t1 x1 %>?
copy_encode_param ~uri params binary_params binary_escape t2 x2
| Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) ->
copy_encode_param ~uri params binary_params binary_escape t0 x0 %>?
copy_encode_param ~uri params binary_params binary_escape t1 x1 %>?
copy_encode_param ~uri params binary_params binary_escape t2 x2 %>?
copy_encode_param ~uri params binary_params binary_escape t3 x3
| Caqti_type.Custom {rep; encode; _}, x -> fun i ->
(match encode x with
| Ok y ->
copy_encode_param ~uri params binary_params binary_escape rep y i
| Error msg ->
let msg = Caqti_error.Msg msg in
Error (Caqti_error.encode_rejected ~uri ~typ:t msg)))

let rec decode_row'
: type b. uri: Uri.t -> Pg.result * int ->
b Caqti_type.t -> int -> (int * b, _) result =
Expand Down Expand Up @@ -542,6 +489,12 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
struct
open Db

module Copy_encoder = Make_encoder (struct
let encode_string s =
"\"" ^ (String.concat "\"\"" (String.split_on_char '"' s)) ^ "\""
let encode_octets s = db#escape_bytea s
end)

let using_db_ref = ref false
let using_db f = H.assert_single_use using_db_ref f
let in_transaction = ref false
Expand Down Expand Up @@ -695,9 +648,8 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
let nbp = set_binary_params binary_params param_type 0 in
assert (nbp = param_length);
let params = Array.make param_length Pg.null in
(match encode_param ~uri params param_type param 0 with
| Ok n ->
assert (n = param_length);
(match Param_encoder.encode ~uri params param_type param with
| Ok () ->
query_oneshot ~params ~binary_params query >|=? fun result ->
Ok (query, result)
| Error _ as r ->
Expand All @@ -715,9 +667,8 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
{query; param_length; binary_params}
in
let params = Array.make prepared.param_length Pg.null in
(match encode_param ~uri params param_type param 0 with
| Ok n ->
assert (n = prepared.param_length);
(match Param_encoder.encode ~uri params param_type param with
| Ok () ->
query_prepared query_id prepared params >|=? fun result ->
Ok (prepared.query, result)
| Error _ as r ->
Expand Down Expand Up @@ -774,7 +725,6 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
let binary_params = Array.make param_length false in
let nbp = set_binary_params binary_params row_type 0 in
assert (nbp = param_length);
let binary_escape = db#escape_bytea in
let fail msg =
return
(Error (Caqti_error.request_failed ~uri ~query (Caqti_error.Msg msg)))
Expand All @@ -799,19 +749,8 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
in
let copy_row row =
let params = Array.make param_length Pg.null in
let num_encoded_params =
copy_encode_param
~uri
params
binary_params
binary_escape
row_type
row
0
in
(match num_encoded_params with
| Ok n ->
assert (n = param_length);
(match Copy_encoder.encode ~uri params row_type row with
| Ok () ->
return (Ok (String.concat "," (Array.to_list params)))
| Error _ as r ->
return r)
Expand Down

0 comments on commit e9befd6

Please sign in to comment.