Skip to content

Commit

Permalink
unique curves
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Jan 3, 2025
1 parent ad52aa4 commit eee3019
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDD = "345a2cc7-28d8-58b2-abdf-cff77ea7d7f1"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438"
DirectedAcyclicGraphs = "1e6dae5e-d6e2-422d-9af3-452e7a3785ee"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
50 changes: 47 additions & 3 deletions qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ include("lib/lib.jl")
using Plots
using Random
using Infiltrator
using DataStructures
using Debugger

ENV["GKSwstype"] = "100" # prevent plots from displaying

Expand Down Expand Up @@ -86,6 +88,7 @@ function run_benchmark(
end

if any(isinf(vals[loss]) || isnan(vals[loss]) for loss in losses)
println_loud(rs, "stopping in epoch $(epoch)")
break
end

Expand Down Expand Up @@ -118,6 +121,18 @@ function run_benchmark(
end

if m isa FeatureSpecEntropyLossMgr
pres, posts = first(m.feature_unique_curve_history), last(m.feature_unique_curve_history)
name = "unique_curves_" * join(to_subpath(loss_config), "_")
open(joinpath(out_dir, "$(name).csv"), "w") do file
xs = 1:length(pres)
for (num_samples, pre, post) in zip(xs, pres, posts)
println(file, "$(num_samples)\t$(pre)\t$(post)")
end
plot(xs, pres, label="Initial", color=:blue, xlabel="Number of samples", ylabel="Count", title="STLC: Cumulative unique types during sampling", legend=:topright)
plot!(xs, posts, label="Trained", color=:red)
savefig(joinpath(out_dir, "$(name).svg"))
end

d = []
for (k, cts) in m.feature_counts_history
ctor, args = k
Expand Down Expand Up @@ -419,8 +434,10 @@ mutable struct FeatureSpecEntropyLossMgr <: LossMgr
current_actual_loss::Union{Nothing,ADNode}
current_samples
feature_counts_history
feature_unique_curve_history
epochs_history # per resampling
num_resamples
FeatureSpecEntropyLossMgr(p, val, consider) = new(p, val, consider, [], nothing, nothing, nothing, Dict(), 0)
FeatureSpecEntropyLossMgr(p, val, consider) = new(p, val, consider, [], nothing, nothing, nothing, Dict(), [], [], 0)
end

function create_loss_manager(::RunState, p::FeatureSpecEntropy{T}, g::Generation) where T
Expand All @@ -432,13 +449,33 @@ function create_loss_manager(::RunState, p::FeatureSpecEntropy{T}, g::Generation
FeatureSpecEntropyLossMgr(p, g, consider)
end


function only_first_last!(v)
if length(v) > 2
a = first(v)
b = last(v)
empty!(v)
append!(v, [a, b])
end
end

function produce_loss(rs::RunState, m::FeatureSpecEntropyLossMgr, epoch::Integer)
if (epoch - 1) % m.p.resampling_frequency == 0
sampler = sample_from_lang(rs, m.generation.prog)
a = ADComputer(rs.var_vals)
samples = [to_dist(sampler()) for _ in 1:m.p.samples_per_batch]

feature_counts = counter(map(m.p.feature, samples))
feature_unique_curve = []
feature_counts = DefaultDict(0)
# counter(f, collection)
for s in samples
s_feature = m.p.feature(s)
if !haskey(feature_counts, s_feature)
feature_counts[s_feature] = 0
end
feature_counts[s_feature] += 1
push!(feature_unique_curve, length(feature_counts))
end

l = Dice.LogPrExpander(WMC(BDDCompiler([
prob_equals(m.generation.value,sample)
Expand All @@ -452,7 +489,10 @@ function produce_loss(rs::RunState, m::FeatureSpecEntropyLossMgr, epoch::Integer

lpr_eq = LogPr(prob_equals(m.generation.value, sample))
lpr_eq = Dice.expand_logprs(l, lpr_eq)
empirical_feature_logpr = Dice.Constant(log(feature_counts[m.p.feature(sample)]/length(samples)))
ct = feature_counts[m.p.feature(sample)]
@assert ct != 0
rat = ct/length(samples)
empirical_feature_logpr = Dice.Constant(log(rat))
if m.p.train_feature
[lpr_eq * empirical_feature_logpr, empirical_feature_logpr]
else
Expand All @@ -467,6 +507,7 @@ function produce_loss(rs::RunState, m::FeatureSpecEntropyLossMgr, epoch::Integer

# loss = Dice.expand_logprs(l, loss) / length(samples)
loss = loss / length(samples)
actual_loss = actual_loss / length(samples)
m.current_loss = loss
m.current_actual_loss = actual_loss
m.current_samples = samples
Expand All @@ -479,6 +520,9 @@ function produce_loss(rs::RunState, m::FeatureSpecEntropyLossMgr, epoch::Integer
for feature in keys(m.feature_counts_history)
push!(m.feature_counts_history[feature], feature_counts[feature])
end
push!(m.feature_unique_curve_history, feature_unique_curve)
only_first_last!(m.feature_unique_curve_history)
push!(m.epochs_history, epoch)
m.num_resamples += 1
end

Expand Down
4 changes: 2 additions & 2 deletions qc/benchmarks/tool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ if isempty(ARGS)
# l_p = [SpecEntropy{BST}(2,200,isBST)=>0.3]

# uniform type
# g_p = LangSiblingDerivedGenerator{STLC}(Main.Expr.t,Pair{Type,Integer}[Main.Expr.t=>5,Main.Typ.t=>2],2,3)
# l_p = [FeatureSpecEntropy{STLC}(2,200,wellTyped,typecheck_ft)=>0.3]
g_p = LangSiblingDerivedGenerator{STLC}(Main.Expr.t,Pair{Type,Integer}[Main.Expr.t=>5,Main.Typ.t=>2],2,3)
l_p = [FeatureSpecEntropy{STLC}(2,200,wellTyped,typecheck_ft,true)=>0.3]


push!(as, replace(string(g_p), " "=>""))
Expand Down

0 comments on commit eee3019

Please sign in to comment.