Skip to content

Commit f5715f7

Browse files
committed
Invalidate Distributed.create_worker to execute custom expression on initialization
1 parent 4029cd5 commit f5715f7

File tree

6 files changed

+215
-67
lines changed

6 files changed

+215
-67
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ ClimaCalibrate.submit_pbs_job
5252
ClimaCalibrate.initialize
5353
ClimaCalibrate.save_G_ensemble
5454
ClimaCalibrate.update_ensemble
55+
ClimaCalibrate.update_ensemble!
5556
ClimaCalibrate.ExperimentConfig
5657
ClimaCalibrate.get_prior
5758
ClimaCalibrate.get_param_dict

src/workers.jl

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using Logging
33

44
export SlurmManager, PBSManager, set_worker_loggers
55

6+
worker_timeout() = parse(Float64, get(ENV, "JULIA_WORKER_TIMEOUT", "300.0"))
7+
68
get_worker_pool() = workers() == [1] ? WorkerPool() : default_worker_pool()
79

810
function run_worker_iteration(
@@ -21,7 +23,7 @@ function run_worker_iteration(
2123
remotecall_wait(forward_model, w, iter, m)
2224
end
2325
end
24-
26+
isempty(all_known_workers.workers) && @info "No workers currently available"
2527
@sync while !isempty(work_to_do)
2628
# Add new workers to worker_pool
2729
all_workers = get_worker_pool()
@@ -40,7 +42,7 @@ function run_worker_iteration(
4042
push!(worker_pool, worker)
4143
end
4244
else
43-
println("no workers available")
45+
@debug "no workers available"
4446
sleep(10) # Wait for workers to become available
4547
end
4648
end
@@ -100,7 +102,6 @@ function Distributed.manage(
100102
)
101103
if op == :register
102104
set_worker_logger(id)
103-
evaluate_initial_expression(id, manager.expr)
104105
end
105106
end
106107

@@ -478,3 +479,148 @@ function set_worker_loggers(workers = workers())
478479
end
479480
end
480481
end
482+
483+
# Copied from Distributed.jl in order to evaluate the manager's expression on worker initialization
484+
function Distributed.create_worker(
485+
manager::Union{SlurmManager, PBSManager},
486+
wconfig,
487+
)
488+
# only node 1 can add new nodes, since nobody else has the full list of address:port
489+
@assert Distributed.LPROC.id == 1
490+
timeout = worker_timeout()
491+
492+
# initiate a connect. Does not wait for connection completion in case of TCP.
493+
w = Distributed.Worker()
494+
local r_s, w_s
495+
try
496+
(r_s, w_s) = Distributed.connect(manager, w.id, wconfig)
497+
catch ex
498+
try
499+
Distributed.deregister_worker(w.id)
500+
kill(manager, w.id, wconfig)
501+
finally
502+
rethrow(ex)
503+
end
504+
end
505+
506+
w = Distributed.Worker(w.id, r_s, w_s, manager; config = wconfig)
507+
# install a finalizer to perform cleanup if necessary
508+
finalizer(w) do w
509+
if myid() == 1
510+
Distributed.manage(w.manager, w.id, w.config, :finalize)
511+
end
512+
end
513+
514+
# set when the new worker has finished connections with all other workers
515+
ntfy_oid = Distributed.RRID()
516+
rr_ntfy_join = Distributed.lookup_ref(ntfy_oid)
517+
rr_ntfy_join.waitingfor = myid()
518+
519+
# Start a new task to handle inbound messages from connected worker in master.
520+
# Also calls `wait_connected` on TCP streams.
521+
Distributed.process_messages(w.r_stream, w.w_stream, false)
522+
523+
# send address information of all workers to the new worker.
524+
# Cluster managers set the address of each worker in `WorkerConfig.connect_at`.
525+
# A new worker uses this to setup an all-to-all network if topology :all_to_all is specified.
526+
# Workers with higher pids connect to workers with lower pids. Except process 1 (master) which
527+
# initiates connections to all workers.
528+
529+
# Connection Setup Protocol:
530+
# - Master sends 16-byte cookie followed by 16-byte version string and a JoinPGRP message to all workers
531+
# - On each worker
532+
# - Worker responds with a 16-byte version followed by a JoinCompleteMsg
533+
# - Connects to all workers less than its pid. Sends the cookie, version and an IdentifySocket message
534+
# - Workers with incoming connection requests write back their Version and an IdentifySocketAckMsg message
535+
# - On master, receiving a JoinCompleteMsg triggers rr_ntfy_join (signifies that worker setup is complete)
536+
537+
join_list = []
538+
if Distributed.PGRP.topology === :all_to_all
539+
# need to wait for lower worker pids to have completed connecting, since the numerical value
540+
# of pids is relevant to the connection process, i.e., higher pids connect to lower pids and they
541+
# require the value of config.connect_at which is set only upon connection completion
542+
for jw in Distributed.PGRP.workers
543+
if (jw.id != 1) && (jw.id < w.id)
544+
# wait for wl to join
545+
# We should access this atomically using (@atomic jw.state)
546+
# but this is only recently supported
547+
if jw.state === Distributed.W_CREATED
548+
lock(jw.c_state) do
549+
wait(jw.c_state)
550+
end
551+
end
552+
push!(join_list, jw)
553+
end
554+
end
555+
556+
elseif Distributed.PGRP.topology === :custom
557+
# wait for requested workers to be up before connecting to them.
558+
filterfunc(x) =
559+
(x.id != 1) &&
560+
isdefined(x, :config) &&
561+
(
562+
notnothing(x.config.ident) in
563+
something(wconfig.connect_idents, [])
564+
)
565+
566+
wlist = filter(filterfunc, Distributed.PGRP.workers)
567+
waittime = 0
568+
while wconfig.connect_idents !== nothing &&
569+
length(wlist) < length(wconfig.connect_idents)
570+
if waittime >= timeout
571+
error("peer workers did not connect within $timeout seconds")
572+
end
573+
sleep(1.0)
574+
waittime += 1
575+
wlist = filter(filterfunc, Distributed.PGRP.workers)
576+
end
577+
578+
for wl in wlist
579+
lock(wl.c_state) do
580+
if (@atomic wl.state) === Distributed.W_CREATED
581+
# wait for wl to join
582+
wait(wl.c_state)
583+
end
584+
end
585+
push!(join_list, wl)
586+
end
587+
end
588+
589+
all_locs = Base.mapany(
590+
x ->
591+
isa(x, Distributed.Worker) ?
592+
(something(x.config.connect_at, ()), x.id) : ((), x.id, true),
593+
join_list,
594+
)
595+
Distributed.send_connection_hdr(w, true)
596+
enable_threaded_blas = something(wconfig.enable_threaded_blas, false)
597+
598+
join_message = Distributed.JoinPGRPMsg(
599+
w.id,
600+
all_locs,
601+
Distributed.PGRP.topology,
602+
enable_threaded_blas,
603+
Distributed.isclusterlazy(),
604+
)
605+
Distributed.send_msg_now(
606+
w,
607+
Distributed.MsgHeader(Distributed.RRID(0, 0), ntfy_oid),
608+
join_message,
609+
)
610+
611+
# Ensure the initial expression is evaluated before any other code
612+
@info "Evaluating initial expression on worker $(w.id)"
613+
evaluate_initial_expression(w.id, manager.expr)
614+
615+
@async Distributed.manage(w.manager, w.id, w.config, :register)
616+
617+
# wait for rr_ntfy_join with timeout
618+
if timedwait(() -> isready(rr_ntfy_join), timeout) === :timed_out
619+
error("worker did not connect within $timeout seconds")
620+
end
621+
lock(Distributed.client_refs) do
622+
delete!(Distributed.PGRP.refs, ntfy_oid)
623+
end
624+
625+
return w.id
626+
end

test/hpc_backend.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,6 @@ if get_backend() == DerechoBackend
2424
hpc_kwargs[:gpus_per_task] = 1
2525
end
2626

27-
@testset "Restarts" begin
28-
initialize(ensemble_size, observation, variance, prior, output_dir)
29-
30-
last_iter = ClimaCalibrate.last_completed_iteration(output_dir)
31-
@test last_iter == -1
32-
ClimaCalibrate.run_worker_iteration(
33-
last_iter + 1,
34-
ensemble_size,
35-
output_dir,
36-
)
37-
G_ensemble = observation_map(last_iter + 1)
38-
save_G_ensemble(output_dir, last_iter + 1, G_ensemble)
39-
update_ensemble(output_dir, last_iter + 1, prior)
40-
41-
@test ClimaCalibrate.last_completed_iteration(output_dir) == 0
42-
end
43-
4427
eki = calibrate(experiment_config; model_interface, hpc_kwargs, verbose = true)
4528

4629
@test ClimaCalibrate.last_completed_iteration(output_dir) == n_iterations - 1

test/pbs_manager_unit_tests.jl

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Test, ClimaCalibrate, Distributed, Logging
2323
@test workers() == [1]
2424
end
2525

26-
@testset "PBSManager - multiple processes" begin
26+
@testset "Test PBSManager multiple tasks, output file" begin
2727
out_file = "pbs_unit_test.out"
2828
p = addprocs(
2929
PBSManager(2),
@@ -37,34 +37,6 @@ end
3737
@test workers() == p
3838
@test remotecall_fetch(+, p[1], 1, 1) == 2
3939

40-
@everywhere using ClimaCalibrate
41-
# Test function with no arguments
42-
p = workers()
43-
@test ClimaCalibrate.map_remotecall_fetch(myid) == p
44-
45-
# single argument
46-
x = rand(5)
47-
@test ClimaCalibrate.map_remotecall_fetch(identity, x) == fill(x, length(p))
48-
49-
# multiple arguments
50-
@test ClimaCalibrate.map_remotecall_fetch(+, 2, 3) == fill(5, length(p))
51-
52-
# Test specified workers list
53-
@test length(ClimaCalibrate.map_remotecall_fetch(myid; workers = p[1:2])) ==
54-
2
55-
56-
# Test with more complex data structure
57-
d = Dict("a" => 1, "b" => 2)
58-
@test ClimaCalibrate.map_remotecall_fetch(identity, d) == fill(d, length(p))
59-
60-
loggers = ClimaCalibrate.set_worker_loggers()
61-
@test length(loggers) == length(p)
62-
@test typeof(loggers) == Vector{Base.CoreLogging.SimpleLogger}
63-
64-
rmprocs(p)
65-
@test nprocs() == 1
66-
@test workers() == [1]
67-
6840
@test isfile(out_file)
6941
rm(out_file)
7042
end

test/slurm_manager_unit_tests.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,49 @@ using Test, ClimaCalibrate, Distributed, Logging
2727
# Test incorrect generic arguments
2828
@test_throws TaskFailedException p = addprocs(SlurmManager(1), time = "w")
2929
end
30+
31+
@testset "SlurmManager Initialization Expressions" begin
32+
p = addprocs(SlurmManager(1; expr = :(@info "test")))
33+
rmprocs(p)
34+
test_logger = TestLogger()
35+
with_logger(test_logger) do
36+
p = addprocs(SlurmManager(1; expr = :(w + 2)))
37+
rmprocs(p)
38+
end
39+
@test test_logger.logs[end].message == "Initial worker expression errored:"
40+
end
41+
42+
@testset "Test remotecall utilities" begin
43+
p = addprocs(SlurmManager(2))
44+
@test nprocs() == length(p) + 1
45+
@test workers() == p
46+
@test remotecall_fetch(+, p[1], 1, 1) == 2
47+
48+
@everywhere using ClimaCalibrate
49+
# Test function with no arguments
50+
p = workers()
51+
@test ClimaCalibrate.map_remotecall_fetch(myid) == p
52+
53+
# single argument
54+
x = rand(5)
55+
@test ClimaCalibrate.map_remotecall_fetch(identity, x) == fill(x, length(p))
56+
57+
# multiple arguments
58+
@test ClimaCalibrate.map_remotecall_fetch(+, 2, 3) == fill(5, length(p))
59+
60+
# Test specified workers list
61+
@test length(ClimaCalibrate.map_remotecall_fetch(myid; workers = p[1:2])) ==
62+
2
63+
64+
# Test with more complex data structure
65+
d = Dict("a" => 1, "b" => 2)
66+
@test ClimaCalibrate.map_remotecall_fetch(identity, d) == fill(d, length(p))
67+
68+
loggers = ClimaCalibrate.set_worker_loggers()
69+
@test length(loggers) == length(p)
70+
@test typeof(loggers) == Vector{Base.CoreLogging.SimpleLogger}
71+
72+
rmprocs(p)
73+
@test nprocs() == 1
74+
@test workers() == [1]
75+
end

test/worker_backend.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ include(
88
"utils.jl",
99
),
1010
)
11-
1211
# Expression to run on worker initialization, used instead of @everywhere
1312
expr = quote
13+
using ClimaCalibrate
1414
include(
1515
joinpath(
1616
pkgdir(ClimaCalibrate),
@@ -36,23 +36,6 @@ if nworkers() == 1
3636
end
3737

3838

39-
# @testset "Restarts" begin
40-
# initialize(ensemble_size, observation, variance, prior, output_dir)
41-
42-
# last_iter = ClimaCalibrate.last_completed_iteration(output_dir)
43-
# @test last_iter == -1
44-
# ClimaCalibrate.run_worker_iteration(
45-
# last_iter + 1,
46-
# ensemble_size,
47-
# output_dir,
48-
# )
49-
# G_ensemble = observation_map(last_iter + 1)
50-
# save_G_ensemble(output_dir, last_iter + 1, G_ensemble)
51-
# update_ensemble(output_dir, last_iter + 1, prior)
52-
53-
# @test ClimaCalibrate.last_completed_iteration(output_dir) == 0
54-
# end
55-
5639
eki = calibrate(
5740
WorkerBackend,
5841
ensemble_size,
@@ -78,3 +61,20 @@ convergence_plot(
7861
)
7962

8063
g_vs_iter_plot(eki)
64+
65+
@testset "Restarts" begin
66+
initialize(ensemble_size, observation, variance, prior, output_dir)
67+
68+
last_iter = ClimaCalibrate.last_completed_iteration(output_dir)
69+
@test last_iter == n_iterations - 1
70+
ClimaCalibrate.run_worker_iteration(
71+
last_iter + 1,
72+
ensemble_size,
73+
output_dir,
74+
)
75+
G_ensemble = observation_map(last_iter + 1)
76+
save_G_ensemble(output_dir, last_iter + 1, G_ensemble)
77+
update_ensemble(output_dir, last_iter + 1, prior)
78+
79+
@test ClimaCalibrate.last_completed_iteration(output_dir) == n_iterations
80+
end

0 commit comments

Comments
 (0)