Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TaskLocalValues #2075

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,10 @@ steps:
using Pkg

println("--- :julia: Instantiating project")
Pkg.resolve()
Pkg.instantiate()
Pkg.activate("perf")
Pkg.resolve()
Pkg.instantiate()
push!(LOAD_PATH, @__DIR__)

Expand Down
11 changes: 6 additions & 5 deletions lib/cublas/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import LLVM
using LLVM.Interop: assume

using CEnum: @cenum

using TaskLocalValues

# core library
include("libcublas.jl")
Expand Down Expand Up @@ -73,14 +73,15 @@ end
const idle_handles = HandleCache{CuContext,cublasHandle_t}()
const idle_xt_handles = HandleCache{Any,cublasXtHandle_t}()

const LIBRARY_STATE = @NamedTuple{handle::cublasHandle_t, stream::CuStream, math_mode::CUDA.MathMode}
const CUBLAS_STATE =
TaskLocalValue{Dict{CuContext,LibraryState}}(()-> Dict{CuContext,LibraryState}())

function handle()
cuda = CUDA.active_state()

# every task maintains library state per device
LibraryState = @NamedTuple{handle::cublasHandle_t, stream::CuStream, math_mode::CUDA.MathMode}
states = get!(task_local_storage(), :CUBLAS) do
Dict{CuContext,LibraryState}()
end::Dict{CuContext,LibraryState}
states = CUBLAS_STATE[]

# get library state
@noinline function new_state(cuda)
Expand Down
74 changes: 25 additions & 49 deletions lib/cudadrv/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,33 +65,20 @@ function validate_task_local_state(state::TaskLocalState)
return state
end

# get or create the task local state, and make sure it's valid
function task_local_state!(args...)
tls = task_local_storage()
if haskey(tls, :CUDA)
validate_task_local_state(@inbounds(tls[:CUDA])::TaskLocalState)
else
# verify that CUDA.jl is functional. this doesn't belong here, but since we can't
# error during `__init__`, we do it here instead as this is the first function
# that's likely executed when using CUDA.jl
@assert functional(true)
const CUDA_STATE = TaskLocalValue{TaskLocalState}() do
# verify that CUDA.jl is functional. this doesn't belong here, but since we can't
# error during `__init__`, we do it here instead as this is the first function
# that's likely executed when using CUDA.jl
@assert functional(true)

tls[:CUDA] = TaskLocalState(args...)
end::TaskLocalState
return TaskLocalState()
end

# only get the task local state (it may be invalid!), or return nothing if unitialized
function task_local_state()
tls = task_local_storage()
if haskey(tls, :CUDA)
@inbounds(tls[:CUDA])
else
nothing
end::Union{TaskLocalState,Nothing}
end
# get or create the task local state, and make sure it's valid
task_local_state() = validate_task_local_state(CUDA_STATE[])

@inline function prepare_cuda_state()
state = task_local_state!()
state = task_local_state()

# NOTE: current_context() is too slow to use here (taking a lock, accessing a dict)
# so we use the raw handle. is that safe though, when we reset the device?
Expand All @@ -109,7 +96,7 @@ end
# without querying task local storage multiple times
@inline function active_state()
# inline to remove unused state properties
state = task_local_state!()
state = task_local_state()
return (device=state.device, context=state.context, stream=stream(state),
math_mode=state.math_mode, math_precision=state.math_precision)
end
Expand All @@ -125,7 +112,7 @@ Get or create a CUDA context for the current thread (as opposed to
current thread).
"""
function context()
task_local_state!().context
task_local_state().context
end

"""
Expand All @@ -144,19 +131,12 @@ function context!(ctx::CuContext)
# NOTE: if we actually need to switch contexts, we eagerly activate it so that we can
# query its device (we normally only do so lazily in `prepare_cuda_state`)
state = task_local_state()
if state === nothing
old_ctx = nothing
old_ctx = state.context
if old_ctx != ctx
activate(ctx)
dev = current_device()
task_local_state!(dev, ctx)
else
old_ctx = state.context
if old_ctx != ctx
activate(ctx)
dev = current_device()
state.device = dev
state.context = ctx
end
state.device = dev
state.context = ctx
end

return old_ctx
Expand All @@ -169,7 +149,7 @@ end
try
f()
finally
if old_ctx !== nothing && old_ctx != ctx && isvalid(old_ctx)
if old_ctx != ctx && isvalid(old_ctx)
context!(old_ctx)
end
end
Expand All @@ -188,7 +168,7 @@ Get the CUDA device for the current thread, similar to how [`context()`](@ref) w
compared to [`current_context()`](@ref).
"""
function device()
task_local_state!().device
task_local_state().device
end

const __device_contexts = LazyInitialized{Vector{Union{Nothing,CuContext}}}()
Expand Down Expand Up @@ -286,12 +266,8 @@ function device!(dev::CuDevice, flags=nothing)
# switch contexts
ctx = context(dev)
state = task_local_state()
if state === nothing
task_local_state!(dev)
else
state.device = dev
state.context = ctx
end
state.device = dev
state.context = ctx
activate(ctx)

dev
Expand Down Expand Up @@ -349,7 +325,7 @@ deviceid(dev::CuDevice=device()) = Int(convert(CUdevice, dev))
## math mode

function math_mode!(mode::MathMode; precision=nothing)
state = task_local_state!()
state = task_local_state()

state.math_mode = mode
default_math_mode[] = mode
Expand All @@ -362,8 +338,8 @@ function math_mode!(mode::MathMode; precision=nothing)
return
end

math_mode() = task_local_state!().math_mode
math_precision() = task_local_state!().math_precision
math_mode() = task_local_state().math_mode
math_precision() = task_local_state().math_precision


## streams
Expand All @@ -373,7 +349,7 @@ math_precision() = task_local_state!().math_precision

Get the CUDA stream that should be used as the default one for the currently executing task.
"""
@inline function stream(state=task_local_state!())
@inline function stream(state=task_local_state())
# @inline so that it can be DCE'd when unused from active_state
devidx = deviceid(state.device)+1
@inbounds if state.streams[devidx] === nothing
Expand All @@ -396,14 +372,14 @@ end
end

function stream!(stream::CuStream)
state = task_local_state!()
state = task_local_state()
devidx = deviceid(state.device)+1
state.streams[devidx] = stream
return
end

function stream!(f::Function, stream::CuStream)
state = task_local_state!()
state = task_local_state()
devidx = deviceid(state.device)+1
old_stream = state.streams[devidx]
state.streams[devidx] = stream
Expand Down
10 changes: 6 additions & 4 deletions lib/cudnn/src/cuDNN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using CUDA: CUstream, libraryPropertyType
using CUDA: retry_reclaim, isdebug, initialize_context

using CEnum: @cenum
using TaskLocalValues

if CUDA.local_toolkit
using CUDA_Runtime_Discovery
Expand Down Expand Up @@ -65,14 +66,15 @@ end
# cache for created, but unused handles
const idle_handles = HandleCache{CuContext,cudnnHandle_t}()

const LibraryState = @NamedTuple{handle::cudnnHandle_t, stream::CuStream}
const cuDNN_STATE =
TaskLocalValue{Dict{CuContext,LibraryState}}(()-> Dict{CuContext,LibraryState}())

function handle()
cuda = CUDA.active_state()

# every task maintains library state per device
LibraryState = @NamedTuple{handle::cudnnHandle_t, stream::CuStream}
states = get!(task_local_storage(), :cuDNN) do
Dict{CuContext,LibraryState}()
end::Dict{CuContext,LibraryState}
states = cuDNN_STATE[]

# get library state
@noinline function new_state(cuda)
Expand Down
Loading