Skip to content

Commit

Permalink
Merge pull request #14396 from MinaProtocol/feature/pickles-js-caching
Browse files Browse the repository at this point in the history
Cache prover keys
  • Loading branch information
mitschabaude authored Oct 31, 2023
2 parents 28aef4f + 5aff374 commit fbd97fe
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 68 deletions.
2 changes: 1 addition & 1 deletion src/lib/crypto/proof-systems
89 changes: 50 additions & 39 deletions src/lib/pickles/cache.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ module Step = struct
[@@warning "-4"]
end

type storable =
(Key.Proving.t, Backend.Tick.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
( Key.Verification.t
, Kimchi_bindings.Protocol.VerifierIndex.Fp.t )
Key_cache.Sync.Disk_storable.t

let storable =
Key_cache.Sync.Disk_storable.simple Key.Proving.to_string
(fun (_, header, _, cs) ~path ->
Expand Down Expand Up @@ -83,9 +91,8 @@ module Step = struct
(Kimchi_bindings.Protocol.VerifierIndex.Fp.write (Some true) x)
header path ) )

let read_or_generate ~prev_challenges cache k_p k_v typ return_typ main =
let s_p = storable in
let s_v = vk_storable in
let read_or_generate ~prev_challenges cache ?(s_p = storable) k_p
?(s_v = vk_storable) k_v typ return_typ main =
let open Impls.Step in
let pk =
lazy
Expand Down Expand Up @@ -154,6 +161,12 @@ module Wrap = struct
end
end

type storable =
(Key.Proving.t, Backend.Tock.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
(Key.Verification.t, Verification_key.t) Key_cache.Sync.Disk_storable.t

let storable =
Key_cache.Sync.Disk_storable.simple Key.Proving.to_string
(fun (_, header, cs) ~path ->
Expand Down Expand Up @@ -181,10 +194,42 @@ module Wrap = struct
(Kimchi_bindings.Protocol.Index.Fq.write (Some true) t.index)
header path ) )

let read_or_generate ~prev_challenges cache k_p k_v typ return_typ main =
let vk_storable =
Key_cache.Sync.Disk_storable.simple Key.Verification.to_string
(fun (_, header, _cs) ~path ->
Or_error.try_with_join (fun () ->
let open Or_error.Let_syntax in
let%map header_read, index =
Snark_keys_header.read_with_header
~read_data:(fun ~offset:_ path ->
Binable.of_string
(module Verification_key.Stable.Latest)
(In_channel.read_all path) )
path
in
[%test_eq: int] header.header_version header_read.header_version ;
[%test_eq: Snark_keys_header.Kind.t] header.kind header_read.kind ;
[%test_eq: Snark_keys_header.Constraint_constants.t]
header.constraint_constants header_read.constraint_constants ;
[%test_eq: string] header.constraint_system_hash
header_read.constraint_system_hash ;
index ) )
(fun (_, header, _) t path ->
Or_error.try_with (fun () ->
Snark_keys_header.write_with_header
~expected_max_size_log2:33 (* 8 GB should be enough *)
~append_data:(fun path ->
Out_channel.with_file ~append:true path ~f:(fun file ->
Out_channel.output_string file
(Binable.to_string
(module Verification_key.Stable.Latest)
t ) ) )
header path ) )

let read_or_generate ~prev_challenges cache ?(s_p = storable) k_p
?(s_v = vk_storable) k_v typ return_typ main =
let module Vk = Verification_key in
let open Impls.Wrap in
let s_p = storable in
let pk =
lazy
(let k = Lazy.force k_p in
Expand All @@ -208,40 +253,6 @@ module Wrap = struct
let vk =
lazy
(let k_v = Lazy.force k_v in
let s_v =
Key_cache.Sync.Disk_storable.simple Key.Verification.to_string
(fun (_, header, _cs) ~path ->
Or_error.try_with_join (fun () ->
let open Or_error.Let_syntax in
let%map header_read, index =
Snark_keys_header.read_with_header
~read_data:(fun ~offset:_ path ->
Binable.of_string
(module Vk.Stable.Latest)
(In_channel.read_all path) )
path
in
[%test_eq: int] header.header_version
header_read.header_version ;
[%test_eq: Snark_keys_header.Kind.t] header.kind
header_read.kind ;
[%test_eq: Snark_keys_header.Constraint_constants.t]
header.constraint_constants
header_read.constraint_constants ;
[%test_eq: string] header.constraint_system_hash
header_read.constraint_system_hash ;
index ) )
(fun (_, header, _) t path ->
Or_error.try_with (fun () ->
Snark_keys_header.write_with_header
~expected_max_size_log2:33 (* 8 GB should be enough *)
~append_data:(fun path ->
Out_channel.with_file ~append:true path ~f:(fun file ->
Out_channel.output_string file
(Binable.to_string (module Vk.Stable.Latest) t) )
)
header path ) )
in
match Key_cache.Sync.read cache s_v k_v with
| Ok (vk, d) ->
(vk, d)
Expand Down
36 changes: 34 additions & 2 deletions src/lib/pickles/cache.mli
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ module Step : sig
* Snark_keys_header.t
* int
* Backend.Tick.R1CS_constraint_system.t

val to_string : t -> string
end

module Verification : sig
Expand All @@ -17,13 +19,29 @@ module Step : sig
* Snark_keys_header.t
* int
* Core_kernel.Md5.t

val to_string : t -> string
end
end

type storable =
(Key.Proving.t, Backend.Tick.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
( Key.Verification.t
, Kimchi_bindings.Protocol.VerifierIndex.Fp.t )
Key_cache.Sync.Disk_storable.t

val storable : storable

val vk_storable : vk_storable

val read_or_generate :
prev_challenges:int
-> Key_cache.Spec.t list
-> ?s_p:storable
-> Key.Proving.t lazy_t
-> ?s_v:vk_storable
-> Key.Verification.t lazy_t
-> ('a, 'b) Impls.Step.Typ.t
-> ('c, 'd) Impls.Step.Typ.t
Expand All @@ -43,6 +61,8 @@ module Wrap : sig
Core_kernel.Type_equal.Id.Uid.t
* Snark_keys_header.t
* Backend.Tock.R1CS_constraint_system.t

val to_string : t -> string
end

module Verification : sig
Expand All @@ -59,11 +79,23 @@ module Wrap : sig
end
end

type storable =
(Key.Proving.t, Backend.Tock.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
(Key.Verification.t, Verification_key.t) Key_cache.Sync.Disk_storable.t

val storable : storable

val vk_storable : vk_storable

val read_or_generate :
prev_challenges:Core_kernel.Int.t
-> Key_cache.Spec.t list
-> Key.Proving.t Core_kernel.Lazy.t
-> Key.Verification.t Core_kernel.Lazy.t
-> ?s_p:storable
-> Key.Proving.t lazy_t
-> ?s_v:vk_storable
-> Key.Verification.t lazy_t
-> ('a, 'b) Impls.Wrap.Typ.t
-> ('c, 'd) Impls.Wrap.Typ.t
-> ('a -> unit -> 'c)
Expand Down
40 changes: 31 additions & 9 deletions src/lib/pickles/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,22 @@ type ('max_proofs_verified, 'branches, 'prev_varss) wrap_main_generic =
*)
}

module Storables = struct
type t =
{ step_storable : Cache.Step.storable
; step_vk_storable : Cache.Step.vk_storable
; wrap_storable : Cache.Wrap.storable
; wrap_vk_storable : Cache.Wrap.vk_storable
}

let default =
{ step_storable = Cache.Step.storable
; step_vk_storable = Cache.Step.vk_storable
; wrap_storable = Cache.Wrap.storable
; wrap_vk_storable = Cache.Wrap.vk_storable
}
end

module Make
(Arg_var : Statement_var_intf)
(Arg_value : Statement_value_intf)
Expand Down Expand Up @@ -340,6 +356,7 @@ struct
type var value prev_varss prev_valuess widthss heightss max_proofs_verified branches.
self:(var, value, max_proofs_verified, branches) Tag.t
-> cache:Key_cache.Spec.t list
-> storables:Storables.t
-> proof_cache:Proof_cache.t option
-> ?disk_keys:
(Cache.Step.Key.Verification.t, branches) Vector.t
Expand Down Expand Up @@ -378,10 +395,13 @@ struct
* _
* _
* _ =
fun ~self ~cache ~proof_cache ?disk_keys
?(return_early_digest_exception = false) ?override_wrap_domain
?override_wrap_main ~branches:(module Branches) ~max_proofs_verified
~name ~constraint_constants ~public_input ~auxiliary_typ ~choices () ->
fun ~self ~cache
~storables:
{ step_storable; step_vk_storable; wrap_storable; wrap_vk_storable }
~proof_cache ?disk_keys ?(return_early_digest_exception = false)
?override_wrap_domain ?override_wrap_main ~branches:(module Branches)
~max_proofs_verified ~name ~constraint_constants ~public_input
~auxiliary_typ ~choices () ->
let snark_keys_header kind constraint_system_hash =
{ Snark_keys_header.header_version = Snark_keys_header.header_version
; kind
Expand Down Expand Up @@ -595,7 +615,7 @@ struct
Common.time "step read or generate" (fun () ->
Cache.Step.read_or_generate
~prev_challenges:(Nat.to_int (fst b.proofs_verified))
cache k_p k_v
cache ~s_p:step_storable k_p ~s_v:step_vk_storable k_v
(Snarky_backendless.Typ.unit ())
typ main )
in
Expand Down Expand Up @@ -671,7 +691,8 @@ struct
let r =
Common.time "wrap read or generate " (fun () ->
Cache.Wrap.read_or_generate (* Due to Wrap_hack *)
~prev_challenges:2 cache disk_key_prover disk_key_verifier typ
~prev_challenges:2 cache ~s_p:wrap_storable disk_key_prover
~s_v:wrap_vk_storable disk_key_verifier typ
(Snarky_backendless.Typ.unit ())
main )
in
Expand Down Expand Up @@ -938,6 +959,7 @@ let compile_with_wrap_main_override_promise :
type var value a_var a_value ret_var ret_value auxiliary_var auxiliary_value prev_varss prev_valuess widthss heightss max_proofs_verified branches.
?self:(var, value, max_proofs_verified, branches) Tag.t
-> ?cache:Key_cache.Spec.t list
-> ?storables:Storables.t
-> ?proof_cache:Proof_cache.t
-> ?disk_keys:
(Cache.Step.Key.Verification.t, branches) Vector.t
Expand Down Expand Up @@ -991,8 +1013,8 @@ let compile_with_wrap_main_override_promise :
(* This function is an adapter between the user-facing Pickles.compile API
and the underlying Make(_).compile function which builds the circuits.
*)
fun ?self ?(cache = []) ?proof_cache ?disk_keys
?(return_early_digest_exception = false) ?override_wrap_domain
fun ?self ?(cache = []) ?(storables = Storables.default) ?proof_cache
?disk_keys ?(return_early_digest_exception = false) ?override_wrap_domain
?override_wrap_main ~public_input ~auxiliary_typ ~branches
~max_proofs_verified ~name ~constraint_constants ~choices () ->
let self =
Expand Down Expand Up @@ -1061,7 +1083,7 @@ let compile_with_wrap_main_override_promise :
in
let provers, wrap_vk, wrap_disk_key, cache_handle =
M.compile ~return_early_digest_exception ~self ~proof_cache ~cache
?disk_keys ?override_wrap_domain ?override_wrap_main ~branches
~storables ?disk_keys ?override_wrap_domain ?override_wrap_main ~branches
~max_proofs_verified ~name ~public_input ~auxiliary_typ
~constraint_constants
~choices:(fun ~self -> conv_irs (choices ~self))
Expand Down
12 changes: 12 additions & 0 deletions src/lib/pickles/compile.mli
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,24 @@ type ('max_proofs_verified, 'branches, 'prev_varss) wrap_main_generic =
*)
}

module Storables : sig
type t =
{ step_storable : Cache.Step.storable
; step_vk_storable : Cache.Step.vk_storable
; wrap_storable : Cache.Wrap.storable
; wrap_vk_storable : Cache.Wrap.vk_storable
}

val default : t
end

(** This compiles a series of inductive rules defining a set into a proof
system for proving membership in that set, with a prover corresponding
to each inductive rule. *)
val compile_with_wrap_main_override_promise :
?self:('var, 'value, 'max_proofs_verified, 'branches) Tag.t
-> ?cache:Key_cache.Spec.t list
-> ?storables:Storables.t
-> ?proof_cache:Proof_cache.t
-> ?disk_keys:
(Cache.Step.Key.Verification.t, 'branches) Vector.t
Expand Down
24 changes: 13 additions & 11 deletions src/lib/pickles/pickles.ml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ module Make_str (_ : Wire_types.Concrete) = struct
module Step_main_inputs = Step_main_inputs
module Step_verifier = Step_verifier
module Proof_cache = Proof_cache
module Cache = Cache
module Storables = Compile.Storables

exception Return_digest = Compile.Return_digest

Expand Down Expand Up @@ -306,22 +308,22 @@ module Make_str (_ : Wire_types.Concrete) = struct
let compile_with_wrap_main_override_promise =
Compile.compile_with_wrap_main_override_promise

let compile_promise ?self ?cache ?proof_cache ?disk_keys
let compile_promise ?self ?cache ?storables ?proof_cache ?disk_keys
?return_early_digest_exception ?override_wrap_domain ~public_input
~auxiliary_typ ~branches ~max_proofs_verified ~name ~constraint_constants
~choices () =
compile_with_wrap_main_override_promise ?self ?cache ?proof_cache ?disk_keys
?return_early_digest_exception ?override_wrap_domain ~public_input
~auxiliary_typ ~branches ~max_proofs_verified ~name ~constraint_constants
~choices ()

let compile ?self ?cache ?proof_cache ?disk_keys ?override_wrap_domain
compile_with_wrap_main_override_promise ?self ?cache ?storables ?proof_cache
?disk_keys ?return_early_digest_exception ?override_wrap_domain
~public_input ~auxiliary_typ ~branches ~max_proofs_verified ~name
~constraint_constants ~choices () =
~constraint_constants ~choices ()

let compile ?self ?cache ?storables ?proof_cache ?disk_keys
?override_wrap_domain ~public_input ~auxiliary_typ ~branches
~max_proofs_verified ~name ~constraint_constants ~choices () =
let self, cache_handle, proof_module, provers =
compile_promise ?self ?cache ?proof_cache ?disk_keys ?override_wrap_domain
~public_input ~auxiliary_typ ~branches ~max_proofs_verified ~name
~constraint_constants ~choices ()
compile_promise ?self ?cache ?storables ?proof_cache ?disk_keys
?override_wrap_domain ~public_input ~auxiliary_typ ~branches
~max_proofs_verified ~name ~constraint_constants ~choices ()
in
let rec adjust_provers :
type a1 a2 a3 s1 s2_inner.
Expand Down
Loading

0 comments on commit fbd97fe

Please sign in to comment.