Skip to content

Commit

Permalink
rbt table, or something
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Feb 11, 2025
1 parent 992f638 commit 0bb1bc4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 33 deletions.
11 changes: 8 additions & 3 deletions qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ end

alwaysTrue(t) = true
isRBT(t) = satisfies_bookkeeping_invariant(t) && satisfies_balance_invariant(t) && satisfies_order_invariant(t)
isRBTdist(t) = satisfies_bookkeeping_invariant(t) & satisfies_balance_invariant(t) & satisfies_order_invariant(t)
isBST(t) = satisfies_order_invariant(t)
function wellTyped(e::OptExpr.t)
@assert isdeterministic(e)
Expand Down Expand Up @@ -689,7 +690,7 @@ function produce_loss(rs::RunState, m::SpecEntropyLossMgr, epoch::Integer)
num_meeting += 1
lpr_eq = LogPr(prob_equals(m.generation.value, sample))
lpr_eq = Dice.expand_logprs(l, lpr_eq)
[lpr_eq * compute(a, lpr_eq), Dice.Constant(lpr_eq)]
[lpr_eq * compute(a, lpr_eq), lpr_eq]
else
[Dice.Constant(0), Dice.Constant(0)]
end
Expand All @@ -698,6 +699,7 @@ function produce_loss(rs::RunState, m::SpecEntropyLossMgr, epoch::Integer)
push!(m.num_meeting, num_meeting / length(samples))

loss = Dice.expand_logprs(l, loss) / length(samples)
actual_loss = Dice.expand_logprs(l, actual_loss) / length(samples)
m.current_loss = loss
m.current_actual_loss = actual_loss
m.current_samples = samples
Expand Down Expand Up @@ -1031,7 +1033,7 @@ end
struct SatisfyPropertyLoss{T} <: LossConfig{T}
property::Function
end
to_subpath(p::SatisfyPropertyLoss) = [nameof(p.property)]
to_subpath(p::SatisfyPropertyLoss) = ["$(p.property)"]
function create_loss_manager(rs::RunState, p::SatisfyPropertyLoss, generation)
println_flush(rs.io, "Building computation graph for $(p)...")
time_build_loss = @elapsed begin
Expand Down Expand Up @@ -1094,4 +1096,7 @@ function metric_loss(metric::Dist, ::Target4321)
BoolToMax(prob_equals(metric, DistUInt32(2)), weight=.2),
BoolToMax(prob_equals(metric, DistUInt32(3)), weight=.1),
])
end
end


always_true(_) = true
79 changes: 49 additions & 30 deletions qc/benchmarks/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,35 @@ GENERATION_PARAMS_LIST = [
# expr_size=5,
# typ_size=2,
# ),
LangSiblingDerivedGenerator{STLC}(
root_ty=Expr.t,
ty_sizes=[Expr.t=>5, Typ.t=>2],
stack_size=2,
intwidth=3,
)
# LangSiblingDerivedGenerator{RBT}(
# root_ty=ColorKVTree.t,
# ty_sizes=[ColorKVTree.t=>4, Color.t=>0],
# LangSiblingDerivedGenerator{STLC}(
# root_ty=Expr.t,
# ty_sizes=[Expr.t=>5, Typ.t=>2],
# stack_size=2,
# intwidth=3,
# ),
# )
LangSiblingDerivedGenerator{RBT}(
root_ty=ColorKVTree.t,
ty_sizes=[ColorKVTree.t=>4, Color.t=>0],
stack_size=2,
intwidth=3,
),
# LangSiblingDerivedGenerator{BST}(
# root_ty=KVTree.t,
# ty_sizes=[KVTree.t=>4],
# stack_size=2,
# intwidth=3,
# ),
]
LR_LIST = [0.3]
# LR_LIST = [0.03, 0.1, 0.3]

# SAMPLES_PER_BATCH_LIST = [200]
SAMPLES_PER_BATCH_LIST = [2000]
RESAMPLING_FREQUENCY_LIST = [1]
EPOCHS_LIST = [2000, 10000]
# LR_LIST = [0.3]
LR_LIST = [0.03, 0.1, 0.3]

BOUND_LIST = [0.]
SAMPLES_PER_BATCH_LIST = [200]
# SAMPLES_PER_BATCH_LIST = [2000]
RESAMPLING_FREQUENCY_LIST = [2]
EPOCHS_LIST = [2000]
BOUND_LIST = [0., 0.1]

PROPERTY_LIST = [nothing]
PROPERTY_LIST = [isRBT, always_true]

# TRAIN_FEATURE_LIST = [false, true]
TRAIN_FEATURE_LIST = [true]
Expand All @@ -53,21 +52,40 @@ println(n_runs)
@show BOUND_LIST
println()

LOSS_CONFIG_WEIGHT_PAIRS_LIST = collect(Iterators.flatten([

LOSS_CONFIG_WEIGHT_PAIRS_LIST = []

append!(LOSS_CONFIG_WEIGHT_PAIRS_LIST,
(
[
FeatureSpecEntropy{STLC}(resampling_frequency,samples_per_batch,wellTyped,typecheck_ft,train_feature) => lr,
# WeightedSpecEntropy{STLC}(resampling_frequency,samples_per_batch,wellTyped,inv_size) => lr,
# MLELossConfig{STLC}(num_apps, Uniform()) => lr,
# MLELossConfig{STLC}(size, Uniform()) => lr,
]
[SpecEntropy{RBT}(resampling_frequency,samples_per_batch,property) => lr]
for lr in LR_LIST
for property in PROPERTY_LIST
for resampling_frequency in RESAMPLING_FREQUENCY_LIST
for samples_per_batch in SAMPLES_PER_BATCH_LIST
for train_feature in TRAIN_FEATURE_LIST
),
]))
)
append!(LOSS_CONFIG_WEIGHT_PAIRS_LIST,
(
[SatisfyPropertyLoss{RBT}(isRBTdist) => lr]
for lr in LR_LIST
),
)

# LOSS_CONFIG_WEIGHT_PAIRS_LIST = collect(Iterators.flatten([
# (
# [
# FeatureSpecEntropy{STLC}(resampling_frequency,samples_per_batch,wellTyped,typecheck_ft,train_feature) => lr,
# # WeightedSpecEntropy{STLC}(resampling_frequency,samples_per_batch,wellTyped,inv_size) => lr,
# # MLELossConfig{STLC}(num_apps, Uniform()) => lr,
# # MLELossConfig{STLC}(size, Uniform()) => lr,
# ]
# for lr in LR_LIST
# for property in PROPERTY_LIST
# for resampling_frequency in RESAMPLING_FREQUENCY_LIST
# for samples_per_batch in SAMPLES_PER_BATCH_LIST
# for train_feature in TRAIN_FEATURE_LIST
# ),
# ]))

# LOSS_CONFIG_WEIGHT_PAIRS_LIST = begin
# lr = 0.03
Expand Down Expand Up @@ -115,15 +133,16 @@ TOOL_PATH = "qc/benchmarks/tool.jl"
p_s = replace(string(p), " "=>"")
s = "julia --project $(TOOL_PATH) $(flags) $(p_s) $(lcws_s) $(epochs) $(bound)"
cmd = Cmd(Cmd(convert(Vector{String}, split(s))), ignorestatus=true)
println(s)
s_hum = "julia --project $(TOOL_PATH) $(flags) \"$(p_s)\" \"$(lcws_s)\" $(epochs) $(bound)"
println(s_hum)
out = IOBuffer()
@async begin
proc = run(pipeline(cmd; stdout=out, stderr=stdout),)
if proc.exitcode != 0
println()
println(proc.exitcode)
so = String(take!(out))
println("FAILED: $(s)\nSTDOUT ===\n$(so)\n\n")
println("FAILED:\n$(s_hum)\nSTDOUT ===\n$(so)\n\n")
end
end
end
Expand Down
1 change: 1 addition & 0 deletions qc/benchmarks/tool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ TAG = "v109_unif_ty"
TAG = "v1110_weighted_se"
TAG = "v112_prettier_unif"
TAG = "v113_prettier_unif"
TAG = "v114_rbt_table"
OUT_TOP_DIR = joinpath(@__DIR__, "../../../tuning-output")

args = ARGS
Expand Down

0 comments on commit 0bb1bc4

Please sign in to comment.