Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moar threadsafe moar better #101

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Manifest.toml
*.swp
112 changes: 66 additions & 46 deletions src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@
del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels?
add_msgs::Array{Any,1}
@atomic gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily
@atomic state::WorkerState
Copy link
Member

@gbaraldi gbaraldi Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if state is always read/written from inside a lock this doesn't need to be atomic as the lock should have the correct barriers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's guaranteed? From a cursory grep through cluster.jl I see plenty of reads outside of a lock.

c_state::Threads.Condition # wait for state changes, lock for state
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

r_stream::IO
w_stream::IO
Expand Down Expand Up @@ -134,7 +134,7 @@
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -144,8 +144,10 @@
end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
lock(w.c_state) do
@atomic w.state = state
notify(w.c_state; all=true)
end
end

function check_worker_state(w::Worker)
Expand All @@ -161,15 +163,16 @@
else
w.ct_time = time()
if myid() > w.id
t = @async exec_conn_func(w)
t = Threads.@spawn Threads.threadpool() exec_conn_func(w)
else
# route request via node 1
t = @async remotecall_fetch((p,to_id) -> remotecall_fetch(exec_conn_func, p, to_id), 1, w.id, myid())
t = Threads.@spawn Threads.threadpool() remotecall_fetch((p,to_id) -> remotecall_fetch(exec_conn_func, p, to_id), 1, w.id, myid())
end
errormonitor(t)
wait_for_conn(w)
end
end
return nothing
end

exec_conn_func(id::Int) = exec_conn_func(worker_from_id(id)::Worker)
Expand All @@ -191,9 +194,17 @@
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
T = Threads.@spawn Threads.threadpool() begin
sleep($timeout)
lock(w.c_state) do
notify(w.c_state; all=true)
end
end
errormonitor(T)
lock(w.c_state) do
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
end
nothing
end
Expand Down Expand Up @@ -247,7 +258,7 @@
else
sock = listen(interface, LPROC.bind_port)
end
errormonitor(@async while isopen(sock)
errormonitor(Threads.@spawn while isopen(sock)
client = accept(sock)
process_messages(client, client, true)
end)
Expand Down Expand Up @@ -279,7 +290,7 @@


function redirect_worker_output(ident, stream)
t = @async while !eof(stream)
t = Threads.@spawn while !eof(stream)
line = readline(stream)
if startswith(line, " From worker ")
# stdout's of "additional" workers started from an initial worker on a host are not available
Expand Down Expand Up @@ -318,7 +329,7 @@
leader = String[]
try
while ntries > 0
readtask = @async readline(io)
readtask = Threads.@spawn Threads.threadpool() readline(io)
yield()
while !istaskdone(readtask) && ((time_ns() - t0) < timeout)
sleep(0.05)
Expand Down Expand Up @@ -419,7 +430,7 @@

```julia
# On busy clusters, call `addprocs` asynchronously
t = @async addprocs(...)
t = Threads.@spawn addprocs(...)
```

```julia
Expand Down Expand Up @@ -485,20 +496,23 @@
# call manager's `launch` is a separate task. This allows the master
# process initiate the connection setup process as and when workers come
# online
t_launch = @async launch(manager, params, launched, launch_ntfy)
t_launch = Threads.@spawn Threads.threadpool() launch(manager, params, launched, launch_ntfy)

@sync begin
while true
if isempty(launched)
istaskdone(t_launch) && break
@async (sleep(1); notify(launch_ntfy))
Threads.@spawn Threads.threadpool() begin
sleep(1)
notify(launch_ntfy)
end
wait(launch_ntfy)
end

if !isempty(launched)
wconfig = popfirst!(launched)
let wconfig=wconfig
@async setup_launched_worker(manager, wconfig, launched_q)
Threads.@spawn Threads.threadpool() setup_launched_worker(manager, wconfig, launched_q)
end
end
end
Expand Down Expand Up @@ -578,7 +592,7 @@
wconfig.port = port

let wconfig=wconfig
@async begin
Threads.@spawn Threads.threadpool() begin

Check warning on line 595 in src/cluster.jl

View check run for this annotation

Codecov / codecov/patch

src/cluster.jl#L595

Added line #L595 was not covered by tests
pid = create_worker(manager, wconfig)
remote_do(redirect_output_from_additional_worker, frompid, pid, port)
push!(launched_q, pid)
Expand Down Expand Up @@ -645,7 +659,12 @@
# require the value of config.connect_at which is set only upon connection completion
for jw in PGRP.workers
if (jw.id != 1) && (jw.id < w.id)
(jw.state === W_CREATED) && wait(jw.c_state)
# wait for wl to join
if jw.state === W_CREATED
lock(jw.c_state) do
wait(jw.c_state)
end
end
push!(join_list, jw)
end
end
Expand All @@ -668,7 +687,12 @@
end

for wl in wlist
(wl.state === W_CREATED) && wait(wl.c_state)
lock(wl.c_state) do
if wl.state === W_CREATED
# wait for wl to join
wait(wl.c_state)

Check warning on line 693 in src/cluster.jl

View check run for this annotation

Codecov / codecov/patch

src/cluster.jl#L693

Added line #L693 was not covered by tests
end
end
push!(join_list, wl)
end
end
Expand Down Expand Up @@ -727,23 +751,21 @@
end

function check_master_connect()
timeout = worker_timeout() * 1e9
# If we do not have at least process 1 connect to us within timeout
# we log an error and exit, unless we're running on valgrind
if ccall(:jl_running_on_valgrind,Cint,()) != 0
return
end
@async begin
start = time_ns()
while !haskey(map_pid_wrkr, 1) && (time_ns() - start) < timeout
sleep(1.0)
end

if !haskey(map_pid_wrkr, 1)
print(stderr, "Master process (id 1) could not connect within $(timeout/1e9) seconds.\nexiting.\n")
exit(1)
errormonitor(
Threads.@spawn begin
timeout = worker_timeout()
if timedwait(() -> !haskey(map_pid_wrkr, 1), timeout) === :timed_out
print(stderr, "Master process (id 1) could not connect within $(timeout) seconds.\nexiting.\n")
exit(1)

Check warning on line 765 in src/cluster.jl

View check run for this annotation

Codecov / codecov/patch

src/cluster.jl#L764-L765

Added lines #L764 - L765 were not covered by tests
end
end
end
)
end


Expand Down Expand Up @@ -1028,13 +1050,13 @@

pids = vcat(pids...)
if waitfor == 0
t = @async _rmprocs(pids, typemax(Int))
t = Threads.@spawn Threads.threadpool() _rmprocs(pids, typemax(Int))

Check warning on line 1053 in src/cluster.jl

View check run for this annotation

Codecov / codecov/patch

src/cluster.jl#L1053

Added line #L1053 was not covered by tests
yield()
return t
else
_rmprocs(pids, waitfor)
# return a dummy task object that user code can wait on.
return @async nothing
return Threads.@spawn Threads.threadpool() nothing
end
end

Expand Down Expand Up @@ -1217,7 +1239,7 @@
@assert myid() == 1
@sync begin
for pid in pids
@async interrupt(pid)
Threads.@spawn Threads.threadpool() interrupt(pid)

Check warning on line 1242 in src/cluster.jl

View check run for this annotation

Codecov / codecov/patch

src/cluster.jl#L1242

Added line #L1242 was not covered by tests
end
end
end
Expand Down Expand Up @@ -1288,18 +1310,16 @@

using Random: randstring

let inited = false
# do initialization that's only needed when there is more than 1 processor
global function init_multi()
if !inited
inited = true
push!(Base.package_callbacks, _require_callback)
atexit(terminate_all_workers)
init_bind_addr()
cluster_cookie(randstring(HDR_COOKIE_LEN))
end
return nothing
# do initialization that's only needed when there is more than 1 processor
const inited = Threads.Atomic{Bool}(false)
function init_multi()
if !Threads.atomic_cas!(inited, false, true)
push!(Base.package_callbacks, _require_callback)
atexit(terminate_all_workers)
init_bind_addr()
cluster_cookie(randstring(HDR_COOKIE_LEN))
end
return nothing
end

function init_parallel()
Expand Down
4 changes: 2 additions & 2 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ function remotecall_eval(m::Module, procs, ex)
# execute locally last as we do not want local execution to block serialization
# of the request to remote nodes.
for _ in 1:run_locally
@async Core.eval(m, ex)
Threads.@spawn Threads.threadpool() Core.eval(m, ex)
end
end
nothing
Expand Down Expand Up @@ -275,7 +275,7 @@ function preduce(reducer, f, R)
end

function pfor(f, R)
t = @async @sync for c in splitrange(Int(firstindex(R)), Int(lastindex(R)), nworkers())
t = Threads.@spawn Threads.threadpool() @sync for c in splitrange(Int(firstindex(R)), Int(lastindex(R)), nworkers())
@spawnat :any f(R, first(c), last(c))
end
errormonitor(t)
Expand Down
4 changes: 2 additions & 2 deletions src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
# Wait for all launches to complete.
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
@async try
Threads.@spawn Threads.threadpool() try

Check warning on line 181 in src/managers.jl

View check run for this annotation

Codecov / codecov/patch

src/managers.jl#L181

Added line #L181 was not covered by tests
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
catch e
print(stderr, "exception launching on machine $(machine) : $(e)\n")
Expand Down Expand Up @@ -744,7 +744,7 @@
# First, try sending `exit()` to the remote over the usual control channels
remote_do(exit, pid)

timer_task = @async begin
timer_task = Threads.@spawn Threads.threadpool() begin
sleep(exit_timeout)

# Check to see if our child exited, and if not, send an actual kill signal
Expand Down
2 changes: 1 addition & 1 deletion src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
end
catch e
bt = catch_backtrace()
@async showerror(stderr, e, bt)
Threads.@spawn showerror(stderr, e, bt)

Check warning on line 203 in src/messages.jl

View check run for this annotation

Codecov / codecov/patch

src/messages.jl#L203

Added line #L203 was not covered by tests
end
end

Expand Down
14 changes: 7 additions & 7 deletions src/process_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
rv = RemoteValue(def_rv_channel())
(PGRP::ProcessGroup).refs[rid] = rv
push!(rv.clientset, rid.whence)
errormonitor(@async run_work_thunk(rv, thunk))
errormonitor(Threads.@spawn run_work_thunk(rv, thunk))
return rv
end
end
Expand Down Expand Up @@ -118,7 +118,7 @@

## message event handlers ##
function process_messages(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool=true)
errormonitor(@async process_tcp_streams(r_stream, w_stream, incoming))
errormonitor(Threads.@spawn process_tcp_streams(r_stream, w_stream, incoming))
end

function process_tcp_streams(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool)
Expand Down Expand Up @@ -148,7 +148,7 @@
See also [`cluster_cookie`](@ref).
"""
function process_messages(r_stream::IO, w_stream::IO, incoming::Bool=true)
errormonitor(@async message_handler_loop(r_stream, w_stream, incoming))
errormonitor(Threads.@spawn message_handler_loop(r_stream, w_stream, incoming))

Check warning on line 151 in src/process_messages.jl

View check run for this annotation

Codecov / codecov/patch

src/process_messages.jl#L151

Added line #L151 was not covered by tests
end

function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
Expand Down Expand Up @@ -283,7 +283,7 @@
schedule_call(header.response_oid, ()->invokelatest(msg.f, msg.args...; msg.kwargs...))
end
function handle_msg(msg::CallMsg{:call_fetch}, header, r_stream, w_stream, version)
errormonitor(@async begin
errormonitor(Threads.@spawn begin
v = run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), false)
if isa(v, SyncTake)
try
Expand All @@ -299,15 +299,15 @@
end

function handle_msg(msg::CallWaitMsg, header, r_stream, w_stream, version)
errormonitor(@async begin
errormonitor(Threads.@spawn begin
rv = schedule_call(header.response_oid, ()->invokelatest(msg.f, msg.args...; msg.kwargs...))
deliver_result(w_stream, :call_wait, header.notify_oid, fetch(rv.c))
nothing
end)
end

function handle_msg(msg::RemoteDoMsg, header, r_stream, w_stream, version)
errormonitor(@async run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), true))
errormonitor(Threads.@spawn run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), true))
end

function handle_msg(msg::ResultMsg, header, r_stream, w_stream, version)
Expand Down Expand Up @@ -350,7 +350,7 @@
# The constructor registers the object with a global registry.
Worker(rpid, ()->connect_to_peer(cluster_manager, rpid, wconfig))
else
@async connect_to_peer(cluster_manager, rpid, wconfig)
Threads.@spawn connect_to_peer(cluster_manager, rpid, wconfig)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ or to use a local [`Channel`](@ref) as a proxy:
```julia
p = 1
f = Future(p)
errormonitor(@async put!(f, remotecall_fetch(long_computation, p)))
errormonitor(Threads.@spawn put!(f, remotecall_fetch(long_computation, p)))
isready(f) # will not block
```
"""
Expand Down Expand Up @@ -322,7 +322,7 @@ function process_worker(rr)
msg = (remoteref_id(rr), myid())

# Needs to acquire a lock on the del_msg queue
T = Threads.@spawn begin
T = Threads.@spawn Threads.threadpool() begin
publish_del_msg!($w, $msg)
end
Base.errormonitor(T)
Expand Down
Loading
Loading