Skip to content

Commit

Permalink
Include primal time in benchmark table (#449)
Browse files Browse the repository at this point in the history
* Include primal time in table

* Formatting

* Undo redundant change

* Fix to 3sf
  • Loading branch information
willtebbutt authored Jan 27, 2025
1 parent e7ce4ef commit 8d91ab9
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,19 @@ function plot_ratio_histogram!(df::DataFrame)
return histogram(df.Mooncake; xscale=:log10, xlim, bin, title="log", label="")
end

fix_sig_fig(t) = string.(round(t; sigdigits=3))

function format_time(t::Float64)
t < 1e-6 && return fix_sig_fig(t * 1e9) * " ns"
t < 1e-3 && return fix_sig_fig(t * 1e6) * " μs"
t < 1 && return fix_sig_fig(t * 1e3) * " ms"
return fix_sig_fig(t) * " s"
end

function create_inter_ad_benchmarks()
results = benchmark_inter_framework_rules()
tools = [:Mooncake, :Zygote, :ReverseDiff, :Enzyme]
df = DataFrame(results)[:, [:tag, tools...]]
df = DataFrame(results)[:, [:tag, :primal_time, tools...]]

# Plot graph of results.
plt = plot(; yscale=:log10, legend=:topright, title="AD Time / Primal Time (Log Scale)")
Expand All @@ -343,8 +352,9 @@ function create_inter_ad_benchmarks()
Plots.savefig(plt, "bench/benchmark_results.png")

# Write table of results.
formatted_cols = map(t -> t => string.(round.(df[:, t]; sigdigits=3)), tools)
df_formatted = DataFrame(:Label => df.tag, formatted_cols...)
formatted_ts = format_time.(df.primal_time)
formatted_cols = map(t -> t => fix_sig_fig.(df[:, t]), tools)
df_formatted = DataFrame(:Label => df.tag, :Primal => formatted_ts, formatted_cols...)
return open(
io -> pretty_table(io, df_formatted), "bench/benchmark_results.txt"; write=true
)
Expand Down

0 comments on commit 8d91ab9

Please sign in to comment.