This repository has been archived by the owner on Dec 2, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathw_postanalysis_reweight.py
370 lines (274 loc) · 15.8 KB
/
w_postanalysis_reweight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# Copyright (C) 2013 Matthew C. Zwier and Lillian T. Chong
#
# This file is part of WESTPA.
#
# WESTPA is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# WESTPA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with WESTPA. If not, see <http://www.gnu.org/licenses/>.
from __future__ import print_function, division; __metaclass__ = type
import logging
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csgraph
import h5py
from collections import Counter
import westpa
from west.data_manager import weight_dtype, n_iter_dtype
from westtools import (WESTTool, WESTDataReader, IterRangeSelection,
ProgressIndicatorComponent)
from westpa import h5io
from westtools.dtypes import iter_block_ci_dtype as ci_dtype
log = logging.getLogger('westtools.w_postanalysis_reweight')
def normalize(m):
nm = m.copy()
row_sum = m.sum(1)
ii = np.nonzero(row_sum)[0]
nm[ii,:] = m[ii,:] / row_sum[ii][:, np.newaxis]
return nm
def steadystate_solve(K):
# Reformulate K to remove sink/source states
n_components, component_assignments = csgraph.connected_components(K, connection="strong")
largest_component = Counter(component_assignments).most_common(1)[0][0]
components = np.where(component_assignments == largest_component)[0]
ii = np.ix_(components, components)
K_mod = K[ii]
K_mod = normalize(K_mod)
eigvals, eigvecs = np.linalg.eig(K_mod.T)
eigvals = np.real(eigvals)
eigvecs = np.real(eigvecs)
maxi = np.argmax(eigvals)
if not np.allclose(np.abs(eigvals[maxi]), 1.0):
print('WARNING: Steady-state undetermined for current iteration')
bin_prob = K.diagonal().copy()
bin_prob = bin_prob / np.sum(bin_prob)
return bin_prob
sub_bin_prob = eigvecs[:, maxi] / np.sum(eigvecs[:, maxi])
bin_prob = np.zeros(K.shape[0])
bin_prob[components] = sub_bin_prob
return bin_prob
def accumulate_statistics(h5file, start_iter, stop_iter, nbins, total_fluxes=None, total_obs=None):
if total_fluxes is None:
assert total_obs is None
total_fluxes = np.zeros((nbins, nbins), weight_dtype)
total_obs = np.zeros((nbins, nbins), np.int64)
rows = []
cols = []
obs = []
flux = []
for iiter in xrange(start_iter, stop_iter):
iter_grp = h5file['iterations']['iter_{:08d}'.format(iiter)]
rows.append(iter_grp['rows'][...])
cols.append(iter_grp['cols'][...])
obs.append(iter_grp['obs'][...])
flux.append(iter_grp['flux'][...])
rows, cols, obs, flux = map(np.hstack, [rows, cols, obs, flux])
total_fluxes += sp.coo_matrix((flux, (rows, cols)), shape=(nbins, nbins)).todense()
total_obs += sp.coo_matrix((obs, (rows, cols)), shape=(nbins, nbins)).todense()
total_pop = np.sum(h5file['bin_populations'][start_iter:stop_iter, :], axis=0)
return total_fluxes, total_obs, total_pop
def reweight(h5file, start, stop, nstates, nbins, state_labels, state_map, nfbins, obs_threshold=1, total_fluxes=None, total_obs=None):
total_fluxes, total_obs, total_pop = accumulate_statistics(h5file, start, stop, nfbins, total_fluxes, total_obs)
flux_matrix = total_fluxes.copy()
flux_matrix[total_obs < obs_threshold] = 0.0
transition_matrix = normalize(flux_matrix)
rw_bin_probs = steadystate_solve(transition_matrix)
bin_last_state_map = np.tile(np.arange(nstates, dtype=np.int), nbins)
bin_state_map = np.repeat(state_map[:-1], nstates)
rw_color_probs = np.bincount(bin_last_state_map, weights=rw_bin_probs)
rw_state_probs = np.bincount(bin_state_map, weights=rw_bin_probs)
rw_bin_transition_matrix = transition_matrix
ii = np.nonzero(transition_matrix)
rw_state_flux = calc_state_flux(rw_bin_transition_matrix[ii], ii[0], ii[1], rw_bin_probs,
bin_last_state_map, bin_state_map, nstates)
return rw_state_flux, rw_color_probs, rw_state_probs, rw_bin_probs, rw_bin_transition_matrix
def calc_state_flux(trans_matrix, index1, index2, bin_probs, bin_last_state_map, bin_state_map, nstates):
state_flux = np.zeros((nstates, nstates), np.float64)
n_trans = index1.shape[0]
for k in xrange(n_trans):
ii = bin_last_state_map[index1[k]]
jj = bin_state_map[index2[k]]
if jj != nstates:
state_flux[ii, jj] += trans_matrix[k] * bin_probs[index1[k]]
return state_flux
class WPostAnalysisReweightTool(WESTTool):
prog ='w_postanalysis_reweight'
description = '''\
Calculate average rates from weighted ensemble data using the postanalysis
reweighting scheme. Bin assignments (usually "assignments.h5") and pre-calculated
iteration flux matrices (usually "flux_matrices.h5") data files must have been
previously generated using w_postanalysis_matrix.py (see "w_assign --help" and
"w_kinetics --help" for information on generating these files).
-----------------------------------------------------------------------------
Output format
-----------------------------------------------------------------------------
The output file (-o/--output, usually "kinrw.h5") contains the following
dataset:
/state_prob_evolution [window,state]
The reweighted state populations based on windows
/color_prob_evolution [window,state]
The reweighted populations last assigned to each state based on windows
/bin_prob_evolution [window, bin]
The reweighted populations of each bin based on windows. Bins contain
one color each, so to recover the original un-colored spatial bins,
one must sum over all states.
/conditional_flux_evolution [window,state,state]
(Structured -- see below). State-to-state fluxes based on windows of
varying width
The structure of the final dataset is as follows:
iter_start
(Integer) Iteration at which the averaging window begins (inclusive).
iter_stop
(Integer) Iteration at which the averaging window ends (exclusive).
expected
(Floating-point) Expected (mean) value of the rate as evaluated within
this window, in units of inverse tau.
-----------------------------------------------------------------------------
Command-line options
-----------------------------------------------------------------------------
'''
def __init__(self):
super(WPostAnalysisReweightTool, self).__init__()
self.data_reader = WESTDataReader()
self.iter_range = IterRangeSelection()
self.progress = ProgressIndicatorComponent()
self.output_filename = None
self.kinetics_filename = None
self.assignment_filename = None
self.output_file = None
self.assignments_file = None
self.kinetics_file = None
self.evolution_mode = None
def add_args(self, parser):
self.progress.add_args(parser)
self.data_reader.add_args(parser)
self.iter_range.include_args['iter_step'] = True
self.iter_range.add_args(parser)
iogroup = parser.add_argument_group('input/output options')
iogroup.add_argument('-a', '--assignments', default='assign.h5',
help='''Bin assignments and macrostate definitions are in ASSIGNMENTS
(default: %(default)s).''')
iogroup.add_argument('-k', '--kinetics', default='flux_matrices.h5',
help='''Per-iteration flux matrices calculated by w_postanalysis_matrix
(default: %(default)s).''')
iogroup.add_argument('-o', '--output', dest='output', default='kinrw.h5',
help='''Store results in OUTPUT (default: %(default)s).''')
cogroup = parser.add_argument_group('calculation options')
cogroup.add_argument('-e', '--evolution-mode', choices=['cumulative', 'blocked'], default='cumulative',
help='''How to calculate time evolution of rate estimates.
``cumulative`` evaluates rates over windows starting with --start-iter and getting progressively
wider to --stop-iter by steps of --step-iter.
``blocked`` evaluates rates over windows of width --step-iter, the first of which begins at
--start-iter.''')
cogroup.add_argument('--window-frac', type=float, default=1.0,
help='''Fraction of iterations to use in each window when running in ``cumulative`` mode.
The (1 - frac) fraction of iterations will be discarded from the start of each window.''')
cogroup.add_argument('--obs-threshold', type=int, default=1,
help='''The minimum number of observed transitions between two states i and j necessary to include
fluxes in the reweighting estimate''')
def open_files(self):
self.output_file = h5io.WESTPAH5File(self.output_filename, 'w', creating_program=True)
h5io.stamp_creator_data(self.output_file)
self.assignments_file = h5io.WESTPAH5File(self.assignments_filename, 'r')#, driver='core', backing_store=False)
self.kinetics_file = h5io.WESTPAH5File(self.kinetics_filename, 'r')#, driver='core', backing_store=False)
if not self.iter_range.check_data_iter_range_least(self.assignments_file):
raise ValueError('assignments data do not span the requested iterations')
if not self.iter_range.check_data_iter_range_least(self.kinetics_file):
raise ValueError('kinetics data do not span the requested iterations')
def process_args(self, args):
self.progress.process_args(args)
self.data_reader.process_args(args)
with self.data_reader:
self.iter_range.process_args(args, default_iter_step=None)
if self.iter_range.iter_step is None:
#use about 10 blocks by default
self.iter_range.iter_step = max(1, (self.iter_range.iter_stop - self.iter_range.iter_start) // 10)
self.output_filename = args.output
self.assignments_filename = args.assignments
self.kinetics_filename = args.kinetics
self.evolution_mode = args.evolution_mode
self.evol_window_frac = args.window_frac
if self.evol_window_frac <= 0 or self.evol_window_frac > 1:
raise ValueError('Parameter error -- fractional window defined by --window-frac must be in (0,1]')
self.obs_threshold = args.obs_threshold
def go(self):
pi = self.progress.indicator
with pi:
pi.new_operation('Initializing')
self.open_files()
nstates = self.assignments_file.attrs['nstates']
nbins = self.assignments_file.attrs['nbins']
state_labels = self.assignments_file['state_labels'][...]
state_map = self.assignments_file['state_map'][...]
nfbins = self.kinetics_file.attrs['nrows']
npts = self.kinetics_file.attrs['npts']
assert nstates == len(state_labels)
assert nfbins == nbins * nstates
start_iter, stop_iter, step_iter = self.iter_range.iter_start, self.iter_range.iter_stop, self.iter_range.iter_step
start_pts = range(start_iter, stop_iter, step_iter)
flux_evol = np.zeros((len(start_pts), nstates, nstates), dtype=ci_dtype)
color_prob_evol = np.zeros((len(start_pts), nstates))
state_prob_evol = np.zeros((len(start_pts), nstates))
bin_prob_evol = np.zeros((len(start_pts), nfbins))
pi.new_operation('Calculating flux evolution', len(start_pts))
if self.evolution_mode == 'cumulative' and self.evol_window_frac == 1.0:
print('Using fast streaming accumulation')
total_fluxes = np.zeros((nfbins, nfbins), weight_dtype)
total_obs = np.zeros((nfbins, nfbins), np.int64)
for iblock, start in enumerate(start_pts):
pi.progress += 1
stop = min(start + step_iter, stop_iter)
params = dict(start=start, stop=stop, nstates=nstates, nbins=nbins,
state_labels=state_labels, state_map=state_map, nfbins=nfbins,
total_fluxes=total_fluxes, total_obs=total_obs,
h5file=self.kinetics_file, obs_threshold=self.obs_threshold)
rw_state_flux, rw_color_probs, rw_state_probs, rw_bin_probs, rw_bin_flux = reweight(**params)
for k in xrange(nstates):
for j in xrange(nstates):
# Normalize such that we report the flux per tau (tau being the weighted ensemble iteration)
# npts always includes a 0th time point
flux_evol[iblock]['expected'][k,j] = rw_state_flux[k,j] * (npts - 1)
flux_evol[iblock]['iter_start'][k,j] = start
flux_evol[iblock]['iter_stop'][k,j] = stop
color_prob_evol[iblock] = rw_color_probs
state_prob_evol[iblock] = rw_state_probs[:-1]
bin_prob_evol[iblock] = rw_bin_probs
else:
for iblock, start in enumerate(start_pts):
pi.progress += 1
stop = min(start + step_iter, stop_iter)
if self.evolution_mode == 'cumulative':
windowsize = max(1, int(self.evol_window_frac * (stop - start_iter)))
block_start = max(start_iter, stop - windowsize)
else: # self.evolution_mode == 'blocked'
block_start = start
params = dict(start=block_start, stop=stop, nstates=nstates, nbins=nbins,
state_labels=state_labels, state_map=state_map, nfbins=nfbins,
total_fluxes=None, total_obs=None,
h5file=self.kinetics_file)
rw_state_flux, rw_color_probs, rw_state_probs, rw_bin_probs, rw_bin_flux = reweight(**params)
for k in xrange(nstates):
for j in xrange(nstates):
# Normalize such that we report the flux per tau (tau being the weighted ensemble iteration)
# npts always includes a 0th time point
flux_evol[iblock]['expected'][k,j] = rw_state_flux[k,j] * (npts - 1)
flux_evol[iblock]['iter_start'][k,j] = start
flux_evol[iblock]['iter_stop'][k,j] = stop
color_prob_evol[iblock] = rw_color_probs
state_prob_evol[iblock] = rw_state_probs[:-1]
bin_prob_evol[iblock] = rw_bin_probs
ds_flux_evol = self.output_file.create_dataset('conditional_flux_evolution', data=flux_evol, shuffle=True, compression=9)
ds_state_prob_evol = self.output_file.create_dataset('state_prob_evolution', data=state_prob_evol, compression=9)
ds_color_prob_evol = self.output_file.create_dataset('color_prob_evolution', data=color_prob_evol, compression=9)
ds_bin_prob_evol = self.output_file.create_dataset('bin_prob_evolution', data=bin_prob_evol, compression=9)
ds_state_labels = self.output_file.create_dataset('state_labels', data=state_labels)
if __name__ == '__main__':
WPostAnalysisReweightTool().main()