-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
197 lines (185 loc) · 10.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
import numpy as np
import time
from argparse import ArgumentParser
from copy import deepcopy
import json
import random
import os
import sys
sys.path.append("../")
from causal_graphs.graph_utils import adj_matrix_to_edges
from causal_graphs.graph_visualization import visualize_graph
from causal_discovery.utils import get_device
from causal_discovery.enco import ENCO
def set_seed(seed):
"""
Sets the seed for all libraries used.
"""
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available:
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_basic_parser():
"""
Returns argument parser of standard hyperparameters/experiment arguments.
"""
parser = ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=30,
help='Number of epochs to run ENCO for.')
parser.add_argument('--seed', type=int, default=42,
help='Random seed for the experiments.')
parser.add_argument('--cluster', action='store_true',
help='If True, no tqdm progress bars are used.')
parser.add_argument('--batch_size', type=int, default=128,
help='Batch size to use for distribution and graph fitting.')
parser.add_argument('--hidden_size', type=int, default=64,
help='Hidden size of the distribution fitting NNs.')
parser.add_argument('--use_flow_model', action='store_true',
help='If True, a Deep Sigmoidal Flow will be used as model if'
' the graph contains continuous data.')
parser.add_argument('--model_iters', type=int, default=1000,
help='Number of updates per distribution fitting stage.')
parser.add_argument('--graph_iters', type=int, default=100,
help='Number of updates per graph fitting stage.')
parser.add_argument('--lambda_sparse', type=float, default=0.004,
help='Sparsity regularizer in the graph fitting stage.')
parser.add_argument('--lr_model', type=float, default=5e-3,
help='Learning rate of distribution fitting NNs.')
parser.add_argument('--lr_gamma', type=float, default=2e-2,
help='Learning rate of gamma parameters in graph fitting.')
parser.add_argument('--lr_theta', type=float, default=1e-1,
help='Learning rate of theta parameters in graph fitting.')
parser.add_argument('--weight_decay', type=float, default=0.0,
help='Weight decay to use during distribution fitting.')
parser.add_argument('--checkpoint_dir', type=str, default=None,
help='Directory to save experiment log to. If None, one will'
' be created based on the current time')
parser.add_argument('--GF_num_batches', type=int, default=1,
help='Number of batches to use in graph fitting gradient estimators.')
parser.add_argument('--GF_num_graphs', type=int, default=100,
help='Number of graph samples to use in the gradient estimators.')
parser.add_argument('--max_graph_stacking', type=int, default=200,
help='Number of graphs to evaluate in parallel. Reduce this to save memory.')
parser.add_argument('--use_theta_only_stage', action='store_true',
help='If True, gamma is frozen in every second graph fitting stage.'
' Recommended for large graphs with >=100 nodes.')
parser.add_argument('--theta_only_num_graphs', type=int, default=4,
help='Number of graph samples to use when gamma is frozen.')
parser.add_argument('--theta_only_iters', type=int, default=1000,
help='Number of updates per graph fitting stage when gamma is frozen.')
parser.add_argument('--save_model', action='store_true',
help='If True, the neural networks will be saved besides gamma and theta.')
parser.add_argument('--stop_early', action='store_true',
help='If True, ENCO stops running if it achieved perfect reconstruction in'
' all of the last 5 epochs.')
parser.add_argument('--sample_size_obs', type=int, default=5000,
help='Dataset size to use for observational data. If an exported graph is'
' given as input and sample_size_obs is smaller than the exported'
' observational dataset, the first sample_size_obs samples will be taken.')
parser.add_argument('--sample_size_inters', type=int, default=200,
help='Number of samples to use per intervention. If an exported graph is'
' given as input and sample_size_inters is smaller than the exported'
' interventional dataset, the first sample_size_inters samples will be taken.')
parser.add_argument('--max_inters', type=int, default=-1,
help='Number of variables to provide interventional data for. If smaller'
' than zero, interventions on all variables will be used.')
return parser
def test_graph(graph, args, checkpoint_dir, file_id):
"""
Runs ENCO on a given graph for structure learning.
Parameters
----------
graph : CausalDAG
The graph on which we want to perform causal structure learning.
args : Namespace
Parsed input arguments from the argument parser, including all
hyperparameters.
checkpoint_dir : str
Directory to which all logs and the model should be
saved to.
file_id : str
Identifier of the graph/experiment instance. Is used for creating
log filenames, and identify the graph among other experiments in
the same checkpoint directory.
"""
# Determine variables to exclude from the intervention set
if args.max_inters < 0:
graph.exclude_inters = None
elif graph.exclude_inters is not None:
graph.exclude_inters = graph.exclude_inters[:-args.max_inters]
else:
exclude_inters = list(range(graph.num_vars))
random.seed(args.seed)
random.shuffle(exclude_inters)
exclude_inters = exclude_inters[:-args.max_inters]
graph.exclude_inters = exclude_inters
# Execute ENCO on graph
discovery_module = ENCO(graph=graph,
hidden_dims=[args.hidden_size],
use_flow_model=args.use_flow_model,
lr_model=args.lr_model,
weight_decay=args.weight_decay,
lr_gamma=args.lr_gamma,
lr_theta=args.lr_theta,
model_iters=args.model_iters,
graph_iters=args.graph_iters,
batch_size=args.batch_size,
GF_num_batches=args.GF_num_batches,
GF_num_graphs=args.GF_num_graphs,
lambda_sparse=args.lambda_sparse,
use_theta_only_stage=args.use_theta_only_stage,
theta_only_num_graphs=args.theta_only_num_graphs,
theta_only_iters=args.theta_only_iters,
max_graph_stacking=args.max_graph_stacking,
sample_size_obs=args.sample_size_obs,
sample_size_inters=args.sample_size_inters
)
discovery_module.to(get_device())
start_time = time.time()
discovery_module.discover_graph(num_epochs=args.num_epochs,
stop_early=args.stop_early)
duration = int(time.time() - start_time)
print("-> Finished training in %ih %imin %is" % (duration // 3600, (duration // 60) % 60, duration % 60))
# Save metrics in checkpoint folder
metrics = discovery_module.get_metrics()
with open(os.path.join(checkpoint_dir, "metrics_%s.json" % file_id), "w") as f:
json.dump(metrics, f, indent=4)
print('-'*50 + '\nFinal metrics:')
discovery_module.print_graph_statistics(m=metrics)
if graph.num_vars < 100:
metrics_acyclic = discovery_module.get_metrics(enforce_acyclic_graph=True)
with open(os.path.join(checkpoint_dir, "metrics_acyclic_%s.json" % file_id), "w") as f:
json.dump(metrics_acyclic, f, indent=4)
print('-'*50 + '\nFinal metrics (acyclic):')
discovery_module.print_graph_statistics(m=metrics_acyclic)
with open(os.path.join(checkpoint_dir, "metrics_full_log_%s.json" % file_id), "w") as f:
json.dump(discovery_module.metric_log, f, indent=4)
# Save predicted binary matrix
binary_matrix = discovery_module.get_binary_adjmatrix().detach().cpu().numpy()
np.save(os.path.join(checkpoint_dir, 'binary_matrix_%s.npy' % file_id),
binary_matrix.astype(np.bool))
if graph.num_vars < 100:
acyclic_matrix = discovery_module.get_acyclic_adjmatrix().detach().numpy()
np.save(os.path.join(checkpoint_dir, 'binary_acyclic_matrix_%s.npy' % file_id),
acyclic_matrix.astype(np.bool))
# Visualize predicted graphs. For large graphs, visualizing them do not really help
if graph.num_vars < 40:
pred_graph = deepcopy(graph)
pred_graph.adj_matrix = binary_matrix
pred_graph.edges = adj_matrix_to_edges(pred_graph.adj_matrix)
figsize = max(3, pred_graph.num_vars / 1.5)
visualize_graph(pred_graph,
filename=os.path.join(checkpoint_dir, "graph_%s_prediction.pdf" % (file_id)),
figsize=(figsize, figsize),
layout="circular")
# Save parameters and model if wanted
state_dict = discovery_module.get_state_dict()
if not args.save_model: # The model can be expensive in memory
_ = state_dict.pop("model")
torch.save(state_dict,
os.path.join(checkpoint_dir, "state_dict_%s.tar" % file_id))