Skip to content

Commit

Permalink
train generator w entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Dec 27, 2023
1 parent 7cd0563 commit ae19e01
Show file tree
Hide file tree
Showing 34 changed files with 10,695 additions and 5 deletions.
170 changes: 170 additions & 0 deletions examples/qc/stlc/entropy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Train an STLC generator to have some property (term size or num apps) match a
# specific distribution (linear or uniform).
#
# Saves the distributions and sampled terms before and after training.

using Dice
include("lib/dist.jl")
include("lib/util.jl")
include("lib/generator.jl")

############################
# Config
############################

# Specify generator, initial & target distributions
METRIC = "entropy"
INIT_SIZE = 2 # size passed to top call of gen_expr
GEN_TYP_SIZE = 1 # size passed to all calls of gen_type

# Hyperparams
PARAMETERIZE_FLIP_GROUPS_BY_SZ = true # whether flips at the same code location
# but different sizes can have different
# probabilities
EPOCHS = 2000 # epochs to train for

LOG_TO_FILE = true

TAG = "v3_opt_meta_ad"

############################

# Corresponds to "problem" - generator we are trying to train & desired dist.
# Data within a directory would get plotted on the same graph
OUT_DIR = "examples/qc/stlc/output/$(TAG)/$(METRIC)/sz=$(INIT_SIZE),tysz=$(GEN_TYP_SIZE)"

# Hyperparams
OUT_FILE_TAG = "param_by_sz=$(PARAMETERIZE_FLIP_GROUPS_BY_SZ),epochs=$(EPOCHS)"

############################
# Intro
############################

LOG_PATH = joinpath(OUT_DIR, "log_" * OUT_FILE_TAG * ".log")
LEARNING_CURVE_PATH = joinpath(OUT_DIR, "learning_curve_" * OUT_FILE_TAG * ".csv")

mkpath(OUT_DIR)
io = if LOG_TO_FILE
open(LOG_PATH, "w")
else
stdout
end

using Dates
t = now()
for io′ in Set([io, stdout])
println(io′, t)
println(io′, "== Config ==")
println(io′, "INIT_SIZE: $(INIT_SIZE)")
println(io′, "GEN_TYP_SIZE: $(GEN_TYP_SIZE)")
println(io′, "PARAMETERIZE_FLIP_GROUPS_BY_SZ: $(PARAMETERIZE_FLIP_GROUPS_BY_SZ)")
println(io′, "EPOCHS: $(EPOCHS)")
println(io′, "DistNat: $(DistNat)")
println(io′, "TAG: $(TAG)")
println(io′)
end
if LOG_TO_FILE
println("Logging to $(LOG_PATH)")
println()
end

var_vals = Valuation()
adnodes_of_interest = Dict{String, ADNode}()
function register_weight!(s)
var = Var("$(s)_before_sigmoid")
var_vals[var] = 0
weight = sigmoid(var)
adnodes_of_interest[s] = weight
weight
end

println_flush(io, "Building (gen_expr(...)) computation graph...")
time_build = @elapsed begin
e = gen_expr(
DistNil(DistI{Typ}),
gen_type(GEN_TYP_SIZE, PARAMETERIZE_FLIP_GROUPS_BY_SZ),
INIT_SIZE,
GEN_TYP_SIZE,
PARAMETERIZE_FLIP_GROUPS_BY_SZ
)
end
println(io, " $(time_build) seconds")
println(io)


############################
# Before
############################

println(io, "Initial adnodes_of_interest:")
vals = compute(var_vals, values(adnodes_of_interest))
show(io, Dict(s => vals[adnode] for (s, adnode) in adnodes_of_interest))
println(io)


println_flush(io, "Saving samples...")
time_sample_init = @elapsed with_concrete_ad_flips(var_vals, e) do
save_samples(joinpath(OUT_DIR, "terms_before.txt"), e; io=io)
end
println(io, " $(time_sample_init) seconds")
println(io)

to_id = Dict(
"Var" => DistUInt32(1),
"Boolean" => DistUInt32(2),
"App" => DistUInt32(3),
"Abs" => DistUInt32(4),
)

function collect_terminals(e)
match(e, [
"Var" => (i) -> DistVector([to_id["Var"]]),
"Boolean" => (b) -> DistVector([to_id["Boolean"]]),
"App" => (f, x) -> prob_append(prob_extend(collect_terminals(f), collect_terminals(x)), to_id["App"]),
"Abs" => (ty, e′) -> prob_append(collect_terminals(e′), to_id["Abs"]),
])
end
random_term = match(e, [
"None" => () -> DistNone(DistUInt32),
"Some" => e -> DistSome(choice(collect_terminals(e)))
])
loss = neg_entropy(random_term, Set([DistSome(i) for i in values(to_id)]))

initial_entropy = compute_mixed(var_vals, -loss)
println(io, "Initial entropy: $(initial_entropy)")
println(io)

############################
# Train
############################

println_flush(io, "Training...")
time_train = @elapsed learning_curve = train!(var_vals, loss; epochs=EPOCHS, learning_rate=0.003)
println(io, " $(time_train) seconds")
println(io)

open(LEARNING_CURVE_PATH, "w") do file
for (epoch, logpr) in zip(0:EPOCHS, learning_curve)
println(file, "$(epoch)\t$(logpr)")
end
end

############################
# After
############################

final_entropy = compute_mixed(var_vals, -loss)
println(io, "Final entropy: $(final_entropy)")
println(io)

println(io, "Learned adnodes_of_interest:")
vals = compute(var_vals, values(adnodes_of_interest))
show(io, Dict(s => vals[adnode] for (s, adnode) in adnodes_of_interest))
println(io)

println(io, "Saving samples...")
time_sample_final = @elapsed with_concrete_ad_flips(var_vals, e) do
save_samples(joinpath(OUT_DIR, "terms_trained_" * OUT_FILE_TAG * ".txt"), e; io=io)
end
println(io, " $(time_sample_final) seconds")
println(io)
2 changes: 1 addition & 1 deletion examples/qc/stlc/lib/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ end

function bind_opt(f, ma::DistI{Opt{T}})::DistI{<:Opt{<:Any}} where T
match(ma, [
"None" => () -> DistNone(T)
"None" => () -> DistNone(T) # TODO: should be DistNone(return type of f)
"Some" => f
])
end
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
num_apps probability
0 0.3784722222222223
1 0.10286458333333337
2 0.1511501736111112
3 0.14680989583333343
4 0.11425781250000003
5 0.06738281250000001
6 0.03125
7 0.007812500000000002
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
num_apps probability
0 0.16486577196496097
1 0.047180881871368555
2 0.1088898526593344
3 0.15080968983713253
4 0.17234621280449572
5 0.16286665359517002
6 0.10749048574248589
7 0.08555045152505174
Loading

0 comments on commit ae19e01

Please sign in to comment.