Skip to content

Commit

Permalink
new rate from pyemma
Browse files Browse the repository at this point in the history
  • Loading branch information
Chenggong committed Jul 20, 2023
1 parent 215b32e commit e608ac5
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 47 deletions.
157 changes: 111 additions & 46 deletions Sfilter/util/MSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import networkx as nx
import matplotlib.pyplot as plt
from .output_wrapper import read_k_cylinder

import pyemma



Expand Down Expand Up @@ -171,17 +171,39 @@ def calc_state_array(self, merge_list=None):

def get_transition_matrix(self, lag_step=1):
"""
compute transition matrix
compute states transition matrix.
a numpy array, each element is the number of transition from state i to state j
"""""
state_num = max([max(traj) for traj in self.state_array])
tran_matrix = np.zeros((state_num + 1, state_num + 1), dtype=np.int64)
f_matrix = np.zeros((state_num + 1, state_num + 1), dtype=np.int64)
#for traj in self.state_array:
# state_start = traj[:-lag_step]
# state_end = traj[lag_step:]
# for m_step in np.array([state_start, state_end]).T:
# tran_matrix[m_step[0], m_step[1]] += 1
for traj in self.state_array:
state_start = traj[:-lag_step]
state_end = traj[lag_step:]
for m_step in np.array([state_start, state_end]).T:
tran_matrix[m_step[0], m_step[1]] += 1
return tran_matrix
f_matrix[m_step[0], m_step[1]] += 1

return f_matrix

def f_matrix_2_rate_matrix(self, f_matrix, physical_time):
"""
compute rate matrix
a numpy array, each element is the rate from state i to state j
input:
f_matrix: a numpy array, each element is the number of transition from state i to state j
physical_time: physical time for each step
return:
rate_matrix: a numpy array, each element is the rate from state i to state j
"""
rate_matrix = np.array(f_matrix, dtype=np.float64) # convert int to float
for i in range(rate_matrix.shape[0]):
rate_matrix[i, :] /= self.node_counter[i] * physical_time
rate_matrix[i, i] = 0
return rate_matrix

def get_rate_matrix(self, lag_step=1, physical_time=None):
"""
Expand All @@ -200,25 +222,31 @@ def get_rate_matrix(self, lag_step=1, physical_time=None):
else:
raise ValueError("physical_time is not given, and time_step is not equal")

t_matrix = self.get_transition_matrix(lag_step)
rate_matrix = np.array(t_matrix, dtype=np.float64)
for i in range(rate_matrix.shape[0]):
rate_matrix[i, :] /= self.node_counter[i] * physical_time
rate_matrix[i, i] = 0
f_matrix = self.get_transition_matrix(lag_step)
rate_matrix = self.f_matrix_2_rate_matrix(f_matrix, physical_time)
return rate_matrix

def get_transition_probability(self, lag_step=1):
def f_matrix_2_transition_probability(self, f_matrix):
"""
compute transition probability
compute transition probability matrix (between steps)
return: transition_probability_matrix
a numpy array, each element is the probability of transition from state i to state j
The sum of each row is 1.
"""
t_matrix = self.get_transition_matrix(lag_step)
p_matrix = np.array(t_matrix, dtype=np.float64)
p_matrix = np.array(f_matrix, dtype=np.float64) # int to float
p_matrix /= p_matrix.sum(axis=1, keepdims=True) # normalize each row
return p_matrix

def get_transition_probability(self, lag_step=1):
"""
compute transition probability matrix (between steps)
return: transition_probability_matrix
a numpy array, each element is the probability of transition from state i to state j
The sum of each row is 1.
"""
f_matrix = self.get_transition_matrix(lag_step)
return self.f_matrix_2_transition_probability(f_matrix)

def get_CK_test(self, lag_step=1, test_time=[2, 4]):
"""
run Chapman-Kolmogorov test
Expand Down Expand Up @@ -303,8 +331,8 @@ def get_matrix(self, lag_step=1, physical_time=None):
"""
calculate transition matrix, rate_matrix, and transition probability
return:
t_matrix: each element is the number of transition from state i to state j
net_t_matrix: each element is the number of net event between state i to state j (t_matrix - t_matrix.T)
f_matrix: each element is the number of flux from state i to state j
net_f_matrix: each element is the number of net event between state i to state j (f_ij - f_ji)
rate_matrix: each element is the rate (number of event / observation time) from state i to state j
p_matrix: each element is the probability of transition from state i to state j
input:
Expand All @@ -317,15 +345,11 @@ def get_matrix(self, lag_step=1, physical_time=None):
else:
raise ValueError("physical_time is not given, and time_step is not equal")

t_matrix = self.get_transition_matrix(lag_step)
net_t_matrix = t_matrix - t_matrix.T
rate_matrix = np.array(t_matrix, dtype=np.float64)
for i in range(rate_matrix.shape[0]):
rate_matrix[i, :] /= self.node_counter[i] * physical_time
rate_matrix[i, i] = 0
p_matrix = np.array(t_matrix, dtype=np.float64)
p_matrix /= p_matrix.sum(axis=1, keepdims=True)
return t_matrix, net_t_matrix, rate_matrix, p_matrix
f_matrix = self.get_transition_matrix(lag_step) # transition between states (not steps)
net_t_matrix = f_matrix - f_matrix.T
rate_matrix = self.f_matrix_2_rate_matrix(f_matrix, physical_time)
p_matrix = self.f_matrix_2_transition_probability(f_matrix)
return f_matrix, net_t_matrix, rate_matrix, p_matrix

def get_resident_time(self):
"""
Expand Down Expand Up @@ -356,7 +380,7 @@ def find_merge_states(self, cut_off=0.01, lag_step=1, physical_time=None, method
"""
if physical_time is None: # use the time_step that was read from file
if np.allclose(self.time_step, self.time_step[0]):
physical_time = self.time_step[0]
physical_time = self.time_step[0] * lag_step
else:
raise ValueError("physical_time is not given, and time_step is not equal")

Expand Down Expand Up @@ -446,6 +470,67 @@ def lump_MFPT(self, node_cut_off=0.01, min_node=3):
raise ValueError("merge error " + str(n_0) + str(n_1))
return True, merge_list_new, [self.int_2_s[n_0], self.int_2_s[n_1]]

def get_pyemma_TPT_rate(self):
rate_matrix = np.zeros((len(self.int_2_s), len(self.int_2_s)))
msm = pyemma.msm.estimate_markov_model(self.state_array, lag=1,
reversible=False, dt_traj=str(self.time_step[0])+" ps")
for i in range(len(self.int_2_s)):
for j in range(len(self.int_2_s)):
if i != j:
rate_matrix[i, j] = pyemma.msm.tpt(msm, [i], [j]).rate
else:
rate_matrix[i, j] = 0
return rate_matrix


def lump_pyemma_TPT_rate(self, node_cut_off=0.01, min_node=3):
total_count = self.node_counter.total()
for n, node_count in self.node_counter.items():
if node_count / total_count < node_cut_off:
break
# check node number
if n < min_node:
return False, copy.deepcopy(self.merge_list), [0, 0]
# get rate matrix using pyemma TPT
msm_pyemma = pyemma.msm.estimate_markov_model(self.state_array, lag=1,
reversible=False, dt_traj=str(self.time_step[0])+" ps")
rate_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
if i != j:
rate_matrix[i, j] = pyemma.msm.tpt(msm_pyemma, [i], [j]).rate
else:
rate_matrix[i, j] = 0
rate_list = []
for i in range(n):
for j in range(0, i):
rate_list.append([i, j, rate_matrix[i, j] * rate_matrix[j, i]])
rate_list = sorted(rate_list, key=lambda x: x[2], reverse=True)
n_0, n_1, rate = rate_list[0]
if len(self.int_2_s[n_0]) > 1 and len(self.int_2_s[n_1]) > 1:
merge_list_new = copy.deepcopy(self.merge_list)
for node in self.merge_list:
for node2 in self.merge_list:
if node == self.int_2_s[n_0] and node2 == self.int_2_s[n_1]:
merge_list_new.remove(node)
merge_list_new.remove(node2)
merge_list_new.append(node + node2)
elif len(self.int_2_s[n_0]) > 1 or len(self.int_2_s[n_1]) > 1:
merge_list_new = []
for node in self.merge_list:
if node == self.int_2_s[n_0]:
merge_list_new.append(node + self.int_2_s[n_1])
elif node == self.int_2_s[n_1]:
merge_list_new.append(self.int_2_s[n_0] + node)
else:
merge_list_new.append(node)
elif len(self.int_2_s[n_0]) == 1 or len(self.int_2_s[n_1]) == 1:
merge_list_new = copy.deepcopy(self.merge_list)
merge_list_new.append(self.int_2_s[n_0] + self.int_2_s[n_1])
else:
raise ValueError("merge error " + str(n_0) + str(n_1))
return True, merge_list_new, [self.int_2_s[n_0], self.int_2_s[n_1]]



def merge_until(self, rate_cut_off, rate_square_cut_off, node_cut_off=0.01, step_cut_off=30, lag_step=1, physical_time=None,
Expand Down Expand Up @@ -635,25 +720,5 @@ def computer_pos(strings):
return x, y


def get_transition_matrix(state_arrays, begin=0, lag_time=1):
state_num = max([max(traj) for traj in state_arrays])
tran_matrix = np.zeros((state_num + 1, state_num + 1), dtype=np.int64)
for traj in state_arrays:
state_start = traj[begin:-lag_time]
state_end = traj[begin + lag_time:]
for m_step in np.array([state_start, state_end]).T:
tran_matrix[m_step[0], m_step[1]] += 1
return tran_matrix


def get_distribution(state_arrays):
flattened = [num for sublist in state_arrays for num in sublist]
return Counter(flattened)


def get_rate_matrix(state_arrays, phy_time, begin=0, lag_time=1):
tran_matrix = np.array(get_transition_matrix(state_arrays, begin, lag_time), dtype=np.float64)
counter = get_distribution(state_arrays)
for i in range(tran_matrix.shape[0]):
tran_matrix[i, :] /= counter[i] * phy_time
return tran_matrix
22 changes: 21 additions & 1 deletion test/test_MSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ def test_SF_msm_set_state_str(self):
self.assertDictEqual(msm.state_counter, {"A": 11, "B": 7, "C": 3})
self.assertDictEqual(msm.node_counter, {0: 18, 1: 3})

def test_SF_msm_get_transition_matrix(self):
msm = MSM.SF_msm([])
msm.set_state_str(["A A B C A B C A B C A B C A B".split()])
msm.calc_state_array()
f_matrix_1 = msm.get_transition_matrix(lag_step=1)
f_matrix_2 = msm.get_transition_matrix(lag_step=2)
f_matrix_3 = msm.get_transition_matrix(lag_step=3)
self.assertListEqual(f_matrix_1.tolist(), [[1, 5, 0], [0, 0, 4], [4, 0, 0]])
self.assertListEqual(f_matrix_2.tolist(), [[0, 1, 4], [4, 0, 0], [0, 4, 0]])
self.assertListEqual(f_matrix_3.tolist(), [[4, 0, 1], [0, 4, 0], [0, 0, 3]])
msm.time_step = [1]
r_1 = msm.get_rate_matrix(lag_step=1)
r_2 = msm.get_rate_matrix(lag_step=2)
r_3 = msm.get_rate_matrix(lag_step=3)
self.assertListEqual(r_1.tolist(), [[0, 5/6, 0], [0, 0, 4/5], [4/4, 0, 0]])
self.assertListEqual(r_2.tolist(), [[0, 1/6, 4/6], [4/5, 0, 0], [0, 4/4, 0]])
self.assertListEqual(r_3.tolist(), [[0, 0, 1/6], [0, 0, 0], [0, 0, 0]])



def test_SF_msm_get_matrix(self):
msm = MSM.SF_msm([])
msm.set_state_str(["A B A B A B C D".split(),
Expand Down Expand Up @@ -145,7 +165,7 @@ def test_SF_msm_merge_until_03(self):
"A B A B A B A B A B A D C D D C C D D".split(),
"B A B A B A B A B A B D D C D C C D D E D".split()])
msm.calc_state_array()
reason = msm.merge_until(rate_cut_off=0.00, rate_square_cut_off=0.00, node_cut_off=0.017, lag_step=1, physical_time=1, method="rate_square", min_node=1)
reason = msm.merge_until(rate_cut_off=0.00, rate_square_cut_off=0.00, node_cut_off=0.017, lag_step=1, physical_time=1, method="rate_square", min_node=2)
t_matrix, net_t_matrix, rate_matrix, p_matrix = msm.get_matrix(lag_step=1, physical_time=1)
print()
print(reason)
Expand Down

0 comments on commit e608ac5

Please sign in to comment.