-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_conv.py
158 lines (124 loc) · 6.01 KB
/
data_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from data_generating import *
def convolution_kernel(data, h):
windowed_data_sq = np.convolve(data ** 2, np.ones(h + 1), "same")
windowed_data_aux = np.sqrt(windowed_data_sq)
windowed_data = windowed_data_aux / np.amax(np.abs(windowed_data_aux))
return windowed_data
def convolution_kernel2(data, h):
"""
:param data:
:param h: must be even
:return:
"""
window_l = np.hstack((np.ones(int(h/2)), -np.ones(int(h/2))))
windowed_data_l = np.convolve(data, window_l, "same")
#windowed_data_l = windowed_data_l/np.max(np.abs(windowed_data_l))
return windowed_data_l
def convolution_signal_func(signal_matrix, convolution_kernal, convolution_window, normalize=False):
num_row = signal_matrix.shape[0]
convolution_signal = np.zeros((num_row, signal_matrix.shape[1]))
for i in range(num_row):
convolution_signal[i] = convolution_kernal(signal_matrix[i], convolution_window)
if normalize:
convolution_signal = convolution_signal/np.max(np.abs(convolution_signal))
return convolution_signal
def undersample(convolution_data, undersample_window, undersample_threshold):
"""
undersample convoluved signal
======================================================
mat,loc_int=undersample(conv,window_l,threshold)
input:
conv: convolved signal matrix
window_l: window length for undersampling
threshold: the threshold for undersampling
output:
mat:the signal matrix after convolution, max of mat
is 1, min is 0 or -1 depending on the kernel
loc_int: the location of where the matrix has non zero
element, where there's a spike
======================================================
"""
# input value convolved signals, window length and threshold for undersample
interval_num = int(convolution_data.shape[0] / undersample_window)
mat = np.zeros(interval_num)
loc_list = []
for i in range(interval_num):
undersample_interval = convolution_data[undersample_window * i:undersample_window * (i + 1)]
interval_abs = np.abs(undersample_interval)
indices = [x for x in interval_abs if x >= undersample_threshold]
if len(indices) == 0:
mat[i] = 0
else:
loc = undersample_window * i + np.argmax(interval_abs)
mat[i] = convolution_data[int(loc)]
loc_list.append(int(loc))
return mat, np.array(loc_list)
def undersample_signal_func(convolution_matrix, undersample_window, undersample_threshold):
"""
# perform convolution and then undersample to signal matrix
:return:
undersample_signal_mat: undersampling signal matrix
spike_loc_list: the point where we get undersampling point
"""
num_electron = convolution_matrix.shape[0]
signal_length = convolution_matrix.shape[1]
undersample_signal_length = int(signal_length/undersample_window)
undersample_signal_mat = np.zeros((num_electron, undersample_signal_length))
spike_loc_list = []
for i in range(num_electron):
convolution_signal = convolution_matrix[i]
# perform undersample on ith electron
undersample_signal, loc = undersample(convolution_signal, undersample_window, undersample_threshold)
undersample_signal_mat[i] = undersample_signal
# get the location of the signal where we get the max convolution point
spike_loc_list.append(loc)
return undersample_signal_mat, spike_loc_list
def undersample_convolution_plot(signal_mat, spike_loc_list, convolution_signal_mat, undersample_window, plot_interval, time_step,
path ='fig/undersample plot'):
time_axis = np.arange(plot_interval[0], plot_interval[1], time_step)
num_electron = signal_mat.shape[0]
signal_length = signal_mat.shape[1]
fig, axs = plt.subplots(nrows=num_electron, ncols=1, sharex=True, sharey=True, figsize=(15, 15))
# normalize signal for better visualization
signal_mat = signal_mat/np.max(np.abs(signal_mat))
convolution_signal_mat = convolution_signal_mat/np.max(np.abs(convolution_signal_mat))
start = int(plot_interval[0] / time_step)
end = int(plot_interval[1] / time_step)
for i in range(num_electron):
# plot original signal
axs[i].plot(time_axis, signal_mat[i][start:end], c='lightcoral', label='signal ')
# plot convolution
axs[i].plot(time_axis, convolution_signal_mat[i][start: end],'--',c='orange', label='convolution')
# plot the point where we think is the spike location
spike_loc = spike_loc_list[i]
index = np.array([item for item in spike_loc if plot_interval[0] <= item*time_step < plot_interval[1]])
axs[i].scatter(index * time_step, np.ones(len(index)), marker='o', c='c', label='spike location')
# plot the undersample grid
grid_point = np.arange(0, signal_length, undersample_window)
grid_point = np.array([point for point in grid_point if end >= point >= start])
axs[i].scatter(grid_point * time_step, np.zeros(len(grid_point)), marker="*", c= 'g', label='undersample grid')
axs[0].set_title('Undersample of signals by ' + str(num_electron)+' electrons')
axs[0].legend(loc=2)
plt.xlabel('time step')
plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
plt.savefig(path)
plt.close()
return
def undersample_convolution_heatmap_plot(undersample_mat, plot_length, time_step):
t = np.arange(0, plot_length, time_step)
undersample_mat = undersample_mat[:, : len(t)]
ax = sns.heatmap(undersample_mat)
plt.savefig('fig/undersample heatmap')
plt.close()
return
def undersample_plot(undersample_mat, plot_length, file_name = 'fig/undersample mat'):
num_neuron = undersample_mat.shape[0]
fig, axs = plt.subplots(nrows=num_neuron, ncols=1, sharex=True, sharey=True)
for i in range(num_neuron):
axs[i].plot(undersample_mat[i][: plot_length])
axs[i].set_title('undersample ' + str(i) + ' signals')
plt.savefig(file_name)
return