-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_s2c_als_linear_sd_fits_init.py
52 lines (43 loc) · 2.25 KB
/
run_s2c_als_linear_sd_fits_init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from spike2calcium_linear_ALS_SD import *
import os, sys
import warnings
warnings.filterwarnings('ignore')
from glob import glob
# get cell list information
# flist = sorted(glob('GCaMP8_exported_ROIs_s2f_full/dff_v00/*.npz'))
# new fit based on the clean-up data
flist = sorted(glob('GCaMP8_exported_ROIs_s2f_full/dff_v02/*.npz'))
cell_info=["_".join(os.path.basename(f).split('_')[:6])[:-4] for f in flist]
cell_info_uq = np.unique(cell_info)
# fitting codes
def run_single_cell(nf):
nf = nf.squeeze()
_ = np.load(f'GCaMP8_exported_ROIs_s2f_full/dff_v00/{nf}.npz', allow_pickle=True)
spike_times = _['spike_times'] # raw fluorescence of the ROI
ca_times = _['ca_times'] # raw fluorescence of the surrounding neuropil, for neuropil subtraction
ca_trace = _['ca_trace'] # frame rate of the movie
cell_type = _['cell_type'] # timing of the action potentials from ephys
# these checks for old fit (fit_als_sd_results_v01) can be removed
valid_trial = np.array([len(st) for st in spike_times])>=12
valid_trial_ = np.array([np.min(st) for st in spike_times])<np.array([np.max(ct) for ct in ca_times])
valid_trial = valid_trial & valid_trial_
if valid_trial.sum()==0:
return None
spike_times = [_ for _ in spike_times[valid_trial]]
ca_times = [_ for _ in ca_times[valid_trial]]
ca_trace =[_ for _ in ca_trace[valid_trial]]
try:
_=np.load('GCaMP8_exported_ROIs_s2f_full/fit_als_linear_sd_results_v01/'+nf+'_fit.npz', allow_pickle=True)
param_kernel = _['param_kernel']
except:
param_kernel = np.array([1/1000, 0.5])
for n in range(4): # do 4 times of ALS at 1st round, 2 times in 2nd
param_linear = fit_spike2calcium_linear(param_kernel, ca_trace, spike_times, ca_times)
param_kernel = fit_spike2calcium_kernel(param_kernel, param_linear, ca_trace, spike_times, ca_times)
param_linear = fit_spike2calcium_linear(param_kernel, ca_trace, spike_times, ca_times)
np.savez('GCaMP8_exported_ROIs_s2f_full/fit_als_linear_sd_results_v04/'+nf+'_fit', param_linear=np.array(param_linear), param_kernel=np.array(param_kernel))
return None
if __name__ == "__main__":
import dask
res = [dask.delayed(run_single_cell)(nf) for nf in cell_info_uq]
dask.compute(res)