This README covers a running example of training an AVICI model in causal discovery for a custom data-generating process. The components we provide here can be used as a starting point for a new project. As an illustrative example, we implement AVICI trained on SCMs with sinusoidal functions and random tree graphs.
This folder contains the following three files:
This file contains all custom classes that make up the generative model
of our domain.
All functions should be written using standard numpy
and not jax.numpy
. This is both faster and avoids conflicting
resource usage of the CPU workers that continually update the training data buffers
and the hardware accelerators used by jax
for the actual network training.
The provided func.py
implements two example classes for sampling random trees and SCMs with sinusoidal functions,
respectively, which are not implemented already in avici.synthetic
.
Each custom data-generating process must subclass one of the following two abstract base classes
and implement the __call__
function with the correct signatures:
Subclasses of GraphModel
implement functionality for sampling training graphs and
can be used to define (part of) the causal graph distribution p(G).
Each child class has to implement __call__
acceping two arguments:
- rng (np.random.Generator) – numpy pseudorandom number generator
- n_vars (int) – number of nodes in the graph
Returns:
- ndarray – binary adjacency matrix of shape
[n_vars, n_vars]
Example:
import numpy as onp
from avici.synthetic import GraphModel
class DummyGraph(GraphModel):
def __call__(self, rng, n_vars):
return onp.zeros((n_vars, n_vars))
Subclasses of MechanismModel
implement functionality for sampling observational
and interventional data given a causal graph. These classes can be used to define (part of)
the data-generating distribution p(D | G).
Each child class has to implement __call__
acceping four arguments:
- rng (np.random.Generator) – numpy pseudorandom number generator
- g (ndarray) – binary adjacency matrix of shape
[n_vars, n_vars]
as generated by aGraphModel
subclass - n_observations_obs (int) – number of observational data points to be sampled
- n_observations_int (int) – number of interventional data points to be sampled
Returns:
- avici.synthetic.Data – namedtuple containing
x_obs
,x_int
and booleanis_count_data
. The data matricesx_obs
andx_int
must have shapes[n_observations_obs, n_vars, 2]
and[n_observations_int, n_vars, 2]
, respectively. The first value in the last axis (i.e.x_int[..., 0]
) contains the values and the second axis (i.e.x_int[..., 1]
) contains either 0 or 1, indicating which nodes were intervened upon in which observations.
Accordingly,x_obs[..., 1]
has only zeros as it always contains observational data.
is_count_data
is used to determine how the data is standardized. (Default for all real-valued data should beFalse
, which implies the usual z-standardization.)
Example:
import numpy as onp
from avici.synthetic import MechanismModel, Data
class DummyMechanism(MechanismModel):
def __call__(self, rng, g, n_observations_obs, n_observations_int):
n_vars = g.shape[-1]
return Data(
x_obs=onp.zeros((n_observations_obs, n_vars, 2)),
x_int=onp.zeros((n_observations_int, n_vars, 2)),
is_count_data=False,
)
Both GraphModel
and MechanismModel
subclasses can be initialized with and
store an arbitrary number of arguments for later use inside __call__
,
like function parameters or other sampling functions.
For GraphModel
, this is also where additional details on the interventions ought to be specified,
e.g., how many nodes are intervened upon and in what fashion.
This YAML file is the configuration file that defines the distribution over datasets our structure learning model is trained on. The file has to be structured in the following way:
---
train_n_vars: [5, 10]
test_n_vars: [20]
test_n_datasets: 10
additional_modules:
- "./func.py"
data:
- n_observations_obs: 300
n_observations_int: 100
graph:
- __class__: ErdosRenyi
edges_per_var: [ 1.0, 2.0, 3.0 ]
mechanism:
- __class__: LinearAdditive
param:
- __class__: SignedUniform
low: 1.0
high: 3.0
bias: ...
noise: ...
noise_scale: ...
n_interv_vars: ...
interv_dist: ...
- ...
The top-level keywords specify the following:
- train_n_vars – list of integers specifying the numbers of variables in the causal graphs and datasets during training
- test_n_vars – list of integers specifying the numbers of variables used for validation
- test_n_datasets – number of validation datasets
- additional_modules – list of paths (relative or absolute) defining additional data-generating processes
(e.g., our
func.py
file) - data – nested combination of dicts and lists specifying the full data-generating distribution
The data entry specifies the distribution over training datasets. During training, we continually generate fresh data for data buffers of the different numbers of variables according to this distribution. The configuration of the data field maintains the following invariants:
-
If any (nested) part of the data tree is a list, one configuration of it is selected uniformly at random in each new sample. For example, in the above configuration, all graphs are Erdos-Renyi, in which the expected number of edges per node is either 1, 2, or 3, selected randomly for each new dataset. Internally, the nested dict of lists is expanded into a single list of all possible combinations of dicts, so be careful not to specify too many combinations (>1000).
-
Each (list) element in the top level of _data needs to specify:
graph
,mechanism,
n_observations_obs
, andn_observations_int
(satisfying theavici.synthetic.SyntheticSpec
signature). The integersn_observations_obs
andn_observations_int
specify the number of data points generated for each dataset. At training time, these observations are subbatched further depending on the optimization parameters. -
Each (list) element in the top level of data.graph needs to define a
GraphModel
subclass, and each (list) element of data.mechanism aMechanismModel
subclass. The class name is specified via the__class__
key. All other arguments the class expects at initialization time (via__init__
) are specified alongside. Please refer to the signature ofavici.synthetic.LinearAdditive
to verify this in the above example.The class arguments may be (lists of) classes themselves, defined recursively in the same way. For example,
avici.synthetic.Distribution
subclasses specify how the weights and noise of the linear function SCMLinearAdditive
is sampled. Likewise,avici.synthetic.NoiseModel
subclasses specify the noise scale in the SCM. -
All classes not available inside
avici.synthetic
need to be defined in other files and specified via their path in the additional_modules field. When specified this way, they can be used in the configuration exactly like all other members already provided inavici.synthetic
.
The easiest way of understanding how domain.yaml
is configured is to look at a few examples.
The following configurations define the training distributions of the models trained
in Lorch et al., (2022), whose checkpoints are
available for download via avici.load_pretrained
:
linear.yaml, rff.yaml, gene.yaml.
These config files directly correspond to the Tables given in Appendix A of the paper.
Currently, we provide the following data-generating processes in avici.synthetic
:
-
GraphModel
subclasses:ErdosRenyi
ScaleFree
ScaleFreeTranspose
WattsStrogatz
SBM
GRG
Yeast
Ecoli
-
MechanismModel
subclasses:LinearAdditive
RFFAdditive
GRNSergio
-
Distribution
subclasses:Gaussian
Laplace
Cauchy
Uniform
SignedUniform
RandInt
Beta
-
NoiseModel
subclasses:SimpleNoise
HeteroscedasticRFFNoise
These classes can be used in a domain.yaml
configuration out-of-the-box and without further specifications.
This is the main training script.
Our provided script automatically performs multi-device training. Hence, if you run this
script on a machine with mulitple GPUs, all accelerators will be used directly using jax.pmap
and corresponding functions.
Given our above domain configuration,
we can train a first (small) model to check the script
by changing directory to example-custom/
and running
python train.py --config "./domain.yaml"
where --config
specifies an (absolute or relative) path to our YAML domain configuration.
The above call uses --smoke_test true
by default, which sets the network and training
parameters to small dummy values for testing.
For further information about the other command line arguments,
run python train.py --help
.
To train a large model with the same hyperparameters as Lorch et al., (2022), run
python train.py --config "./domain.yaml" --smoke_test false
Each different n_vars
of the training data distribution requires
a seperate jax.jit
compilation. Therefore, it is normal that the first
few steps of training take relatively long.
After each n_vars
has been seen, update steps are fast (~0.5sec/step for the
full model on Quadro RTX 6000 GPUs).
The script automatically generates checkpoints, which can be used
both for continuing training and for downstream predictions.
By default, the checkpoints are stored in ./checkpoints/
.
To re-start training with a checkpoint,
simply re-run train.py
with the same checkpoint directory
and the code will automatically detect the most recent checkpoint.
Analogous to the pretrained checkpoints we provide for automatic download,
the checkpoints created during training with this script can be
loaded using the avici.load_pretrained
function:
import avici
model = avici.load_pretrained(checkpoint_dir="path/to/checkpoint", expects_counts=False)