Replies: 2 comments 4 replies
-
Example for the radon modelimport pymc as pm
import numpy as np
import pandas as pd
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
# TODO should be a CenteredNormal
raw = pm.Normal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
# Should also be a CenteredNormal
raw = pm.Normal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic(
"county_floor_effect", raw * sd, dims="county"
)
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal(
"log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
)
f = model.logp_dlogp_function()
toposort = f._aesara_function.vm.fgraph.toposort()
dependencies = f._aesara_function.vm.fgraph.orderings()
dependencies = {k: set(v) for k, v in dependencies.items()}
n_workers = 8
evaluations = [[None] * n_workers]
try:
sched = scheduler(toposort, dependencies)
while True:
step_i = []
pad = 0
for i in range(n_workers):
task = sched.send(evaluations[-1][i])
if task is not None:
step_i.append(task)
else:
pad += 1
for i in range(pad):
step_i.append(sched.send(None))
evaluations.append(step_i)
except StopIteration:
pass
|
Beta Was this translation helpful? Give feedback.
0 replies
-
@ferrine IIUC this is only relevant for the C backend. JAX and NUMBA backends compile a single thunk/JIT graph that's evaluated as a monolith. JAX is also already multi-threaded internally. Since the goal is to deprecate the C backend I am not sure this is useful work. Did I misinterpret your idea? |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Was thinking a lot about execution schedulers. The implementation itself is few lines of code and relies on toposort with priorities (e.g. compute intensity)
I was thinking quite a lot about this this day. Given we can get thunks and their dependency graph, we can apply this kind of scheduling where we can have multiple workers in different threads. They contribute with the results and scheduler decides on the next task.
awaiting
list allows to keep long running jobs for some workers and eagerly fetch new tasks from the compute graph.Another use case for this scheduler is to emulate workers and get nodes that could be computed independently in batches to see if we can fuse them
Beta Was this translation helpful? Give feedback.
All reactions