Skip to content

Commit

Permalink
fix resource exhaustion bug #190 (#219)
Browse files Browse the repository at this point in the history
* Disable multihreading on MLIR compiler

* test resource exhaustion bug #190

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Implement `Context` constructors with threading and registry options

* Refactor MLIR context initialization to streamline multithreading and registry handling

* fix `Context` constructor with registry

* Disable test due to long test time

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
mofeing and github-actions[bot] authored Nov 4, 2024
1 parent 92b11fb commit ede6274
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
6 changes: 2 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ function run_pass_pipeline!(mod, pass_pipeline)
end

function compile_mlir(f, args; kwargs...)
ctx = MLIR.IR.Context()
Base.append!(Reactant.registry[]; context=ctx)
ctx = MLIR.IR.Context(Reactant.registry[], false)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
MLIR.IR.context!(ctx) do
mod = MLIR.IR.Module(MLIR.IR.Location())
Expand Down Expand Up @@ -712,8 +711,7 @@ end

function compile_xla(f, args; client=nothing, optimize=true)
# register MLIR dialects
ctx = MLIR.IR.Context()
append!(Reactant.registry[]; context=ctx)
ctx = MLIR.IR.Context(Reactant.registry[], false)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

return MLIR.IR.context!(ctx) do
Expand Down
5 changes: 5 additions & 0 deletions src/mlir/IR/Context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ function Context(f::Core.Function)
end
end

Context(threading::Bool) = Context(API.mlirContextCreateWithThreading(threading))
function Context(registry, threading)
return Context(API.mlirContextCreateWithRegistry(registry, threading))
end

Base.convert(::Core.Type{API.MlirContext}, c::Context) = c.context

# Global state
Expand Down
3 changes: 1 addition & 2 deletions test/bcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ end
end

function test()
ctx = MLIR.IR.Context()
Base.append!(Reactant.registry[]; context=ctx)
ctx = MLIR.IR.Context(Reactant.registry[], false)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

MLIR.IR.context!(ctx) do
Expand Down
14 changes: 14 additions & 0 deletions test/compile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,18 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
@test y1 Float64.(a)
@test y2 Float32.(a)
end

# disabled due to long test time (core tests go from 2m to 7m just with this test)
# @testset "resource exhaustation bug (#190)" begin
# x = rand(2, 2)
# y = Reactant.to_rarray(x)
# @test try
# for _ in 1:10_000
# f = @compile sum(y)
# end
# true
# catch e
# false
# end
# end
end

0 comments on commit ede6274

Please sign in to comment.