-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Learning parameters | ||
learning = {'rate' : 0.001, | ||
'minEpoch' : 2, | ||
'lrScale' : 0.9, | ||
'batchSize' : 16, #256, | ||
'lrScaleCount' : 1000, | ||
'minValError' : 0.00005} | ||
|
||
# Feature extraction parameters: not used for spikefinder evaluations | ||
param = {'windowLength':1000,'windowShift':1000, | ||
'fs': 100} | ||
param['stdFloor'] = 1e-3 # Floor on standard deviation | ||
param['windowLengthSamples'] = int(param['windowLength'] * param['fs'] / 1000.0) #for | ||
param['windowShiftSamples'] = int(param['windowShift'] * param['fs'] / 1000.0) | ||
|
||
#main parameters | ||
dataloc = '../../' #train and test split location | ||
maxlen = 100000 #max possible length of the calcium sigbnal to be fed to the model (100Hz, in samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# python script for loading spikefinder data | ||
# for more info see https://github.com/codeneuro/spikefinder | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from scipy import signal | ||
import config | ||
|
||
def load_data_train(): | ||
calcium_train = [] | ||
spikes_train = [] | ||
ids = [] | ||
for dataset in range(10): | ||
calcium_train.append(np.array(pd.read_csv(config.dataloc + | ||
'spikefinder.train/' + str(dataset+1) + | ||
'.train.calcium.csv'))) | ||
spikes_train.append(np.array(pd.read_csv(config.dataloc + | ||
'spikefinder.train/' + str(dataset+1) + | ||
'.train.spikes.csv'))) | ||
ids.append(np.array([dataset]*calcium_train[-1].shape[1])) | ||
maxlen = max([c.shape[0] for c in calcium_train]) | ||
maxlen = max(maxlen, config.maxlen) | ||
#maxlen_test = max([c.shape[0] for c in calcium_test]) | ||
#maxlen = max(maxlen, maxlen_test) | ||
calcium_train_padded = np.hstack([np.pad(c, ((0, maxlen-c.shape[0]), (0, 0)), 'wrap' ) for c in calcium_train]) | ||
spikes_train_padded = np.hstack([np.pad(c, ((0, maxlen-c.shape[0]), (0, 0)), 'wrap' ) for c in spikes_train]) | ||
ids_stacked = np.hstack(ids) | ||
sample_weight = 1. + 1.5*(ids_stacked<5) | ||
sample_weight /= sample_weight.mean() | ||
calcium_train_padded[np.isnan(calcium_train_padded)] = 0. | ||
spikes_train_padded[np.isnan(spikes_train_padded)] = -1 #it was -1. | ||
|
||
calcium_train_padded[spikes_train_padded<-1] = np.nan | ||
spikes_train_padded[spikes_train_padded<-1] = np.nan | ||
#if gaussian convolving is needed | ||
window = signal.gaussian(33,std=10) | ||
spikes_train_padd = spikes_train_padded | ||
for i in range(spikes_train_padded.shape[1]): | ||
spikes_train_padd[:,i] = np.convolve(spikes_train_padded[:,i], window, mode='same') | ||
|
||
sp = np.asarray(spikes_train) | ||
calcium_train_padded[np.isnan(calcium_train_padded)] = 0. | ||
spikes_train_padded[np.isnan(spikes_train_padded)] = -1 #it was -1. | ||
spikes_train_padd[np.isnan(spikes_train_padd)] = -1 | ||
|
||
|
||
calcium_train_padded = calcium_train_padded.T[:, :, np.newaxis] | ||
spikes_train_padded = spikes_train_padded.T[:, :, np.newaxis] | ||
spikes_train_padd = spikes_train_padd.T[:, :, np.newaxis] | ||
|
||
#optional-used mainly in test set | ||
ids_oneshot = np.zeros((calcium_train_padded.shape[0], | ||
calcium_train_padded.shape[1], 10)) | ||
|
||
for n,i in enumerate(ids_stacked): | ||
ids_oneshot[n, :, i] = 1. | ||
|
||
return {'calcium signal padded': calcium_train_padded, 'spikes train padded': spikes_train_padded, 'Gaussian spikes train': spikes_train_padd} | ||
#optional to use either spike train or Gaussian train | ||
|
||
def load_data_test(): | ||
calcium_test = [] | ||
spikes_test = [] | ||
ids_test = [] | ||
for dataset in range(5): | ||
calcium_test.append(np.array(pd.read_csv(config.dataloc + | ||
'spikefinder.test/' + str(dataset+1) + | ||
'.test.calcium.csv'))) | ||
spikes_test.append(np.array(pd.read_csv(config.dataloc + | ||
'spikefinder.test/' + str(dataset+1) + | ||
'.test.spikes.csv'))) | ||
ids_test.append(np.array([dataset]*calcium_test[-1].shape[1])) | ||
|
||
maxlen_test = max([c.shape[0] for c in calcium_test]) | ||
maxlen_test = max(maxlen_test, config.maxlen) | ||
calcium_test_padded = \ | ||
np.hstack([np.pad(c, ((0, maxlen_test-c.shape[0]), (0, 0)), 'constant', constant_values=np.nan) for c in calcium_test]) | ||
spikes_test_padded = \ | ||
np.hstack([np.pad(c, ((0, maxlen_test-c.shape[0]), (0, 0)), 'constant', constant_values=np.nan) for c in spikes_test]) | ||
|
||
ids_test_stacked = np.hstack(ids_test) | ||
calcium_test_padded[spikes_test_padded<-1] = np.nan | ||
spikes_test_padded[spikes_test_padded<-1] = np.nan | ||
spt = np.asarray(spikes_test) | ||
|
||
calcium_test_padded[np.isnan(calcium_test_padded)] = 0. | ||
spikes_test_padded[np.isnan(spikes_test_padded)] = -1 | ||
|
||
calcium_test_padded = calcium_test_padded.T[:, :, np.newaxis] | ||
spikes_test_padded = spikes_test_padded.T[:, :, np.newaxis] | ||
|
||
ids_oneshot_test = np.zeros((calcium_test_padded.shape[0], | ||
calcium_test_padded.shape[1], 10)) | ||
for n,i in enumerate(ids_test_stacked): | ||
ids_oneshot_test[n, :, i] = 1. | ||
|
||
return {'calcium signal': calcium_test, 'calcium signal padded': calcium_test_padded, 'spikes train': spt, 'spikes train padded': spikes_test_padded, 'ids oneshot': ids_oneshot_test, 'ids stacked': ids_test_stacked} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#python 3.8.10 | ||
tensorflow==2.5.0 | ||
keras==2.4.3 | ||
scikit-learn==0.24.2 | ||
scipy==1.7.0 | ||
numpy==1.19.5 | ||
h5py==3.1.0 | ||
pandas==1.2.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import numpy as np | ||
from scipy.stats import spearmanr | ||
from scipy import corrcoef | ||
from keras.models import load_model | ||
from sklearn.metrics import roc_curve, auc | ||
from datasets import load_data_test | ||
from optparse import OptionParser | ||
|
||
#testing script for spikefinder | ||
|
||
def score(a, b, method, downsample=4): | ||
""" | ||
Estimate similarity score between two reslts. | ||
""" | ||
methods = { | ||
'loglik': _loglik, | ||
'info': _info, | ||
'corr': _corr, | ||
'auc': _auc, | ||
'rank': _rank | ||
} | ||
if method not in methods.keys(): | ||
raise Exception('scoring method not one of: %s' % ' '.join(methods.keys())) | ||
|
||
func = methods[method] | ||
|
||
result = [] | ||
for i in range(a.shape[0]): | ||
x = a[i,:] | ||
y = b[i,:] | ||
x = x[:len(spike_npt[k])] | ||
ml = min([len(x),len(y)]) | ||
|
||
x = x[0:ml] | ||
y = y[0:ml] | ||
naninds = np.isnan(x) | np.isnan(y) | ||
x = x[~naninds] | ||
y = y[~naninds] | ||
x = _downsample(x, downsample) | ||
y = _downsample(y, downsample) | ||
|
||
ml = min([len(x),len(y)]) | ||
|
||
x = x[0:ml] | ||
y = y[0:ml] | ||
|
||
if not len(x) == len(y): | ||
raise Exception('mismatched lengths %s and %s' % (len(x), len(y))) | ||
|
||
if func=='info': | ||
result.append(func(x, y,fps=100/downsample)) | ||
else: | ||
result.append(func(x, y)) | ||
|
||
return result | ||
|
||
def _corr(x, y): | ||
return corrcoef(x, y)[0,1] | ||
|
||
def _rank(x, y): | ||
return spearmanr(x, y).correlation | ||
|
||
def _auc(x, y): | ||
fpr, tpr, thresholds = roc_curve(y>0,x) | ||
return auc(fpr,tpr) | ||
|
||
def _downsample(signal, factor): | ||
""" | ||
Downsample signal by averaging neighboring values. | ||
@type signal: array_like | ||
@param signal: one-dimensional signal to be downsampled | ||
@type factor: int | ||
@param factor: this many neighboring values are averaged | ||
@rtype: ndarray | ||
@return: downsampled signal | ||
""" | ||
|
||
if factor < 2: | ||
return np.asarray(signal) | ||
|
||
return np.convolve(np.asarray(signal).ravel(), np.ones(factor), 'valid')[::factor] | ||
|
||
|
||
|
||
|
||
|
||
|
||
def model_test(model, test_dataset): | ||
#model.load_weights('model/model_conv_11_5') | ||
test_ip = test_dataset['calcium signal padded'] | ||
pred_test = model.predict(test_ip) | ||
gt_test = np.reshape(test_dataset['spikes train padded'],(test_ip.shape[0],-1)) | ||
pred_test = np.reshape(pred_test,(test_ip.shape[0],-1)) | ||
corrs = score(pred_test, gt_test, method='corr') | ||
corrs = np.asarray(corrs) | ||
ranks = score(pred_test, gt_test, method='rank') | ||
ranks = np.asarray(ranks) | ||
aucs = score(pred_test, gt_test, method='auc') | ||
aucs = np.asarray(aucs) | ||
measures = [] | ||
for i in range(5): | ||
corre = np.mean(corrs[id_staked_t==i]) | ||
#print(corre) | ||
ranke = np.mean(ranks[id_staked_t==i]) | ||
#print(ranke) | ||
auce = np.mean(aucs[id_staked_t==i]) | ||
#print(auce) | ||
measures.append([corre, ranke, auce]) | ||
return measures | ||
|
||
|
||
def correlation_coefficient_loss(y_true, y_pred): | ||
x = y_true | ||
y = y_pred | ||
mx = K.mean(x, axis=1,keepdims=True) | ||
my = K.mean(y, axis=1,keepdims=True) | ||
xm, ym = x-mx, y-my | ||
r_num = K.sum(xm*ym, axis=1) | ||
r_den = K.sqrt(K.sum(K.square(xm),axis=1) * K.sum(K.square(ym),axis=1)) | ||
r = r_num / r_den | ||
r = K.maximum(K.minimum(r, 1.0), -1.0) | ||
return 1 - K.square(r) | ||
|
||
if __name__== '__main__': | ||
|
||
usage = 'USAGE: %prog model_path' | ||
parser = OptionParser(usage=usage) | ||
opts, args = parser.parse_args() | ||
|
||
if len(args) != 1: | ||
parser.usage += '\n\n' + parser.format_option_help() | ||
parser.error('Wrong number of arguments') | ||
|
||
model = args[0] #model file location | ||
test_dataset = load_data_test() | ||
id_staked_t = test_dataset['ids stacked'] | ||
spike_npt = test_dataset['spikes train'] | ||
|
||
m = load_model (model, compile=False ) | ||
results = model_test (m, test_dataset) | ||
print(results) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from optparse import OptionParser | ||
import keras | ||
import numpy | ||
from datasets import load_data_train | ||
from keras import backend as K | ||
import os | ||
# training script of s2s | ||
|
||
usage = 'USAGE: %python train_s2s.py model_outdir' | ||
|
||
parser = OptionParser(usage=usage) | ||
opts, args = parser.parse_args() | ||
|
||
if len(args) != 1: | ||
parser.usage += '\n\n' + parser.format_option_help() | ||
parser.error('Wrong number of arguments') | ||
|
||
model_dir = args[0] | ||
|
||
train_dataset = load_data_train() | ||
train_ip = train_dataset['calcium signal padded'] | ||
train_op = train_dataset['Gaussian spikes train'] | ||
|
||
inputFeatDim = train_ip.shape[1] | ||
|
||
def correlation_coefficient_loss(y_true, y_pred): | ||
x = y_true | ||
y = y_pred | ||
mx = K.mean(x, axis=1,keepdims=True) | ||
my = K.mean(y, axis=1,keepdims=True) | ||
xm, ym = x-mx, y-my | ||
r_num = K.sum(xm*ym, axis=1) | ||
r_den = K.sqrt(K.sum(K.square(xm),axis=1) * K.sum(K.square(ym),axis=1)) | ||
r = r_num / r_den | ||
r = K.maximum(K.minimum(r, 1.0), -1.0) | ||
return 1 - K.square(r) | ||
|
||
# Initialise learning parameters and models | ||
s = keras.optimizers.Adam(lr=config.learning['rate'], decay=0) | ||
|
||
# Model definition | ||
numpy.random.seed(25) | ||
m = keras.models.Sequential() | ||
m.add(keras.layers.Reshape((100000, 1), input_shape=(100000,1))) | ||
m.add(keras.layers.Conv1D(filters=30, kernel_size=100, strides=1, padding='same', use_bias=False))# kernel_constraint=non_neg() | ||
m.add(keras.layers.Activation('relu')) | ||
|
||
m.add(keras.layers.TimeDistributed(keras.layers.Dense(30), input_shape=(m.output_shape[1], m.output_shape[2]))) | ||
m.add(keras.layers.Activation('relu')) | ||
m.add(keras.layers.Dropout(0.2)) | ||
|
||
m.add(keras.layers.TimeDistributed(keras.layers.Dense(30))) | ||
m.add(keras.layers.Activation('relu')) | ||
m.add(keras.layers.Dropout(0.2)) | ||
|
||
m.add(keras.layers.TimeDistributed(keras.layers.Dense(30))) | ||
m.add(keras.layers.Activation('relu')) | ||
m.add(keras.layers.Dropout(0.2)) | ||
|
||
m.add(keras.layers.core.Lambda (lambda x:K.expand_dims(x, axis=2))) | ||
m.add(keras.layers.Conv2DTranspose(filters=1, kernel_size=(100,1), strides=(1,1), padding='same', use_bias = False)) | ||
m.add(keras.layers.core.Lambda (lambda x:K.squeeze(x, axis=3))) | ||
m.add(keras.layers.core.Lambda (lambda x:x[:,:100000,:])) | ||
m.summary() | ||
|
||
|
||
# training | ||
m.compile(loss=correlation_coefficient_loss, optimizer=s, metrics=['mse']) | ||
r = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=config.learning['lrScale'], patience=4, | ||
verbose=1, min_delta=config.learning['minValError'], cooldown=1, min_lr=config.learning['rate']) | ||
e = keras.callbacks.EarlyStopping(monitor='val_loss', patience=6, verbose=1) | ||
|
||
h = [m.fit(train_ip, train_op, batch_size=20, epochs=100, verbose=2, validation_split=0.2, callbacks=[r,e])] | ||
m.save(os.path.join(model_dir,'model_s2s.h5'), overwrite=True) |