diff --git a/opam b/opam index 0a35e28..d5a855a 100644 --- a/opam +++ b/opam @@ -11,7 +11,7 @@ remove: [ ] depends: [ "ocamlfind" - "pgocaml" {< "2.0"} + "pgocaml" {>= "2.0"} "oasis" {>= "0.4.4"} "camlp4" ] diff --git a/src/sql.mli b/src/sql.mli index b1fe06f..39824bd 100644 --- a/src/sql.mli +++ b/src/sql.mli @@ -35,21 +35,26 @@ val non_nullable_witness : non_nullable nul_witness class type ['t] type_info = object method typ : 't end class type numeric_t = object method numeric : unit end +class type arrayable_t = object method arrayable : unit end + +class type ['t] array_t = object + constraint 't = < typ : 'ty; arrayable : unit; .. > + inherit ['ty option list] type_info +end class type unit_t = object inherit [unit] type_info end -class type bool_t = object inherit [bool] type_info end +class type bool_t = object inherit [bool] type_info inherit arrayable_t end class type int16_t = object inherit [int16] type_info inherit numeric_t end -class type int32_t = object inherit [int32] type_info inherit numeric_t end -class type int64_t = object inherit [int64] type_info inherit numeric_t end -class type float_t = object inherit [float] type_info inherit numeric_t end -class type string_t = object inherit [string] type_info end +class type int32_t = object inherit [int32] type_info inherit numeric_t inherit arrayable_t end +class type int64_t = object inherit [int64] type_info inherit numeric_t inherit arrayable_t end +class type float_t = object inherit [float] type_info inherit numeric_t inherit arrayable_t end +class type string_t = object inherit [string] type_info inherit arrayable_t end class type bytea_t = object inherit [bytea] type_info end class type time_t = object inherit [time] type_info end class type date_t = object inherit [date] type_info end class type timestamp_t = object inherit [timestamp] type_info end class type timestamptz_t = object inherit [timestamptz] type_info end class type interval_t = object inherit [interval] type_info end -class type int32_array_t = object inherit [int32 array] type_info end class type ['row] row_t = object inherit ['row] type_info end @@ -211,8 +216,16 @@ module Table_type : sig < get : unit; nul : 'nul; t : timestamptz_t > sql_type val interval : 'nul nul_witness -> < get : unit; nul : 'nul; t : interval_t > sql_type + val bool_array : 'nul nul_witness -> + < get : unit; nul : 'nul; t : bool_t array_t > sql_type val int32_array : 'nul nul_witness -> - < get : unit; nul : 'nul; t : int32_array_t > sql_type + < get : unit; nul : 'nul; t : int32_t array_t > sql_type + val int64_array : 'nul nul_witness -> + < get : unit; nul : 'nul; t : int64_t array_t > sql_type + val float_array : 'nul nul_witness -> + < get : unit; nul : 'nul; t : float_t array_t > sql_type + val string_array : 'nul nul_witness -> + < get : unit; nul : 'nul; t : string_t array_t > sql_type end (** final query building *) @@ -255,7 +268,11 @@ module Value : sig val timestamp : timestamp -> < t : timestamp_t; nul : _ > t val timestamptz : timestamptz -> < t : timestamptz_t; nul : _ > t val interval : interval -> < t : interval_t; nul : _ > t - val int32_array : int32 array -> < t : int32_array_t; nul : _ > t + val bool_array : bool option list -> < t : bool_t array_t; nul : _ > t + val int32_array : int32 option list -> < t : int32_t array_t; nul : _ > t + val int64_array : int64 option list -> < t : int64_t array_t; nul : _ > t + val float_array : float option list -> < t : float_t array_t; nul : _ > t + val string_array : string option list -> < t : string_t array_t; nul : _ > t end @@ -331,6 +348,8 @@ module Op : sig < t : #numeric_t as 't; nul : 'n; .. > group -> < t : 't; nul : nullable > t val md5 : < t : string_t; nul : 'n; .. > group -> < t : string_t; nul : 'n > t + val array_agg : + < t : #arrayable_t as 't; .. > group -> < t : 't array_t; nul : nullable > t (** sequence functions *) val nextval : 'a sequence -> < t : 'a; nul : non_nullable > t diff --git a/src/sql_base.ml b/src/sql_base.ml index 6dc8535..2a8692f 100644 --- a/src/sql_base.ml +++ b/src/sql_base.ml @@ -34,6 +34,10 @@ and bytea = PGOCaml.bytea and time = CalendarLib.Time.t and date = CalendarLib.Date.t and interval = CalendarLib.Calendar.Period.t -and timestamp = CalendarLib.Calendar.t +and timestamp = CalendarLib.Calendar.t and timestamptz = PGOCaml.timestamptz (* = CalendarLib.Calendar.t * CalendarLib.Time_Zone.t *) +and bool_array = PGOCaml.bool_array and int32_array = PGOCaml.int32_array +and int64_array = PGOCaml.int64_array +and float_array = PGOCaml.float_array +and string_array = PGOCaml.string_array diff --git a/src/sql_internals.ml b/src/sql_internals.ml index 3b447c6..bb99b51 100644 --- a/src/sql_internals.ml +++ b/src/sql_internals.ml @@ -74,7 +74,11 @@ and atom = | Timestamp of timestamp | Timestamptz of timestamptz | Interval of interval - | Int32_array of int32 array + | Bool_array of bool option list + | Int32_array of int32 option list + | Int64_array of int64 option list + | Float_array of float option list + | String_array of string option list | Record of untyped (* runtime object instance *) and table_name = string option * string and row_name = string @@ -97,7 +101,7 @@ and atom_type = | TTimestamp | TTimestamptz | TInterval - | TInt32_array + | TArray of atom_type | TRecord of unit generic_view and 'a record_parser = descr -> 'a result_parser @@ -128,7 +132,11 @@ let atom_type_of_string = function | "timestamp" -> TTimestamp | "timestamptz" -> TTimestamptz | "interval" -> TInterval - | "int32_array" -> TInt32_array + | "bool_array" -> TArray TBool + | "int32_array" -> TArray TInt32 + | "int64_array" -> TArray TInt64 + | "float_array" -> TArray TFloat + | "string_array" -> TArray TString | other -> failwith ("unknown sql type " ^ other) let string_of_atom_type = function | TUnit -> "unit" @@ -144,7 +152,12 @@ let string_of_atom_type = function | TTimestamp -> "timestamp" | TTimestamptz -> "timestamptz" | TInterval -> "interval" - | TInt32_array -> "int32_array" + | TArray TBool -> "boolean[]" + | TArray TInt32 -> "integer[]" + | TArray TInt64 -> "bigint[]" + | TArray TFloat -> "double precision[]" + | TArray TString -> "text[]" + | TArray _ -> assert false | TRecord _ -> "record" type query = @@ -176,7 +189,7 @@ let rec unify t t' = let unify_atom a a' = match a, a' with (* identity unifications *) | ( TUnit | TBool | TInt16 | TInt32 | TInt64 | TFloat - | TString | TBytea | TTime | TDate | TInterval | TInt32_array + | TString | TBytea | TTime | TDate | TInterval | TArray _ | TTimestamp | TTimestamptz) as t, t' when t = t' -> t | TRecord r, TRecord r' -> let fields descr = List.sort compare (List.map fst descr) in @@ -191,7 +204,7 @@ let rec unify t t' = (* failure *) | ( TUnit | TBool | TInt16 | TInt32 | TInt64 | TFloat - | TString | TBytea | TTime | TDate | TInterval | TInt32_array + | TString | TBytea | TTime | TDate | TInterval | TArray _ | TTimestamp | TTimestamptz | TRecord _), _ -> failwith (Printf.sprintf "unify (%s and %s)" diff --git a/src/sql_parsers.ml b/src/sql_parsers.ml index faa98dd..9bcfcb1 100644 --- a/src/sql_parsers.ml +++ b/src/sql_parsers.ml @@ -45,8 +45,16 @@ let timestamptzval_of_string s = pack (Timestamptz (PGOCaml.timestamptz_of_string s)) TTimestamptz let intervalval_of_string s = pack (Interval (PGOCaml.interval_of_string s)) TInterval +let bool_array_of_string s = + pack (Bool_array (PGOCaml.bool_array_of_string s)) (TArray TBool) let int32_array_of_string s = - pack (Int32_array (PGOCaml.int32_array_of_string s)) TInt32_array + pack (Int32_array (PGOCaml.int32_array_of_string s)) (TArray TInt32) +let int64_array_of_string s = + pack (Int64_array (PGOCaml.int64_array_of_string s)) (TArray TInt64) +let float_array_of_string s = + pack (Float_array (PGOCaml.float_array_of_string s)) (TArray TFloat) +let string_array_of_string s = + pack (String_array (PGOCaml.string_array_of_string s)) (TArray TString) let unit_field_parser = unsafe_parser (incr &&& unitval_of_string) let bool_field_parser = unsafe_parser (incr &&& boolval_of_string) @@ -61,7 +69,11 @@ let date_field_parser = unsafe_parser (incr &&& dateval_of_string) let timestamp_field_parser = unsafe_parser (incr &&& timestampval_of_string) let timestamptz_field_parser = unsafe_parser (incr &&& timestamptzval_of_string) let interval_field_parser = unsafe_parser (incr &&& intervalval_of_string) +let bool_array_field_parser = unsafe_parser (incr &&& bool_array_of_string) let int32_array_field_parser = unsafe_parser (incr &&& int32_array_of_string) +let int64_array_field_parser = unsafe_parser (incr &&& int64_array_of_string) +let float_array_field_parser = unsafe_parser (incr &&& float_array_of_string) +let string_array_field_parser = unsafe_parser (incr &&& string_array_of_string) let error_field_parser= unsafe_parser (ignore &&& (fun _ -> failwith "Error parser")) @@ -101,7 +113,12 @@ let parser_of_type = | TTimestamp -> timestamp_field_parser | TTimestamptz -> timestamptz_field_parser | TInterval -> interval_field_parser - | TInt32_array -> int32_array_field_parser + | TArray TBool -> bool_array_field_parser + | TArray TInt32 -> int32_array_field_parser + | TArray TInt64 -> int64_array_field_parser + | TArray TFloat -> float_array_field_parser + | TArray TString -> string_array_field_parser + | TArray _ -> assert false | TRecord t -> record_parser t in function | Non_nullable typ -> parser_of_sql_type typ diff --git a/src/sql_printers.ml b/src/sql_printers.ml index 42e477e..c2b48a1 100644 --- a/src/sql_printers.ml +++ b/src/sql_printers.ml @@ -181,7 +181,11 @@ and string_of_atom = | Timestamp i -> quote PGOCaml.string_of_timestamp i | Timestamptz i -> quote PGOCaml.string_of_timestamptz i | Interval i -> quote PGOCaml.string_of_interval i + | Bool_array js -> quote PGOCaml.string_of_bool_array js | Int32_array js -> quote PGOCaml.string_of_int32_array js + | Int64_array js -> quote PGOCaml.string_of_int64_array js + | Float_array js -> quote PGOCaml.string_of_float_array js + | String_array js -> quote PGOCaml.string_of_string_array js | Record t -> (* all records should have been expanded, that's the !atom-records flatten postcondition *) diff --git a/src/sql_public.ml b/src/sql_public.ml index 393be38..c875393 100644 --- a/src/sql_public.ml +++ b/src/sql_public.ml @@ -43,7 +43,11 @@ module Value = struct let timestamp i = Atom (Timestamp i), Non_nullable TTimestamp let timestamptz i = Atom (Timestamptz i), Non_nullable TTimestamptz let interval i = Atom (Interval i), Non_nullable TInterval - let int32_array js = Atom (Int32_array js), Non_nullable TInt32_array + let bool_array js = Atom (Bool_array js), Non_nullable (TArray TBool) + let int32_array js = Atom (Int32_array js), Non_nullable (TArray TInt32) + let int64_array js = Atom (Int64_array js), Non_nullable (TArray TInt64) + let float_array js = Atom (Float_array js), Non_nullable (TArray TFloat) + let string_array js = Atom (String_array js), Non_nullable (TArray TString) end type 'a sequence = string * atom_type @@ -124,6 +128,13 @@ module Op = struct let max (v, t) = nullable (prefixop "max" (v, t), t) let sum (v, t) = nullable (prefixop "sum" (v, t), t) let md5 (v, t) = prefixop "md5" (v, t), t + let array_agg (v, t) = + let to_array = function + | Nullable None -> Nullable None + | Nullable (Some t) + | Non_nullable t -> Nullable (Some (TArray t)) + in + prefixop "array_agg" (v, t), to_array t let label seq_name = Atom (String seq_name), Non_nullable TString let nextval (seq_name, typ) = @@ -156,7 +167,11 @@ module Table_type = struct let timestamp = _type TTimestamp let timestamptz = _type TTimestamptz let interval = _type TInterval - let int32_array = _type TInt32_array + let bool_array = _type (TArray TBool) + let int32_array = _type (TArray TInt32) + let int64_array = _type (TArray TInt64) + let float_array = _type (TArray TFloat) + let string_array = _type (TArray TString) end module View = struct diff --git a/src/sql_types.ml b/src/sql_types.ml index aa95d1d..97ba9cd 100644 --- a/src/sql_types.ml +++ b/src/sql_types.ml @@ -29,21 +29,26 @@ type non_nullable class type ['t] type_info = object method typ : 't end class type numeric_t = object method numeric : unit end +class type arrayable_t = object method arrayable : unit end + +class type ['t] array_t = object + constraint 't = < typ : 'ty; arrayable : unit; .. > + inherit ['ty option list] type_info +end class type unit_t = object inherit [unit] type_info end -class type bool_t = object inherit [bool] type_info end +class type bool_t = object inherit [bool] type_info inherit arrayable_t end class type int16_t = object inherit [int16] type_info inherit numeric_t end -class type int32_t = object inherit [int32] type_info inherit numeric_t end -class type int64_t = object inherit [int64] type_info inherit numeric_t end -class type float_t = object inherit [float] type_info inherit numeric_t end -class type string_t = object inherit [string] type_info end +class type int32_t = object inherit [int32] type_info inherit numeric_t inherit arrayable_t end +class type int64_t = object inherit [int64] type_info inherit numeric_t inherit arrayable_t end +class type float_t = object inherit [float] type_info inherit numeric_t inherit arrayable_t end +class type string_t = object inherit [string] type_info inherit arrayable_t end class type bytea_t = object inherit [bytea] type_info end class type time_t = object inherit [time] type_info end class type date_t = object inherit [date] type_info end class type timestamp_t = object inherit [timestamp] type_info end class type timestamptz_t = object inherit [timestamptz] type_info end class type interval_t = object inherit [interval] type_info end -class type int32_array_t = object inherit [int32 array] type_info end class type ['row] row_t = object inherit ['row] type_info end @@ -109,7 +114,11 @@ let get_val : < get : _; t : 'a #type_info; .. > atom -> 'a = | SQLI.Timestamp t -> !?t | SQLI.Timestamptz t -> !?t | SQLI.Interval i -> !?i + | SQLI.Bool_array js -> !?js | SQLI.Int32_array js -> !?js + | SQLI.Int64_array js -> !?js + | SQLI.Float_array js -> !?js + | SQLI.String_array js -> !?js | SQLI.Record o -> !?o let get ((r, t) : 'a t) =