Skip to content

ProbProg: Sample & Generate op #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,22 @@ function compile_mlir!(
),
"only_enzyme",
)
elseif optimize === :probprog
run_pass_pipeline!(
mod,
join(
[
"mark-func-memory-effects",
"enzyme-batch",
"probprog",
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
],
',',
),
"probprog",
)
elseif optimize === :only_enzyme
run_pass_pipeline!(
mod,
Expand Down
132 changes: 132 additions & 0 deletions src/ProbProg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
module ProbProg

using ..Reactant: Reactant, XLA, MLIR, TracedUtils
using ReactantCore: ReactantCore

using Enzyme

@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
argprefix::Symbol = gensym("generatearg")
resprefix::Symbol = gensym("generateresult")
resargprefix::Symbol = gensym("generateresarg")

mlir_fn_res = TracedUtils.make_mlir_fn(
f,
args,
(),
string(f),
false;
args_in_result=:result_and_mutated,
argprefix,
resprefix,
resargprefix,
)
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
fnwrap = mlir_fn_res.fnwrapped
func2 = mlir_fn_res.f

out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))

batch_inputs = MLIR.IR.Value[]
for a in linear_args
idx, path = TracedUtils.get_argidx(a, argprefix)
if idx == 1 && fnwrap
TracedUtils.push_val!(batch_inputs, f, path[3:end])
else
if fnwrap
idx -= 1
end
TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
end
end

gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname)

for (i, res) in enumerate(linear_results)
resv = MLIR.IR.result(gen_op, i)
for path in res.paths
isempty(path) && continue
if path[1] == resprefix
TracedUtils.set!(result, path[2:end], resv)
elseif path[1] == argprefix
idx = path[2]::Int
if idx == 1 && fnwrap
TracedUtils.set!(f, path[3:end], resv)
else
if fnwrap
idx -= 1
end
TracedUtils.set!(args[idx], path[3:end], resv)
end
end
end
end

return result
end

@noinline function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
argprefix::Symbol = gensym("samplearg")
resprefix::Symbol = gensym("sampleresult")
resargprefix::Symbol = gensym("sampleresarg")

mlir_fn_res = TracedUtils.make_mlir_fn(
f,
args,
(),
string(f),
false;
args_in_result=:result_and_mutated,
argprefix,
resprefix,
resargprefix,
)
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
fnwrap = mlir_fn_res.fnwrapped
func2 = mlir_fn_res.f

batch_inputs = MLIR.IR.Value[]
for a in linear_args
idx, path = TracedUtils.get_argidx(a, argprefix)
if idx == 1 && fnwrap
TracedUtils.push_val!(batch_inputs, f, path[3:end])
else
idx -= fnwrap ? 1 : 0
TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
end
end

out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]

sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))

sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr)

for (i, res) in enumerate(linear_results)
resv = MLIR.IR.result(sample_op, i)

for path in res.paths
isempty(path) && continue
if path[1] == resprefix
TracedUtils.set!(result, path[2:end], resv)
elseif path[1] == argprefix
idx = path[2]::Int
if idx == 1 && fnwrap
TracedUtils.set!(f, path[3:end], resv)
else
if fnwrap
idx -= 1
end
TracedUtils.set!(args[idx], path[3:end], resv)
end
end
end
end

return result
end

end
1 change: 1 addition & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ include("stdlibs/Base.jl")

# Other Integrations
include("Enzyme.jl")
include("ProbProg.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

Expand Down
64 changes: 64 additions & 0 deletions test/probprog/generate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using Reactant, Test, Random, StableRNGs, Statistics
using Reactant: ProbProg

normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)

function generate_model(seed, μ, σ, shape)
function model(seed, μ, σ, shape)
rng = Random.default_rng()
Random.seed!(rng, seed)
s = ProbProg.sample!(normal, rng, μ, σ, shape)
t = ProbProg.sample!(normal, rng, s, σ, shape)
return t
end

return ProbProg.generate(model, seed, μ, σ, shape)
end

@testset "Generate" begin
@testset "normal_deterministic" begin
shape = (10000,)
seed1 = Reactant.to_rarray(UInt64[1, 4])
seed2 = Reactant.to_rarray(UInt64[1, 4])
μ1 = Reactant.ConcreteRArray(0.0)
μ2 = Reactant.ConcreteRArray(1000.0)
σ1 = Reactant.ConcreteRArray(1.0)
σ2 = Reactant.ConcreteRArray(1.0)

model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape)

@test Array(model_compiled(seed1, μ1, σ1, shape)) ≈
Array(model_compiled(seed1, μ1, σ1, shape))
@test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol =
0.05
@test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol =
0.05
@test !(all(
Array(model_compiled(seed1, μ1, σ1, shape)) .≈
Array(model_compiled(seed2, μ2, σ2, shape)),
))
end
@testset "normal_hlo" begin
shape = (10000,)
seed = Reactant.to_rarray(UInt64[1, 4])
μ = Reactant.ConcreteRArray(0.0)
σ = Reactant.ConcreteRArray(1.0)

before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape)
@test contains(repr(before), "enzyme.generate")
@test contains(repr(before), "enzyme.sample")

after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape)
@test !contains(repr(after), "enzyme.generate")
@test !contains(repr(after), "enzyme.sample")
end

@testset "normal_generate" begin
shape = (10000,)
seed = Reactant.to_rarray(UInt64[1, 4])
μ = Reactant.ConcreteRArray(0.0)
σ = Reactant.ConcreteRArray(1.0)
X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape))
@test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05
end
end
50 changes: 50 additions & 0 deletions test/probprog/sample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Reactant, Test, Random, StableRNGs, Statistics
using Reactant: ProbProg

@noinline normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)

function sample1(seed, μ, σ, shape)
function model(seed, μ, σ, shape)
rng = Random.default_rng()
Random.seed!(rng, seed)
s = ProbProg.sample!(normal, rng, μ, σ, shape)
return s
end

return ProbProg.generate(model, seed, μ, σ, shape)
end

function sample2(seed, μ, σ, shape)
function model(seed, μ, σ, shape)
rng = Random.default_rng()
Random.seed!(rng, seed)
s = ProbProg.sample!(normal, rng, μ, σ, shape)
t = ProbProg.sample!(normal, rng, μ, σ, shape)
return t
end

return ProbProg.generate(model, seed, μ, σ, shape)
end

@testset "test" begin
@testset "sample_hlo" begin
shape = (10,)
seed = Reactant.to_rarray(UInt64[1, 4])
μ = Reactant.ConcreteRArray(0.0)
σ = Reactant.ConcreteRArray(1.0)
before = @code_hlo optimize = false sample2(seed, μ, σ, shape)
@test contains(repr(before), "enzyme.sample")
after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape)
@test !contains(repr(after), "enzyme.sample")
end

@testset "sample_normal" begin
shape = (10,)
seed = Reactant.to_rarray(UInt64[1, 4])
μ = Reactant.ConcreteRArray(0.0)
σ = Reactant.ConcreteRArray(1.0)
X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape))
Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape))
@test !all(X .≈ Y)
end
end
Loading