From daa0fc51f50e9ffbb8cf0be440618d8af92247cd Mon Sep 17 00:00:00 2001 From: Ryan Tjoa Date: Thu, 31 Oct 2024 21:12:44 -0400 Subject: [PATCH] benchmark training and sample from bdd --- Project.toml | 2 + qc/benchmarks/benchmarks.jl | 8 + qc/benchmarks/rbt_faster.jl | 152 ++++++++++++++++++ qc/benchmarks/stlc_faster_10samples.jl | 87 ++++++++++ qc/benchmarks/stlc_faster_200samples.jl | 85 +++++++++- .../stlc_faster_200samples_from_bdd.jl | 100 ++++++++++++ src/dist/inductive/inductive.jl | 29 ++++ 7 files changed, 458 insertions(+), 5 deletions(-) create mode 100644 qc/benchmarks/rbt_faster.jl create mode 100644 qc/benchmarks/stlc_faster_10samples.jl create mode 100644 qc/benchmarks/stlc_faster_200samples_from_bdd.jl diff --git a/Project.toml b/Project.toml index 215f69e5..3d2f7052 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] ADEV = "91c67158-5de4-465b-a572-6ca3a628f939" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDD = "345a2cc7-28d8-58b2-abdf-cff77ea7d7f1" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -18,6 +19,7 @@ Jive = "ba5e3d4b-8524-549f-bc71-e76ad9e9deed" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/qc/benchmarks/benchmarks.jl b/qc/benchmarks/benchmarks.jl index 6b222f4f..d4a62e8d 100644 --- a/qc/benchmarks/benchmarks.jl +++ b/qc/benchmarks/benchmarks.jl @@ -348,6 +348,14 @@ function wellTyped(e::OptExpr.t) None() -> false, ] end +function wellTyped(e::Expr.t) + @assert isdeterministic(e) + @match typecheck(e) [ + Some(_) -> true, + None() -> false, + ] +end + ################################## # Sampling STLC entropy loss diff --git a/qc/benchmarks/rbt_faster.jl b/qc/benchmarks/rbt_faster.jl new file mode 100644 index 00000000..1b673ad9 --- /dev/null +++ b/qc/benchmarks/rbt_faster.jl @@ -0,0 +1,152 @@ +using Revise +using Dice +include("benchmarks.jl") + +generation_params = LangSiblingDerivedGenerator{RBT}( + root_ty=ColorKVTree.t, + ty_sizes=[ColorKVTree.t=>4, Color.t=>0], + stack_size=2, + intwidth=3, +) + +SEED = 0 +out_dir="/tmp" +log_path="/dev/null" +rs = RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params) + +generation::Generation = generate(rs, generation_params) + +g::Dist = generation.value + +# Assignments +# rs.var_vals + +# Distribution of constructors of root node: +pr_mixed(rs.var_vals)(g.union.which) + +# Sample some tree until it's valid (TODO: make this faster) +a = ADComputer(rs.var_vals) +isRBT(t) = satisfies_bookkeeping_invariant(t) && satisfies_balance_invariant(t) && satisfies_order_invariant(t) +using BenchmarkTools + +@benchmark begin + samples = [] + while length(samples) < 200 + some_tree = sample_as_dist(rs.rng, a, g) + if isRBT(some_tree) + push!(samples, some_tree) + end + end +end + +# one sample +# BenchmarkTools.Trial: 1683 samples with 1 evaluation. +# Range (min … max): 1.789 ms … 29.207 ms ┊ GC (min … max): 0.00% … 77.88% +# Time (median): 2.012 ms ┊ GC (median): 0.00% +# Time (mean ± σ): 2.895 ms ± 2.100 ms ┊ GC (mean ± σ): 4.64% ± 7.26% + +# █▇▅▃ ▃▅▃▂▃▁ ▁▂▂▁ +# █████▆▄▁▁▁██████▇▆▁▄▆████▇▇▇▅▇▄▇▇▅▆▆▅▅▇▄▄▅▅▅▆▄▆▄▁▁▄▄▅▁▄▄▁▄ █ +# 1.79 ms Histogram: log(frequency) by time 8.87 ms < + +# Memory estimate: 759.81 KiB, allocs estimate: 19182. + +# 200 samples +# BenchmarkTools.Trial: 9 samples with 1 evaluation. +# Range (min … max): 551.427 ms … 637.939 ms ┊ GC (min … max): 3.72% … 6.07% +# Time (median): 571.511 ms ┊ GC (median): 3.63% +# Time (mean ± σ): 577.534 ms ± 29.908 ms ┊ GC (mean ± σ): 4.56% ± 1.59% + +# █ ▃ +# █▁▁▁▁▇▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁ +# 551 ms Histogram: frequency by time 638 ms < + +# Memory estimate: 207.46 MiB, allocs estimate: 5280362. + +some_tree + +# .551 * 1000 / 60 ~= 9 minutes on sampling + + +# Every other epoch, we spend 1/2 a second taking ~300 samples in order to get +# exactly 200 samples that meet the spec. + +# "smart conditional sampling" saves at most 2/9 of runtime for RBT + +# time per epoch: ~.25 + +retries = 0 +samples = [] +while length(samples) < 200 + retries +=1 + some_tree = sample_as_dist(rs.rng, a, g) + if isRBT(some_tree) + push!(samples, some_tree) + end +end + +retries # 321 samples taken + + +l = Dice.LogPrExpander(WMC(BDDCompiler([ + prob_equals(g,sample) + for sample in samples +]))) + +num_meeting = 0 +@time begin + loss, actual_loss = sum( + begin + lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, sample))) + [lpr_eq * compute(a, lpr_eq), lpr_eq] + end + for sample in samples + ) +end +# 1.74s on first run, ~.5 seconds on later runs + +length(l.cache) # 935 + +# 0.165 seconds first time + + +@benchmark vals, derivs = differentiate( + rs.var_vals, + Derivs([loss => 1.]) +) + +# BenchmarkTools.Trial: 1867 samples with 1 evaluation. +# Range (min … max): 2.441 ms … 23.635 ms ┊ GC (min … max): 0.00% … 88.88% +# Time (median): 2.544 ms ┊ GC (median): 0.00% +# Time (mean ± σ): 2.634 ms ± 1.110 ms ┊ GC (mean ± σ): 2.54% ± 5.30% + +# ▁▁▁▅▆▆▆▇█▇▅▇▂▂ +# ▃▄▆████████████████▇▇▆▆▆▅▅▄▄▄▄▃▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▃▃▃▂▂▂▁▂▂▂▂ ▄ +# 2.44 ms Histogram: frequency by time 2.9 ms < + +# Memory estimate: 635.62 KiB, allocs estimate: 19618. + +ct = [0] +Dice.foreach_down(loss) do _ + ct[1] += 1 +end +ct # 1334 + +p_eq_g = prob_equals(some_tree, g) +to_maximize::Dice.ADNode = LogPr(p_eq_g) +using ProfileView + +pr_mixed(rs.var_vals)(p_eq_g) + +l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([to_maximize])))) +to_maximize_expanded = Dice.expand_logprs(l, to_maximize) + +using ProfileView + +ProfileView.@profview begin + vals, derivs = Dice.differentiate( + rs.var_vals, + Derivs(to_maximize_expanded => 1.) + ) +end + diff --git a/qc/benchmarks/stlc_faster_10samples.jl b/qc/benchmarks/stlc_faster_10samples.jl new file mode 100644 index 00000000..91918f15 --- /dev/null +++ b/qc/benchmarks/stlc_faster_10samples.jl @@ -0,0 +1,87 @@ +using Revise +using Dice +using BenchmarkTools +using ProfileView + +include("benchmarks.jl") + +generation_params = LangSiblingDerivedGenerator{STLC}( + root_ty=Expr.t, + ty_sizes=[Expr.t=>5, Typ.t=>2], + stack_size=2, + intwidth=3, +) + +SEED = 0 +out_dir="/tmp" +log_path="/dev/null" +rs = RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params) + +generation::Generation = generate(rs, generation_params) + +g::Dist = generation.value + +# Sample some tree until it's valid (TODO: make this faster) +a = ADComputer(rs.var_vals) + +NUM_SAMPLES = 10 + +function wellTyped(e) + @assert isdeterministic(e) + @match typecheck(e) [ + Some(_) -> true, + None() -> false, + ] +end + +retries = Ref(0) +#== @benchmark ==# @time begin + samples = [] + while length(samples) < NUM_SAMPLES + retries[] += 1 + s = sample_as_dist(rs.rng, a, g) + if wellTyped(s) + push!(samples, s) + end + end +end +# Single result which took 26.426 s (3.00% GC) to evaluate, (7s, 26s, 30s, 40s) +# with a memory estimate of 388.02 MiB, over 8512429 allocations. +retries[] # 30 + +l = Dice.LogPrExpander(WMC(BDDCompiler([ + prob_equals(g, sample) + for sample in samples +]))) +@time begin + loss, actual_loss = sum( + begin + lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, sample))) + [lpr_eq * compute(a, lpr_eq), lpr_eq] + end + for sample in samples + ) +end +# 5.3s first run, 1.4s rest + +length(l.cache) # 331 + +@benchmark vals, derivs = differentiate( + rs.var_vals, + Derivs([loss => 1.]) +) +# BenchmarkTools.Trial: 1060 samples with 1 evaluation. +# Range (min … max): 2.029 ms … 137.030 ms ┊ GC (min … max): 0.00% … 98.14% +# Time (median): 2.879 ms ┊ GC (median): 0.00% +# Time (mean ± σ): 4.377 ms ± 6.119 ms ┊ GC (mean ± σ): 4.36% ± 4.07% + +# ██▇▆▅▃▃▂▃▁▁▂▂ ▁ +# ██████████████████▅▇▆▄▆▄▆▇▄▆▄▅▁▆▁▅▇▁▄▄▁▁▁▄▁▁▅▄▆▁▄▄▁▁▁▄▁▁▁▅▅ █ +# 2.03 ms Histogram: log(frequency) by time 22.6 ms < + +# Memory estimate: 292.17 KiB, allocs estimate: 8034. + +ct = Ref(0) +Dice.foreach_down(loss) do _ ct[] += 1 end +ct[] # 350 + diff --git a/qc/benchmarks/stlc_faster_200samples.jl b/qc/benchmarks/stlc_faster_200samples.jl index 8bea9e10..b8fe1fb0 100644 --- a/qc/benchmarks/stlc_faster_200samples.jl +++ b/qc/benchmarks/stlc_faster_200samples.jl @@ -1,8 +1,21 @@ +# We found sampling from the BDD, like sampling from the computation graph, also took ~160s for 200 well-typed samples. +# OTHER TIMINGS IN THIS FILE ARE WRONG +# mainly we care about `sample_one_as_dist_compile` in this file + using Revise using Dice using BenchmarkTools using ProfileView +function comp_graph_size(roots) + cmp_graph_ct = Ref(0) + Dice.foreach_down(roots) do _ + cmp_graph_ct[] += 1 + end + cmp_graph_ct[] # 2040 +end + + include("benchmarks.jl") generation_params = LangSiblingDerivedGenerator{STLC}( @@ -26,6 +39,42 @@ a = ADComputer(rs.var_vals) NUM_SAMPLES = 200 +function sample_one_as_dist_compile(c::BDDCompiler, a::ADComputer, d::Dist, roots) + # State for one sampling + bdd_node_to_tf = Dict{CuddNode,Bool}() + level_to_tf = Dict{Integer, Bool}() + bdd_node_to_tf[Dice.constant(c.mgr, true)] = true + bdd_node_to_tf[Dice.constant(c.mgr, false)] = false + + function sample_level(c, level::Integer) + get!(level_to_tf, level) do + prob = compute(a, c.level_to_flip[level].prob) + rand() < prob + end + end + + function sample_one(c, bdd_node_to_tf, x::AnyBool) + sample_one(c, bdd_node_to_tf, compile(c, x)) + end + + function sample_one(c, bdd_node_to_tf, x::CuddNode) + get!(bdd_node_to_tf, x) do + if sample_level(c, Dice.level(x)) + sample_one(c, bdd_node_to_tf, Dice.high(x)) + else + sample_one(c, bdd_node_to_tf, Dice.low(x)) + end + end + end + + bits = Dict() + for root in roots + bits[root] = sample_one(c, bdd_node_to_tf, root) + end + Dice.frombits_as_dist(d, bits) +end + + function wellTyped(e) @assert isdeterministic(e) @match typecheck(e) [ @@ -37,22 +86,48 @@ end retries = Ref(0) #== @benchmark ==# @time begin samples = [] + d = g + roots = Dice.tobits(d) + c = BDDCompiler(roots) + a = Dice.ADComputer(rs.var_vals) while length(samples) < NUM_SAMPLES retries[] += 1 - s = sample_as_dist(rs.rng, a, g) + s = sample_one_as_dist_compile(c, a, d, roots) if wellTyped(s) push!(samples, s) end end end -# 174s, 155s +# 174s, 155s, 281 retries[] # 607, 556 -@time l = Dice.LogPrExpander(WMC(BDDCompiler([ +@time eqs = [ prob_equals(g, sample) for sample in samples -]))) -# 32s, 36s +] +# 27 sec + +comp_graph_size(eqs) # 2040 +comp_graph_size(Dice.tobits(g)) # 16825 + +# @benchmark prob_equals(g, samples[1]) +# BenchmarkTools.Trial: 24 samples with 1 evaluation. +# Range (min … max): 104.068 ms … 592.294 ms ┊ GC (min … max): 0.00% … 0.00% +# Time (median): 168.251 ms ┊ GC (median): 0.00% +# Time (mean ± σ): 186.669 ms ± 97.022 ms ┊ GC (mean ± σ): 0.00% ± 0.00% +# ▃▃█ ▃▃▃ ▃▃ +# ▇███▁▁███▇██▁▇▇▁▁▇▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁ +# 104 ms Histogram: frequency by time 592 ms < +# Memory estimate: 6.51 MiB, allocs estimate: 235681. + +@time c = BDDCompiler(eqs) +# 0.28553 s, 0.23, 0.1 + +@time w = WMC(c) +# instant + +@time l = Dice.LogPrExpander(w) +# instant @time loss, actual_loss = sum( begin diff --git a/qc/benchmarks/stlc_faster_200samples_from_bdd.jl b/qc/benchmarks/stlc_faster_200samples_from_bdd.jl new file mode 100644 index 00000000..24f247ab --- /dev/null +++ b/qc/benchmarks/stlc_faster_200samples_from_bdd.jl @@ -0,0 +1,100 @@ +# We found sampling from the BDD, like sampling from the computation graph, also took ~160s for 200 well-typed samples. +# OTHER TIMINGS IN THIS FILE ARE WRONG +# mainly we care about `sample_one_as_dist_compile` in this file + +using Revise +using Dice +using BenchmarkTools +using ProfileView + +function comp_graph_size(roots) + cmp_graph_ct = Ref(0) + Dice.foreach_down(roots) do _ + cmp_graph_ct[] += 1 + end + cmp_graph_ct[] # 2040 +end + + +include("benchmarks.jl") + +generation_params = LangSiblingDerivedGenerator{STLC}( + root_ty=Expr.t, + ty_sizes=[Expr.t=>5, Typ.t=>2], + stack_size=2, + intwidth=3, +) + +SEED = 0 +out_dir="/tmp" +log_path="/dev/null" +rs = RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params) + +generation::Generation = generate(rs, generation_params) + +g::Dist = generation.value + +# Sample some tree until it's valid (TODO: make this faster) +a = ADComputer(rs.var_vals) + +NUM_SAMPLES = 200 + +function sample_one_as_dist_compile(c::BDDCompiler, a::ADComputer, d::Dist, roots) + # State for one sampling + bdd_node_to_tf = Dict{CuddNode,Bool}() + level_to_tf = Dict{Integer, Bool}() + bdd_node_to_tf[Dice.constant(c.mgr, true)] = true + bdd_node_to_tf[Dice.constant(c.mgr, false)] = false + + function sample_level(c, level::Integer) + get!(level_to_tf, level) do + prob = compute(a, c.level_to_flip[level].prob) + rand() < prob + end + end + + function sample_one(c, bdd_node_to_tf, x::AnyBool) + sample_one(c, bdd_node_to_tf, compile(c, x)) + end + + function sample_one(c, bdd_node_to_tf, x::CuddNode) + get!(bdd_node_to_tf, x) do + if sample_level(c, Dice.level(x)) + sample_one(c, bdd_node_to_tf, Dice.high(x)) + else + sample_one(c, bdd_node_to_tf, Dice.low(x)) + end + end + end + + bits = Dict() + for root in roots + bits[root] = sample_one(c, bdd_node_to_tf, root) + end + Dice.frombits_as_dist(d, bits) +end + + +function wellTyped(e) + @assert isdeterministic(e) + @match typecheck(e) [ + Some(_) -> true, + None() -> false, + ] +end + +retries = Ref(0) +#== @benchmark ==# @time begin + samples = [] + d = g + roots = Dice.tobits(d) + c = BDDCompiler(roots) + a = Dice.ADComputer(rs.var_vals) + while length(samples) < NUM_SAMPLES + retries[] += 1 + s = sample_one_as_dist_compile(c, a, d, roots) + if wellTyped(s) + push!(samples, s) + end + end +end \ No newline at end of file diff --git a/src/dist/inductive/inductive.jl b/src/dist/inductive/inductive.jl index 570dec69..32a42783 100644 --- a/src/dist/inductive/inductive.jl +++ b/src/dist/inductive/inductive.jl @@ -72,6 +72,35 @@ function Base.match(x::DistTaggedUnion, branches::Vector{Function}) res end +# type t = Left of int | Right of string + +# a = if f1 +# Left 5 +# else +# Right "a" + +# b = if f2 +# Right "b" +# else +# Left 5 + +# a = DistTaggedUnion( +# which=if f1 Left else Right, +# dists=[ +# [if f1 5 else ???] +# [if f1 ??? else "a"] +# ] +# ) + +# Cons, [????, [0, ...]] +# Nil, [[] , ????] + +# i <- [Nil, Cons] +# a <- [????, [0, ...]] +# b <- [[] , ????] + + + # Note: this requires that the "which" index of both unions are equal # (Left 1 != Right 1) function prob_equals(x::DistTaggedUnion, y::DistTaggedUnion)