diff --git a/README.md b/README.md index 5472f797..e61d0fb6 100644 --- a/README.md +++ b/README.md @@ -10,14 +10,31 @@ This toolbox will: - For Myomatrix data: - Combine OpenEphys data into a single binary and automatically remove broken channels - Extract and save the sync signal sent from the behavioural task - - Perform spike sorting with a modified version of Kilosort 3.0 (wider templates) + - Perform spike sorting with a modified version of Kilosort 3.0 - Combine similar units, calculate motor unit statistics, export back to phy +## Folder Tree Structure +![Alt text](images/folder_tree_structure.png) + ## Installation ### Requirements +Currently, using a Linux-based OS is recommended. The code has been tested on Ubuntu and CentOS. Windows support is experimental and may require additional changes. + Many processing steps require a CUDA capable GPU. - For Neuropixel data, a GPU with at least 10GB of onboard RAM is recommended - - For Myomatrix data, currently only GPUs with compute capability 8.0, 8.7, or 9.0 are supported due to shared thread memory requirements + - For Myomatrix data, currently only GPUs with compute capability >=5.0 are supported due to shared thread memory requirements + +Required MATLAB Toolboxes: + - Parallel Computing Toolbox + - Signal Processing Toolbox + - Statistics and Machine Learning Toolbox + +Nvidia Driver: + - Linux: >=450.80.02 + - Windows: >=452.39 + +CUDA Toolkit (Automatically installed with micromamba/conda environment): + - 11.3 ### Instructions These installation instructions were tested on the Computational Brain Science Group Server 'CBS GPU 10GB' image, and the Compute Canada servers. They may need to be adjusted if running on another machine type. @@ -26,9 +43,16 @@ Clone a copy of the repository on your local machine (for example, in the home d git clone https://github.com/JonathanAMichaels/PixelProcessingPipeline.git -After cloning, you can either configure a virtualenv or conda environment to run the pipeline +After cloning, you can either configure a virtualenv, conda, or micromamba environment to run the pipeline + +#### Micromamba Environment (Option 1, recommended) +To install micromamba and set up a micromamba environment, follow these steps: -#### Virtual Environment + "${SHELL}" <(curl -L micro.mamba.pm/install.sh) + micromamba env create -f environment.yml + micromamba activate pipeline + +#### Virtual Environment (Option 2) To set up a virtualenv environment, follow these steps: virtualenv ~/pipeline @@ -41,20 +65,26 @@ Install the cupy version that matches your version of nvcc. For example, if runn shows version 10.1, then run pip install cupy-cuda101 - -#### Conda Environment -We can also create a conda environment to run the file as opposed to a virtual environment by following these steps: +#### Conda Environment (Option 3, untested) +To set up a conda environment, follow these steps: + + wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh + bash Miniforge3-Linux-x86_64.sh + conda init conda env create -f environment.yml + conda activate pipeline #### Final Installation Steps -If you are processing Myomatrix data, open matlab and confirm that all mex files compile by running +Open matlab and confirm that all mex files compile by running + WARNING: make sure to activate the pipeline environment before running these commands + matlab -nodesktop cd PixelProcessingPipeline/sorting/Kilosort-3.0/CUDA/ mexGPUall -Compile codes necessary for drift estimation and install supplementary packages +(Optional) Compile codes necessary for drift estimation and install supplementary packages cd PixelProcessingPipeline/registration/spikes_localization_registration python3 setup.py build_ext --inplace @@ -66,11 +96,17 @@ Compile codes necessary for drift estimation and install supplementary packages ## Usage Organize each experiment into one directory with a Neuropixel folder inside (e.g. 041422_g0), a Myomatrix folder (e.g. 2022-04-14_09-48-02_myo, which must have _myo at the end) and any .kinarm data files generated. -The Myomatrix folder must be organized either as 'folder_myo/Record Node ###/continuous/' for binary open ephys data, -or as 'folder_myo/Record Node ###/***.continuous' for open ephys format data. +The Myomatrix folder must be organized either as 'folder_myo/Record Node ###/continuous/' for binary open ephys data, or as 'folder_myo/Record Node ###/***.continuous' for open ephys format data. + +Each time a sort is performed, a new folder will be created in the experiment directory with the date and time of the sort. Inside this folder will be the sorted data, the phy output files, and a copy of the ops used to sort the data. The original OpenEphys data will not be modified. + +#### Micromamba Activation +Every time you open a new terminal, you must activate the environment. If micromamba was used, activate the environment using + + micromamba activate pipeline #### VirtualEnv Activation -Every time you open a new terminal, you must activate current source. If virtualenv was used, activate the source using +If virtualenv was used, activate the source using source ~/pipeline/bin/activate @@ -82,19 +118,47 @@ If a conda environment was used, activate it using #### Final Usage Steps The first time you process an experiment, call - python3 pipeline.py -f /path_to_experiment_folder + python pipeline.py -f "/path/to/sessionYYYYMMDD" -This will generate a config.yaml file in that directory with all the relevant parameters for that experiment generated automatically. Open that file with any text editor and add any session specific information to the Session parameter section. For example, if you collected Myomatrix data you must specify which channels belong to which electrode and which channel contains the sync information, since this information cannot be generated automatically. +This will generate a `config.yaml` file in that directory with all the relevant parameters for that experiment generated automatically. Open that file with any text editor and add any session specific information to the Session parameter section. For example, if you collected Myomatrix data you must specify which channels belong to which electrode and which channel contains the sync information, since this information cannot be generated automatically. +##### Configuration Commands Editing the main configuration file can be done by running the command below: - python3 pipeline.py -f /path_to_experiment_folder -config + python pipeline.py -f "/path/to/sessionYYYYMMDD" -config + +To edit the configuration file for the processing Myomatrix data, run + + python pipeline.py -f "/path/to/sessionYYYYMMDD" -myo_config + +To edit the configuration file for the processing Neuropixel data, run + + python pipeline.py -f "/path/to/sessionYYYYMMDD" -neuro_config + +##### Spike Sorting Commands +To run a sort on the Myomatrix data, run + + python pipeline.py -f "/path/to/sessionYYYYMMDD" -myo_sort + +To run a sort on the Neuropixel data, run + + python pipeline.py -f "/path/to/sessionYYYYMMDD" -neuro_sort -If the config.yaml is correct, you can run the pipeline with all steps, for example +##### Plotting with Phy Command +For plotting the latest myomatrix sort with Phy GUI, run - python3 pipeline.py -f /path_to_experiment_folder -full + python pipeline.py -f "/path/to/sessionYYYYMMDD" -myo_phy -Alternatively, you can call any combination of +For plotting a previously saved myomatrix sort with Phy GUI, call below with the corresponding datestring + + python pipeline.py -f "/path/to/sessionYYYYMMDD" -d YYYYMMDD_HHMMSS -myo_phy + +##### Chaining Commands Together +If the `config.yaml` is correct, you can run the pipeline with all steps, for example + + python pipeline.py -f "/path/to/sessionYYYYMMDD" -full + +Alternatively, you can call any combination of: -config -registration @@ -103,21 +167,18 @@ Alternatively, you can call any combination of -neuro_post -myo_config -myo_sort - -myo_post + -myo_phy -lfp_extract -to perform only those steps. For example, if you are processing Myomatrix data, run +to perform only those steps. For example, if you want to configure and immediately spike sort, run - python3 pipeline.py -f /path_to_experiment_folder -myo_sort -myo_post + python pipeline.py -f "/path/to/sessionYYYYMMDD" -config -myo_config -myo_sort -To edit the configuration file for the processing Myomatrix data, run +If you want to run a grid search over a range of KS parameters, edit the `Kilosort_gridsearch_config.py` +file under the sorting folder to include all variable combinations you want to try. Be aware of the combinatorics so you don't generate more sorts than you expected. Then open the config file and set the gridsearch parameter to True, for example by running - python3 pipeline.py -f /path_to_experiment_folder -myo_config - -To edit the configuration file for the processing Neuropixel data, run + python pipeline.py -f "/path/to/sessionYYYYMMDD" -config -myo_sort - python3 pipeline.py -f /path_to_experiment_folder -neuro_config - ## Extensions This code does not currently process .kinarm files or combine behavioural information with synced neural data. This may be added at a later date. diff --git a/config_template.yaml b/config_template.yaml index 0ab8998c..58b75152 100644 --- a/config_template.yaml +++ b/config_template.yaml @@ -1,6 +1,11 @@ --- # Configuration file # Auto-generated parameters +# GPU to use for kilosort, (a list of integers, e.g., [1,2,5,6]) +GPU_to_use: [0] +# number of Kilosort jobs to run at once (must be <= number of GPUs, will run in parallel if >1) +# setting num_KS_jobs > 1 can only be used when do_KS_param_gridsearch is True +num_KS_jobs: 1 # path of working directory (must be provided in command line) folder: # path to neuropixel directory @@ -9,6 +14,8 @@ neuropixel: num_neuropixels: # path to myomatrix directory (should end in _myo) myomatrix: +# specify recordings to process (can be [all] or a list of integers, e.g., [1,2,5,6]) +recordings: [1] # whether to concatenate myomatrix data concatenate_myo_data: False # set bandpass filter settings for myomatrix data @@ -25,24 +32,31 @@ script_dir: Registration: # Sorting parameters Sorting: + num_KS_components: 9 + do_KS_param_gridsearch: False # Session-specific parameters Session: trange: - 0 - 0 - myo_chan_map_file: - - bipolar_test_kilosortChanMap.mat - - bipolar_test_kilosortChanMap.mat + myo_chan_map_file: + - linear_16ch_RF400_kilosortChanMap_unitSpacing.mat + - linear_16ch_RF400_kilosortChanMap_unitSpacing.mat myo_chan_list: - [1, 16] - [17, 32] - # Remove bad channels from myomatrix data - # This can be a sequence of booleans or a sequence of lists: - # True for automatic bad channel removal, False to include all channels. - # Provide an integer list of channels to remove those channels for that session, e.g., [1,2,3,4] + ## Remove bad channels from myomatrix data + # This can be a sequence of booleans, of strings, or of lists: + # Booleans: True for automatic bad channel removal (defaults to reject below median), False to include all channels. + # Strings: Provide 'median', 'mean', 'mean-1std', 'percentileXX', or 'lowestYY' (XX,YY are numeric, 0 1: - raise SystemExit("There shouldn't be more than one config file in here (something went wrong)") + raise SystemExit( + "There shouldn't be more than one config file in here (something went wrong)" + ) elif len(config_file) == 0: - print('No config file found - creating one now') + print("No config file found - creating one now") create_config(script_folder, folder) - config_file = find('config.yaml', folder) + config_file = find("config.yaml", folder) config_file = config_file[0] if config: - if os.name == 'posix': # detect Unix + if os.name == "posix": # detect Unix subprocess.run(f"nano {config_file}", shell=True, check=True) - print('Configuration done.') - elif os.name == 'nt': # detect Windows + print("Configuration done.") + elif os.name == "nt": # detect Windows subprocess.run(f"notepad {config_file}", shell=True, check=True) - print('Configuration done.') - + print("Configuration done.") + # Load config -print('Using config file ' + config_file) -config = yaml.load(open(config_file, 'r'), Loader=yaml.RoundTripLoader) +print("Using config file " + config_file) +# make round-trip loader +yaml = YAML() +with open(config_file) as f: + config = yaml.load(f) # Check config for missing information and attempt to auto-fill -config['folder'] = folder +config["folder"] = folder temp_folder = glob.glob(folder + "/*_g0") if len(temp_folder) > 1: raise SystemExit("There shouldn't be more than one Neuropixel folder") elif len(temp_folder) == 0: - print('No Neuropixel data in this recording session') - config['neuropixel'] = '' + print("No Neuropixel data in this recording session") + config["neuropixel"] = "" else: if os.path.isdir(temp_folder[0]): - config['neuropixel'] = temp_folder[0] + config["neuropixel"] = temp_folder[0] else: raise SystemExit("Provided folder is not valid") -if config['neuropixel'] != '': - temp_folder = glob.glob(config['neuropixel'] + '/' + '*_g*') - config['num_neuropixels'] = len(temp_folder) - print('Using neuropixel folder ' + config['neuropixel'] + ' containing ' + - str(config['num_neuropixels']) + ' neuropixel') +if config["neuropixel"] != "": + temp_folder = glob.glob(config["neuropixel"] + "/" + "*_g*") + config["num_neuropixels"] = len(temp_folder) + print( + "Using neuropixel folder " + + config["neuropixel"] + + " containing " + + str(config["num_neuropixels"]) + + " neuropixel" + ) else: - config['num_neuropixels'] = 0 -temp_folder = glob.glob(folder + '/*_myo') + config["num_neuropixels"] = 0 + +temp_folder = glob.glob(folder + "/*_myo") if len(temp_folder) > 1: SystemExit("There shouldn't be more than one Myomatrix folder") elif len(temp_folder) == 0: - print('No Myomatrix data in this recording session') - config['myomatrix'] = '' + print("No Myomatrix data in this recording session") + config["myomatrix"] = "" else: if os.path.isdir(temp_folder[0]): - config['myomatrix'] = temp_folder[0] -if config['myomatrix'] != '': - print('Using myomatrix folder ' + config['myomatrix']) + config["myomatrix"] = temp_folder[0] + +# ensure global fields are present in config +if config["myomatrix"] != "": + print("Using myomatrix folder " + config["myomatrix"]) +if not "GPU_to_use" in config: + config["GPU_to_use"] = [0] +if not "num_KS_jobs" in config: + config["num_KS_jobs"] = 1 +if not "recordings" in config: + config["recordings"] = [1] if not "concatenate_myo_data" in config: - config['concatenate_myo_data'] = False + config["concatenate_myo_data"] = False if not "myo_data_passband" in config: - config['myo_data_passband'] = [250, 5000] + config["myo_data_passband"] = [250, 5000] if not "myo_data_sampling_rate" in config: - config['myo_data_sampling_rate'] = 30000 -if not "remove_bad_myo_chans" in config: - config['remove_bad_myo_chans'] = False - + config["myo_data_sampling_rate"] = 30000 +# ensure Sorting fields are present in config +if not "num_KS_components" in config["Sorting"]: + config["Sorting"]["num_KS_components"] = 9 +if not "do_KS_param_gridsearch" in config["Sorting"]: + config["Sorting"]["do_KS_param_gridsearch"] = False +# ensure Session fields are present in config +if not "myo_chan_map_file" in config["Session"]: + config["myo_chan_map_file"] = [ + ["linear_16ch_RF400_kilosortChanMap_unitSpacing.mat"] + ] +if not "myo_chan_list" in config["Session"]: + config["Session"]["myo_chan_list"] = [[1, 16]] +if not "myo_analog_chan" in config["Session"]: + config["Session"]["myo_analog_chan"] = 17 +if not "myo_muscle_list" in config["Session"]: + config["Session"]["myo_muscle_list"] = [ + ["Muscle" + str(i) for i in range(len(config["Session"]["myo_chan_list"]))] + ] +if not "remove_bad_myo_chans" in config["Session"]: + config["Session"]["remove_bad_myo_chans"] = [False] * len( + config["Session"]["myo_chan_list"] + ) +if not "remove_channel_delays" in config["Session"]: + config["Session"]["remove_channel_delays"] = [False] * len( + config["Session"]["myo_chan_list"] + ) + +# input assertions +assert ( + config["num_KS_jobs"] >= 1 +), "Number of parallel jobs must be greater than or equal to 1" +assert config["recordings"][0] == "all" or all( + [ + (item == round(item) >= 1 and isinstance(item, (int, float))) + for item in config["recordings"] + ] +), "'recordings' field must be a list of positive integers, or 'all' as first element" +assert all( + [(item >= 0 and isinstance(item, int)) for item in config["GPU_to_use"]] +), "'GPU_to_use' field must be greater than or equal to 0" +assert ( + config["num_neuropixels"] >= 0 +), "Number of neuropixels must be greater than or equal to 0" +assert ( + config["Sorting"]["num_KS_components"] >= 1 +), "Number of KS components must be greater than or equal to 1" +assert ( + config["myo_data_sampling_rate"] >= 1 +), "Myomatrix sampling rate must be greater than or equal to 1" + + +# use -d option to specify which sort folder to post-process +if "-d" in opts: + date_str = args[1] + # make sure date_str is in the format YYYYMMDD_HHMMSS, YYYYMMDD_HHMMSSsss, or YYYYMMDD_HHMMSSffffff + assert ( + (len(date_str) == 15 or len(date_str) == 18 or len(date_str) == 21) + & date_str[:8].isnumeric() + & date_str[9:].isnumeric() + & (date_str[8] == "_") + ), "Argument after '-d' must be a date string in format: YYYYMMDD_HHMMSS, YYYYMMDD_HHMMSSsss, or YYYYMMDD_HHMMSSffffff" + # check if date_str is present in any of the subfolders in the config["myomatrix"] path + subfolder_list = os.listdir(config["myomatrix"]) + previous_sort_folder_to_use = [ + iFolder for iFolder in subfolder_list if date_str in iFolder + ] + assert ( + len(previous_sort_folder_to_use) > 0 + ), f'No matching subfolder found in {config["myomatrix"]} for the date string provided' + assert ( + len(previous_sort_folder_to_use) < 2 + ), f'Multiple matching subfolders found in {config["myomatrix"]} for the date string provided. Try using a more specific date string, like "YYYYMMDD_HHMMSSffffff"' + previous_sort_folder_to_use = str(previous_sort_folder_to_use[0]) +else: + if config["num_KS_jobs"] == 1: + if "-myo_phy" in opts or "-myo_post" in opts: + try: + previous_sort_folder_to_use = str( + scipy.io.loadmat(f'{config["myomatrix"]}/sorted0/ops.mat')[ + "final_myo_sorted_dir" + ][0] + ) + except FileNotFoundError: + print( + "WARNING: No ops.mat file found in sorted0 folder, not able to detect previous sort folder.\n" + " If using '-myo_phy' or '-myo_post', try using the '-d' flag to specify the datestring\n" + ) + except KeyError: + print( + "WARNING: No 'final_myo_sorted_dir' field found in ops.mat file, not able to detect previous sort folder.\n" + " If using '-myo_phy' or '-myo_post', try using the '-d' flag to specify the datestring\n" + ) + except: + raise + else: + if "-myo_phy" in opts or "-myo_post" in opts: + raise SystemExit( + "Cannot guess desired previous sort folder after parallel sorting. Please specify manually using the '-d' flag" + ) + # find MATLAB installation -if os.path.isfile('/usr/local/MATLAB/R2021a/bin/matlab'): - matlab_root = '/usr/local/MATLAB/R2021a/bin/matlab' # something else for testing locally -elif os.path.isfile('/srv/software/matlab/R2021b/bin/matlab'): - matlab_root = '/srv/software/matlab/R2021b/bin/matlab' +if os.path.isfile("/usr/local/MATLAB/R2021a/bin/matlab"): + matlab_root = ( + "/usr/local/MATLAB/R2021a/bin/matlab" # something else for testing locally + ) +elif os.path.isfile("/srv/software/matlab/R2021b/bin/matlab"): + matlab_root = "/srv/software/matlab/R2021b/bin/matlab" else: - matlab_path = glob.glob('/usr/local/MATLAB/R*') - matlab_root = matlab_path[0] + '/bin/matlab' - -# Search myomatrix folder for existing concatenated_data folder, if it exists, it will be used -concatDataPath = find('concatenated_data', config['myomatrix']) -if len(concatDataPath) > 1: - raise SystemExit("There shouldn't be more than one concatenated_data folder inside the myomatrix data folder") -elif len(concatDataPath) < 1 & config['concatenate_myo_data']: - #no concatenated data folder was found - print("No concatenated files found, concatenating data from data in recording folders") - path_to_add = script_folder + '/sorting/myomatrix/' - os.system(matlab_root + ' -nodisplay -nosplash -nodesktop -r "addpath(\'' + - path_to_add + f'\'); concatenate_myo_data(\'{config["myomatrix"]}\')"') - concatDataPath = find('concatenated_data', config['myomatrix']) - -temp = glob.glob(folder + '/*.kinarm') + matlab_path = glob.glob("/usr/local/MATLAB/R*") + matlab_root = matlab_path[0] + "/bin/matlab" + +if config["concatenate_myo_data"]: + # If concatenate_myo_data is set to true, search myomatrix folder for existing concatenated_data + # folder, if it exists, check subfolder names to see if they match the recording numbers + # specified in the config file. If they don't, create a new subfolder + # and concatenate the data into that folder. If they do, ensure that the continuous.dat file + # exists in the continuous/ folder for the matching recordings_str folder. If it doesn't, + # create a new subfolder and concatenate the data into that folder. + concatDataPath = find("concatenated_data", config["myomatrix"]) + if config["recordings"][0] == "all": + Record_Node_dir_list = [ + iDir for iDir in os.listdir(config["myomatrix"]) if "Record Node" in iDir + ] + assert ( + len(Record_Node_dir_list) == 1 + ), "Please remove all but one 'Record Node ###' folder" + Record_Node_dir = Record_Node_dir_list[0] + Experiment_dir_list = [ + iDir + for iDir in os.listdir(os.path.join(config["myomatrix"], Record_Node_dir)) + if iDir.startswith("experiment") + ] + assert ( + len(Experiment_dir_list) == 1 + ), "Please remove all but one 'experiment#' folder" + Experiment_dir = Experiment_dir_list[0] + recordings_dir_list = [ + iDir + for iDir in os.listdir( + os.path.join(config["myomatrix"], Record_Node_dir, Experiment_dir) + ) + if iDir.startswith("recording") + ] + recordings_dir_list = [ + int(i[9:]) for i in recordings_dir_list if i.startswith("recording") + ] + config["recordings"] = recordings_dir_list + recordings_str = ",".join([str(i) for i in config["recordings"]]) + + if len(concatDataPath) > 1: + raise SystemExit( + "There shouldn't be more than one concatenated_data folder in the myomatrix data folder" + ) + elif len(concatDataPath) == 1: + exact_match_recording_folder = [ + iFolder + for iFolder in os.listdir(concatDataPath[0]) + if recordings_str == iFolder + ] + if len(exact_match_recording_folder) == 1: + # now check in the continuous/ folder for the 'Acquisition_Board-100.Rhythm Data' or + # 'Rhythm_FPGA-100.0' folder, which should contain the concatenated data + continuous_folder = os.path.join( + concatDataPath[0], exact_match_recording_folder[0], "continuous" + ) + rhythm_folder = [ + iFolder + for iFolder in os.listdir(continuous_folder) + if "Rhythm" in iFolder + ] + if len(rhythm_folder) == 1: + continuous_dat_folder = os.path.join( + continuous_folder, rhythm_folder[0] + ) + # check if continuous.dat file exists in the continuous_dat_folder folder + if "continuous.dat" in os.listdir(continuous_dat_folder): + continuous_dat_is_present = True + else: + continuous_dat_is_present = False + else: + raise SystemExit( + f"There should be exactly one '*Rhythm*' folder in {continuous_folder} folder" + f"concatenated data, but found {len(rhythm_folder)}\n" + f"{rhythm_folder}" + ) + else: + continuous_dat_is_present = False + + # elif concatenating data and no continuous.dat file found in the concatenated_data folder for the + # matching recordings_str folder + if len(concatDataPath) < 1 or not continuous_dat_is_present: + print( + "Concatenated files not found, concatenating data from data in chosen recording folders" + ) + path_to_add = script_folder + "/sorting/myomatrix/" + subprocess.run( + [ + "matlab", + "-nodesktop", + "-nodisplay", + "-nosplash", + "-r", + "rehash toolboxcache; restoredefaultpath;" + f"addpath(genpath('{path_to_add}')); concatenate_myo_data('{config['myomatrix']}', {{{config['recordings']}}})", + ], + check=True, + ) + concatDataPath = find("concatenated_data", config["myomatrix"]) + print( + f"Using newly concatenated data at {concatDataPath[0]+'/'+recordings_str}" + ) + # elif setting is enabled and concatenated data was found with requested recordings present + elif recordings_str in os.listdir(concatDataPath[0]): + print( + f"Using existing concatenated data at {concatDataPath[0]+'/'+recordings_str}" + ) + concatDataPath = concatDataPath[0] + "/" + recordings_str +else: + print("Not concatenating data") + assert ( + len(config["recordings"]) == 1 + ), "Only one recording can be specified in recordings list if concatenate_myo_data is False" + recordings_str = str(config["recordings"][0]) + +# set chosen GPUs in environment variable +GPU_str = ",".join([str(i) for i in config["GPU_to_use"]]) +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = GPU_str + +temp = glob.glob(folder + "/*.kinarm") if len(temp) == 0: - print('No kinarm data in this recording session') - config['kinarm'] = '' + print("No kinarm data in this recording session") + config["kinarm"] = "" else: - config['kinarm'] = temp -if config['kinarm'] != '': - print('Found kinarm data files') -config['script_dir'] = script_folder + config["kinarm"] = temp +if config["kinarm"] != "": + print("Found kinarm data files") +config["script_dir"] = script_folder if in_cluster: - config['in_cluster'] = True + config["in_cluster"] = True else: - config['in_cluster'] = False -config['registration_final'] = registration_final + config["in_cluster"] = False +config["registration_final"] = registration_final # Save config file with up-to-date information -yaml.dump(config, open(config_file, 'w'), Dumper=yaml.RoundTripDumper) +with open(config_file, "w") as f: + yaml.dump(config, f) # Proceed with registration if registration or registration_final: registration_function(config) # Prepare common kilosort config -config_kilosort = yaml.safe_load(open(config_file, 'r')) -config_kilosort['myomatrix_number'] = 1 -config_kilosort['channel_list'] = 1 +with open(config_file) as f: + config_kilosort = yaml.load(f) +config_kilosort["myomatrix_number"] = 1 +config_kilosort["channel_list"] = 1 # if f"{config['script_dir']}/tmp" folder does not exist if not os.path.isdir(f"{config['script_dir']}/tmp"): os.mkdir(f"{config['script_dir']}/tmp") +# Convenience function to edit neuro sorting config file if neuro_config: - if os.name == 'posix': # detect Unix - subprocess.run(f"nano {config['script_dir']}/sorting/Kilosort_run.m", shell=True, check=True) - subprocess.run(f"nano {config['script_dir']}/sorting/resorter/neuropixel_call.m", shell=True, check=True) + if os.name == "posix": # detect Unix + subprocess.run( + f"nano {config['script_dir']}/sorting/Kilosort_run.m", + shell=True, + check=True, + ) print('Configuration for "-neuro_sort" done.') - elif os.name == 'nt': # detect Windows - subprocess.run(f"notepad {config['script_dir']}/sorting/Kilosort_run.m", shell=True, check=True) + subprocess.run( + f"nano {config['script_dir']}/sorting/resorter/neuropixel_call.m", + shell=True, + check=True, + ) + print('Configuration for "-neuro_post" done.') + elif os.name == "nt": # detect Windows + subprocess.run( + f"notepad {config['script_dir']}/sorting/Kilosort_run.m", + shell=True, + check=True, + ) print('Configuration for "-neuro_sort" done.') # Proceed with neural spike sorting if neuro_sort: - config_kilosort = {'script_dir': config['script_dir'], 'trange': np.array(config['Session']['trange'])} - config_kilosort['type'] = 1 - neuro_folders = glob.glob(config['neuropixel'] + '/*_g*') - path_to_add = script_folder + '/sorting/' - for pixel in range(config['num_neuropixels']): - config_kilosort['neuropixel_folder'] = neuro_folders[pixel] - tmp = glob.glob(neuro_folders[pixel] + '/*_t*.imec' + str(pixel) + '.ap.bin') - config_kilosort['neuropixel'] = tmp[0] - if len(find('sync.mat', config_kilosort['neuropixel_folder'])) > 0: - print('Found existing sync file') + config_kilosort = { + "script_dir": config["script_dir"], + "trange": np.array(config["Session"]["trange"]), + } + config_kilosort["type"] = 1 + neuro_folders = glob.glob(config["neuropixel"] + "/*_g*") + path_to_add = script_folder + "/sorting/" + for pixel in range(config["num_neuropixels"]): + config_kilosort["neuropixel_folder"] = neuro_folders[pixel] + tmp = glob.glob(neuro_folders[pixel] + "/*_t*.imec" + str(pixel) + ".ap.bin") + config_kilosort["neuropixel"] = tmp[0] + if len(find("sync.mat", config_kilosort["neuropixel_folder"])) > 0: + print("Found existing sync file") else: - print('Extracting sync signal from ' + config_kilosort['neuropixel'] + ' and saving') + print( + "Extracting sync signal from " + + config_kilosort["neuropixel"] + + " and saving" + ) extract_sync(config_kilosort) - #print('Starting drift correction of ' + config_kilosort['neuropixel']) - #kilosort(config_kilosort) + # print('Starting drift correction of ' + config_kilosort['neuropixel']) + # kilosort(config_kilosort) - print('Starting spike sorting of ' + config_kilosort['neuropixel']) + print("Starting spike sorting of " + config_kilosort["neuropixel"]) scipy.io.savemat(f"{config['script_dir']}/tmp/config.mat", config_kilosort) - # os.system(matlab_root + ' -nodisplay -nosplash -nodesktop -r "addpath(\'' + - # path_to_add + '\'); Kilosort_run"') - subprocess.run(["matlab", "-nodisplay", "-nosplash", "-nodesktop", "-r", - f"addpath(genpath('{path_to_add}')); Kilosort_run_czuba"], check=True) + subprocess.run( + [ + "matlab", + "-nodisplay", + "-nosplash", + "-nodesktop", + "-r", + f"addpath(genpath('{path_to_add}')); Kilosort_run_czuba", + ], + check=True, + ) - print('Starting alf post-processing of ' + config_kilosort['neuropixel']) - alf_dir = Path(config_kilosort['neuropixel_folder'] + '/sorted/alf') + print("Starting alf post-processing of " + config_kilosort["neuropixel"]) + alf_dir = Path(config_kilosort["neuropixel_folder"] + "/sorted/alf") shutil.rmtree(alf_dir, ignore_errors=True) - ks_dir = Path(config_kilosort['neuropixel_folder'] + '/sorted') - ks2_to_alf(ks_dir, Path(config_kilosort['neuropixel']), alf_dir) + ks_dir = Path(config_kilosort["neuropixel_folder"] + "/sorted") + ks2_to_alf(ks_dir, Path(config_kilosort["neuropixel"]), alf_dir) # Proceed with neuro post-processing if neuro_post: - config_kilosort = {'script_dir': config['script_dir']} - neuro_folders = glob.glob(config['neuropixel'] + '/*_g*') - path_to_add = script_folder + '/sorting/' - for pixel in range(config['num_neuropixels']): - config_kilosort['neuropixel_folder'] = neuro_folders[pixel] + '/kilosort2/sorter_output' + config_kilosort = {"script_dir": config["script_dir"]} + neuro_folders = glob.glob(config["neuropixel"] + "/*_g*") + path_to_add = script_folder + "/sorting/" + for pixel in range(config["num_neuropixels"]): + config_kilosort["neuropixel_folder"] = ( + neuro_folders[pixel] + "/kilosort2/sorter_output" + ) scipy.io.savemat(f"{config['script_dir']}/tmp/config.mat", config_kilosort) - # os.system(matlab_root + ' -nodisplay -nosplash -nodesktop -r "addpath(genpath(\'' + - # path_to_add + '\')); neuropixel_call"') - subprocess.run(["matlab", "-nodisplay", "-nosplash", "-nodesktop", "-r", - f"addpath(genpath('{path_to_add}')); neuropixel_call"], check=True) + subprocess.run( + [ + "matlab", + "-nodisplay", + "-nosplash", + "-nodesktop", + "-r", + f"addpath(genpath('{path_to_add}')); neuropixel_call", + ], + check=True, + ) if myo_config: - if os.name == 'posix': # detect Unix - subprocess.run(f"nano {config['script_dir']}/sorting/Kilosort_run_myo_3.m", shell=True, check=True) - subprocess.run(f"nano {config['script_dir']}/sorting/resorter/myomatrix_call.m", shell=True, check=True) + if os.name == "posix": # detect Unix + subprocess.run( + f"nano {config['script_dir']}/sorting/Kilosort_run_myo_3_czuba.m", + shell=True, + check=True, + ) print('Configuration for "-myo_sort" done.') - elif os.name == 'nt': # detect Windows - subprocess.run(f"notepad {config['script_dir']}/sorting/Kilosort_run_myo_3.m", shell=True, check=True) + subprocess.run( + f"nano {config['script_dir']}/sorting/resorter/myomatrix_call.m", + shell=True, + check=True, + ) + print('Configuration for "-myo_post" done.') + elif os.name == "nt": # detect Windows + subprocess.run( + f"notepad {config['script_dir']}/sorting/Kilosort_run_myo_3_czuba.m", + shell=True, + check=True, + ) print('Configuration for "-myo_sort" done.') # Proceed with myo processing and spike sorting if myo_sort: - config_kilosort = {'myomatrix': config['myomatrix'], 'script_dir': config['script_dir'], - 'myo_data_passband': np.array(config['myo_data_passband'],dtype=float), - 'myo_data_sampling_rate': float(config['myo_data_sampling_rate']), - 'trange': np.array(config['Session']['trange']), - 'sync_chan': int(config['Session']['myo_analog_chan'])} - path_to_add = script_folder + '/sorting/' - for myomatrix in range(len(config['Session']['myo_chan_list'])): - if len(concatDataPath)==1: - config_kilosort['myomatrix_data'] = concatDataPath - print(f"Using concatenated data from: {concatDataPath[0]}") + config_kilosort = { + "GPU_to_use": np.array(config["GPU_to_use"], dtype=int), + "num_KS_jobs": int(config["num_KS_jobs"]), + "myomatrix": config["myomatrix"], + "script_dir": config["script_dir"], + "recordings": np.array(config["recordings"], dtype=int) + if type(config["recordings"][0]) != str + else config["recordings"], + "myo_data_passband": np.array(config["myo_data_passband"], dtype=float), + "myo_data_sampling_rate": float(config["myo_data_sampling_rate"]), + "num_KS_components": np.array( + config["Sorting"]["num_KS_components"], dtype=int + ), + "trange": np.array(config["Session"]["trange"]), + "sync_chan": int(config["Session"]["myo_analog_chan"]), + } + path_to_add = script_folder + "/sorting/" + for myomatrix in range(len(config["Session"]["myo_chan_list"])): + if config["concatenate_myo_data"]: + config_kilosort["myomatrix_data"] = concatDataPath else: - f = glob.glob(config_kilosort['myomatrix'] + '/Record*') - config_kilosort['myomatrix_data'] = f[0] + # find match to recording folder using recordings_str + f = find("recording" + str(config["recordings"][0]), config["myomatrix"]) + config_kilosort["myomatrix_data"] = f[0] print(f"Using data from: {f[0]}") - config_kilosort['myo_sorted_dir'] = config_kilosort['myomatrix'] + '/sorted' + str(myomatrix) - config_kilosort['myomatrix_num'] = myomatrix - config_kilosort['myo_chan_map_file'] = os.path.join(config['script_dir'],'geometries', - config['Session']['myo_chan_map_file'][myomatrix]) - config_kilosort['chans'] = np.array(config['Session']['myo_chan_list'][myomatrix]) - config_kilosort['remove_bad_myo_chans'] = np.array(config['Session']['remove_bad_myo_chans'][myomatrix]) - config_kilosort['num_chans'] = config['Session']['myo_chan_list'][myomatrix][1] - \ - config['Session']['myo_chan_list'][myomatrix][0] + 1 - + config_kilosort["myo_sorted_dir"] = ( + config_kilosort["myomatrix"] + "/sorted" + str(myomatrix) + ) + config_kilosort["myomatrix_num"] = myomatrix + config_kilosort["myo_chan_map_file"] = os.path.join( + config["script_dir"], + "geometries", + config["Session"]["myo_chan_map_file"][myomatrix], + ) + config_kilosort["chans"] = np.array( + config["Session"]["myo_chan_list"][myomatrix] + ) + config_kilosort["remove_bad_myo_chans"] = np.array( + config["Session"]["remove_bad_myo_chans"][myomatrix] + ) + config_kilosort["remove_channel_delays"] = np.array( + config["Session"]["remove_channel_delays"][myomatrix] + ) + config_kilosort["num_chans"] = ( + config["Session"]["myo_chan_list"][myomatrix][1] + - config["Session"]["myo_chan_list"][myomatrix][0] + + 1 + ) scipy.io.savemat(f"{config['script_dir']}/tmp/config.mat", config_kilosort) - shutil.rmtree(config_kilosort['myo_sorted_dir'], ignore_errors=True) - # os.system(matlab_root + ' -nodisplay -nosplash -nodesktop -r "addpath(genpath(\'' + - # path_to_add + '\')); myomatrix_binary"') - subprocess.run(["matlab", "-nodisplay", "-nosplash", "-nodesktop", "-r", - f"addpath(genpath('{path_to_add}')); myomatrix_binary"], check=True) - - print('Starting spike sorting of ' + config_kilosort['myo_sorted_dir']) + shutil.rmtree(config_kilosort["myo_sorted_dir"], ignore_errors=True) + subprocess.run( + [ + "matlab", + "-nodisplay", + "-nosplash", + "-nodesktop", + "-r", + ( + "rehash toolboxcache; restoredefaultpath;" + f"addpath(genpath('{path_to_add}')); myomatrix_binary" + ), + ], + check=True, + ) + scipy.io.savemat(f"{config['script_dir']}/tmp/config.mat", config_kilosort) - # os.system(matlab_root + ' -nodisplay -nosplash -nodesktop -r "addpath(\'' + - # path_to_add + '\'); Kilosort_run_myo_3"') - subprocess.run(["matlab", "-nodisplay", "-nosplash", "-nodesktop", "-r", - f"addpath(genpath('{path_to_add}')); Kilosort_run_myo_3"], check=True) - # extract waveforms for Phy FeatureView - subprocess.run(["phy", "extract-waveforms", "params.py"],cwd=f"{config_kilosort['myo_sorted_dir']}", check=True) + + # check if user wants to do grid search of KS params + if config["Sorting"]["do_KS_param_gridsearch"] == 1: + iParams = list( + get_KS_params_grid() + ) # get iterator of all possible param combinations + else: + # just pass an empty string to run once with chosen params + iParams = [""] + + worker_ids = np.arange(config["num_KS_jobs"]) + # create new folders if running in parallel + if config["num_KS_jobs"] > 1: + # ensure proper configuration for parallel jobs + assert config["num_KS_jobs"] <= len( + config["GPU_to_use"] + ), "Number of parallel jobs must be less than or equal to number of GPUs" + assert ( + config["Sorting"]["do_KS_param_gridsearch"] == 1 + ), "Parallel jobs can only be used when do_KS_param_gridsearch is set to True" + # create new folder for each parallel job to store results temporarily + for i in worker_ids: + # create new folder for each parallel job + new_sorted_dir = config_kilosort["myo_sorted_dir"] + str(i) + if os.path.isdir(new_sorted_dir): + shutil.rmtree(new_sorted_dir, ignore_errors=True) + shutil.copytree(config_kilosort["myo_sorted_dir"], new_sorted_dir) + # split iParams according to number of parallel jobs + iParams_split = np.array_split(iParams, config["num_KS_jobs"]) + + def run_KS_sorting(iParams, worker_id): + iParams = iter(iParams) + os.environ["CUDA_VISIBLE_DEVICES"] = str(config["GPU_to_use"][worker_id]) + save_path = ( + f"{config_kilosort['myo_sorted_dir']}{worker_id}" + if config["num_KS_jobs"] > 1 + else config_kilosort["myo_sorted_dir"] + ) + print( + f"Starting spike sorting of {save_path} on GPU {config['GPU_to_use'][worker_id]}" + ) + worker_id = str(worker_id) + with tempfile.TemporaryDirectory( + suffix=f"_worker{worker_id}" + ) as worker_dir: + while True: + # while no exhaustion of iterator + try: + these_params = next(iParams) + if type(these_params) == dict: + print( + f"Using these KS params from Kilosort_gridsearch_config.py" + ) + print(these_params) + param_keys = list(these_params.keys()) + param_keys_str = [f"'{k}'" for k in param_keys] + param_vals = list(these_params.values()) + zipped_params = zip(param_keys_str, param_vals) + flattened_params = itertools.chain.from_iterable( + zipped_params + ) + # this is a comma-separated string of key-value pairs + passable_params = ",".join(str(p) for p in flattened_params) + elif type(these_params) == str: + print(f"Using KS params from Kilosort_run_myo_3.m") + passable_params = ( + these_params # this is a string: 'default' + ) + else: + print("ERROR: KS params must be a dictionary or a string.") + raise TypeError + if config["Sorting"]["do_KS_param_gridsearch"] == 1: + command_str = f"Kilosort_run_myo_3_czuba(struct({passable_params}),{worker_id},'{str(worker_dir)}');" + else: + command_str = f"Kilosort_run_myo_3_czuba('{passable_params}',{worker_id},'{str(worker_dir)}');" + subprocess.run( + [ + "matlab", + "-nosplash", + "-nodesktop", + "-r", + ( + "rehash toolboxcache; restoredefaultpath;" + f"addpath(genpath('{path_to_add}'));" + f"{command_str}" + ), + ], + check=True, + ) + # extract waveforms for Phy FeatureView + subprocess.run( + # "phy extract-waveforms params.py", + [ + "phy", + "extract-waveforms", + "params.py", + ], + cwd=save_path, + check=True, + ) + # get number of good units and total number of clusters from rez.mat + rez = scipy.io.loadmat(f"{save_path}/rez.mat") + num_KS_clusters = str(len(rez["good"])) + # sum the 1's in the good field of ops.mat to get number of good units + num_good_units = str(sum(rez["good"])[0]) + brokenChan = scipy.io.loadmat(f"{save_path}/brokenChan.mat")[ + "brokenChan" + ] + goodChans = np.setdiff1d(np.arange(1, 17), brokenChan) + goodChans_str = ",".join(str(i) for i in goodChans) + + ## TEMP - remove this later: append git branch name to final_filename + # get git branch name + git_branches = subprocess.run( + ["git", "branch"], capture_output=True, text=True + ) + git_branches = git_branches.stdout.split("\n") + git_branches = [i.strip() for i in git_branches] + git_branch = [i for i in git_branches if i.startswith("*")][0][ + 2: + ] + + # remove spaces and single quoutes from passable_params string + time_stamp_us = datetime.datetime.now().strftime( + "%Y%m%d_%H%M%S%f" + ) + filename_friendly_params = passable_params.replace( + "'", "" + ).replace(" ", "") + final_filename = ( + f"sorted{str(myomatrix)}" + f"_{time_stamp_us}" + f"_rec-{recordings_str}" + # f"_chans-{goodChans_str}" + # f"_{num_good_units}-good-of-{num_KS_clusters}-total" + f"_{filename_friendly_params}" + # f"_{git_branch}" + ) + # remove trailing underscore if present + final_filename = ( + final_filename[:-1] + if final_filename.endswith("_") + else final_filename + ) + # store final_filename in a new ops.mat field in the sorted0 folder + ops = scipy.io.loadmat(f"{save_path}/ops.mat") + ops.update({"final_myo_sorted_dir": final_filename}) + scipy.io.savemat(f"{save_path}/ops.mat", ops) + + # copy sorted0 folder tree to a new folder with timestamp to label results by params + # this serves as a backup of the sorted0 data, so it can be loaded into Phy later + shutil.copytree( + save_path, + Path(save_path).parent.joinpath(final_filename), + ) + + except StopIteration: + if config["Sorting"]["do_KS_param_gridsearch"] == 1: + print(f"Grid search complete for worker {worker_id}") + return # exit the function + except: + if config["Sorting"]["do_KS_param_gridsearch"] == 1: + print("Error in grid search.") + else: + print("Error in sorting.") + raise # re-raise the exception + + if config["num_KS_jobs"] > 1: + # run parallel jobs + with concurrent.futures.ProcessPoolExecutor() as executor: + executor.map(run_KS_sorting, iParams_split, worker_ids) + else: + # run single job + run_KS_sorting(iParams, worker_ids[0]) # Proceed with myo post-processing if myo_post: - config_kilosort = {'script_dir': config['script_dir'], 'myomatrix': config['myomatrix']} - path_to_add = script_folder + '/sorting/' - for myomatrix in range(len(config['Session']['myo_chan_list'])): - f = glob.glob(config_kilosort['myomatrix'] + '/Record*') - - config_kilosort['myo_sorted_dir'] = config_kilosort['myomatrix'] + '/sorted' + str(myomatrix) - config_kilosort['myo_chan_map_file'] = os.path.join(config['script_dir'],'geometries', - config['Session']['myo_chan_map_file'][myomatrix]) - config_kilosort['remove_bad_myo_chans'] = np.array(config['Session']['remove_bad_myo_chans'][myomatrix]) - config_kilosort['num_chans'] = config['Session']['myo_chan_list'][myomatrix][1] - \ - config['Session']['myo_chan_list'][myomatrix][0] + 1 + config_kilosort = { + "script_dir": config["script_dir"], + "myomatrix": config["myomatrix"], + "GPU_to_use": config["GPU_to_use"], + } + path_to_add = script_folder + "/sorting/" + for myomatrix in range(len(config["Session"]["myo_chan_list"])): + f = glob.glob(config_kilosort["myomatrix"] + "/Record*") + + config_kilosort["myo_sorted_dir"] = ( + (config_kilosort["myomatrix"] + "/sorted" + str(myomatrix)) + if "-d" not in opts + else (config_kilosort["myomatrix"] + "/" + previous_sort_folder_to_use) + ) + config_kilosort["myo_chan_map_file"] = os.path.join( + config["script_dir"], + "geometries", + config["Session"]["myo_chan_map_file"][myomatrix], + ) + config_kilosort["remove_bad_myo_chans"] = np.array( + config["Session"]["remove_bad_myo_chans"][myomatrix] + ) + config_kilosort["remove_channel_delays"] = np.array( + config["Session"]["remove_channel_delays"][myomatrix] + ) + config_kilosort["num_chans"] = ( + config["Session"]["myo_chan_list"][myomatrix][1] + - config["Session"]["myo_chan_list"][myomatrix][0] + + 1 + ) scipy.io.savemat(f"{config['script_dir']}/tmp/config.mat", config_kilosort) - shutil.rmtree(config_kilosort['myo_sorted_dir'] + '/Plots', ignore_errors=True) + shutil.rmtree(config_kilosort["myo_sorted_dir"] + "/Plots", ignore_errors=True) - print('Starting resorting of ' + config_kilosort['myo_sorted_dir']) + print("Starting resorting of " + config_kilosort["myo_sorted_dir"]) scipy.io.savemat(f"{config['script_dir']}/tmp/config.mat", config_kilosort) - # get intermediate merge folders - merge_folders = Path(f"{config_kilosort['myo_sorted_dir']}/custom_merges").glob("intermediate_merge*") - # os.system(matlab_root + ' -nodisplay -nosplash -nodesktop -r "addpath(genpath(\'' + - # path_to_add + '\')); myomatrix_call"') - subprocess.run(["matlab", "-nodisplay", "-nosplash", "-nodesktop", "-r", - f"addpath(genpath('{path_to_add}')); myomatrix_call"], check=True) - - # extract waveforms for Phy FeatureView - for iDir in merge_folders: - # create symlinks to processed data - Path(f"{iDir}/proc.dat").symlink_to(Path("../../proc.dat")) - # run Phy extract-waveforms on intermediate merges - subprocess.run(["phy", "extract-waveforms", "params.py"],cwd=iDir, check=True) + ## get intermediate merge folders -- (2023-09-11) not doing intermediate merges anymore + # merge_folders = Path(f"{config_kilosort['myo_sorted_dir']}/custom_merges").glob( + # "intermediate_merge*" + # ) + subprocess.run( + [ + "matlab", + "-nodisplay", + "-nosplash", + "-nodesktop", + "-r", + ( + "rehash toolboxcache; restoredefaultpath;" + f"addpath(genpath('{path_to_add}')); myomatrix_call" + ), + ], + check=True, + ) + + # # extract waveforms for Phy FeatureView + # for iDir in merge_folders: + # # create symlinks to processed data + # Path(f"{iDir}/proc.dat").symlink_to(Path("../../proc.dat")) + # # run Phy extract-waveforms on intermediate merges + # subprocess.run(["phy", "extract-waveforms", "params.py"], cwd=iDir, check=True) + # create symlinks to processed data + Path( + f"{config_kilosort['myo_sorted_dir']}/custom_merges/final_merge/proc.dat" + ).symlink_to(Path("../../proc.dat")) # run Phy extract-waveforms on final merge - Path(f"{config_kilosort['myo_sorted_dir']}/custom_merges/final_merge/proc.dat").symlink_to(Path("../../proc.dat")) - subprocess.run(["phy", "extract-waveforms", "params.py"],cwd=f"{config_kilosort['myo_sorted_dir']}/custom_merges/final_merge", check=True) + subprocess.run( + ["phy", "extract-waveforms", "params.py"], + cwd=f"{config_kilosort['myo_sorted_dir']}/custom_merges/final_merge", + check=True, + ) + # copy sorted0 folder tree into same folder as for -myo_sort + try: + merge_path = "custom_merges/final_merge" + shutil.copytree( + Path(config_kilosort["myo_sorted_dir"]).joinpath(merge_path), + Path(config_kilosort["myo_sorted_dir"]) + .parent.joinpath(previous_sort_folder_to_use) + .joinpath(merge_path), + ) + except FileExistsError: + print( + f"WARNING: Final merge already exists in {previous_sort_folder_to_use}, files not updated" + ) + except: + raise + +# plot to show spikes overlaid on electrophysiology data, for validation purposes if myo_plot: - path_to_add = script_folder + '/sorting/' - # create default values for spike validation plot arguments, if not provided + path_to_add = script_folder + "/sorting/" + if "-d" in opts: + sorted_folder_to_plot = previous_sort_folder_to_use + args = args[1:] # remove the -d flag related argument + # create default values for validation plot arguments, if not provided if len(args) == 1: - arg1 = int(1) # default to plot chunk 1 - arg2 = 'true' # default to logical true to show all clusters + arg1 = int(1) # default to plot chunk 1 + arg2 = "true" # default to logical true to show all clusters elif len(args) == 2: arg1 = int(args[1]) - arg2 = 'true' # default to logical true to show all clusters + arg2 = "true" # default to logical true to show all clusters elif len(args) == 3: import json + arg_as_list = json.loads(args[2]) arg1 = int(args[1]) arg2 = np.array(arg_as_list).astype(int) - subprocess.run(["matlab", "-nodesktop", "-nosplash", "-r", - f"addpath(genpath('{path_to_add}')); spike_validation_plot({arg1},{arg2})"], check=True) + subprocess.run( + [ + "matlab", + "-nodesktop", + "-nosplash", + "-r", + ( + "rehash toolboxcache; restoredefaultpath;" + f"addpath(genpath('{path_to_add}')); spike_validation_plot({arg1},{arg2})" + ), + ], + check=True, + ) + +if myo_phy: + path_to_add = script_folder + "/sorting/" + if "-d" in opts: + sorted_folder_to_plot = previous_sort_folder_to_use + args = args[1:] # remove the -d flag related argument + else: + # default to sorted0 folder, may need to update to be flexible for sorted1, 2, etc. + sorted_folder_to_plot = "sorted0" + os.chdir(Path(config["myomatrix"]).joinpath(sorted_folder_to_plot)) + subprocess.run( + [ + "phy", + "template-gui", + "params.py", + ], + ) # Proceed with LFP extraction if lfp_extract: - config_kilosort['type'] = 1 - neuro_folders = glob.glob(config['neuropixel'] + '/*_g*') - for pixel in range(config['num_neuropixels']): - config_kilosort['neuropixel_folder'] = neuro_folders[pixel] - tmp = glob.glob(neuro_folders[pixel] + '/*_t*.imec' + str(pixel) + '.ap.bin') - config_kilosort['neuropixel'] = tmp[0] - if len(find('lfp.mat', config_kilosort['neuropixel_folder'])) > 0: - print('Found existing LFP file') + config_kilosort["type"] = 1 + neuro_folders = glob.glob(f"{config['neuropixel']}'/*_g*") + for pixel in range(config["num_neuropixels"]): + config_kilosort["neuropixel_folder"] = neuro_folders[pixel] + tmp = glob.glob(f"{neuro_folders[pixel]}/*_t*.imec{str(pixel)}.ap.bin") + config_kilosort["neuropixel"] = tmp[0] + if len(find("lfp.mat", config_kilosort["neuropixel_folder"])) > 0: + print("Found existing LFP file") else: - print('Extracting LFP from ' + config_kilosort['neuropixel'] + ' and saving') + print(f"Extracting LFP from {config_kilosort['neuropixel']} and saving") extract_LFP(config_kilosort) -print('Pipeline finished! You\'ve earned a break.') -print(datetime.datetime.now()) \ No newline at end of file +print("Pipeline finished! You've earned a break.") +finish_time = datetime.datetime.now() +time_elapsed = finish_time - start_time +# use strfdelta to format time elapsed +print( + ( + "Time elapsed: " + f"{strfdelta(time_elapsed, '{hours} hours, {minutes} minutes, {seconds} seconds')}" + ) +) + +# reset the terminal mode to prevent not printing user input to terminal after program exits +subprocess.run(["stty", "sane"]) diff --git a/sorting/Kilosort-3.0/CUDA/mexFilterPCs.cu b/sorting/Kilosort-3.0/CUDA/mexFilterPCs.cu index 4e9011ad..304d3faf 100644 --- a/sorting/Kilosort-3.0/CUDA/mexFilterPCs.cu +++ b/sorting/Kilosort-3.0/CUDA/mexFilterPCs.cu @@ -16,10 +16,10 @@ #include #include using namespace std; -const int Nthreads = 1024, NrankMax = 3; +const int Nthreads = 512, NrankMax = 12; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ - volatile __shared__ float sW[201*NrankMax], sdata[(Nthreads+201)*NrankMax]; + volatile __shared__ float sW[61*NrankMax], sdata[(Nthreads+61)*NrankMax]; float x, y; int tid, tid0, bid, i, nid, Nrank, NT, nt0; diff --git a/sorting/Kilosort-3.0/CUDA/mexGPUall.m b/sorting/Kilosort-3.0/CUDA/mexGPUall.m index da351052..e6d24bfd 100644 --- a/sorting/Kilosort-3.0/CUDA/mexGPUall.m +++ b/sorting/Kilosort-3.0/CUDA/mexGPUall.m @@ -2,12 +2,17 @@ % Matlab GPU library first (see README files for platform-specific % information) +% Only compile mex files used for Kilosort-3.0: +%%% -> spikedetector3PC.cu, +%%% -> mexMPnu8.cu/mexMPnu8_pcTight.cu +%%% -> mexWtW2.cu + enableStableMode = true; - mexcuda -largeArrayDims spikedetector3.cu + % mexcuda -largeArrayDims spikedetector3.cu mexcuda -largeArrayDims spikedetector3PC.cu - mexcuda -largeArrayDims mexThSpkPC.cu - mexcuda -largeArrayDims mexGetSpikes2.cu + % mexcuda -largeArrayDims mexThSpkPC.cu + % mexcuda -largeArrayDims mexGetSpikes2.cu if enableStableMode % For algorithm development purposes which require guaranteed @@ -15,16 +20,18 @@ % compile line for mexMPnu8.cu. -DENABLE_STABLEMODE must also % be specified. This version will run ~2X slower than the % non deterministic version. - mexcuda -largeArrayDims -dynamic -DENABLE_STABLEMODE mexMPnu8.cu + % mexcuda -largeArrayDims -dynamic -DENABLE_STABLEMODE mexMPnu8.cu + mexcuda -largeArrayDims -dynamic -DENABLE_STABLEMODE mexMPnu8_pcTight.cu else - mexcuda -largeArrayDims mexMPnu8.cu + % mexcuda -largeArrayDims mexMPnu8.cu + mexcuda -largeArrayDims mexMPnu8_pcTight.cu end - mexcuda -largeArrayDims mexSVDsmall2.cu + % mexcuda -largeArrayDims mexSVDsmall2.cu mexcuda -largeArrayDims mexWtW2.cu - mexcuda -largeArrayDims mexFilterPCs.cu - mexcuda -largeArrayDims mexClustering2.cu - mexcuda -largeArrayDims mexDistances2.cu + % mexcuda -largeArrayDims mexFilterPCs.cu + % mexcuda -largeArrayDims mexClustering2.cu + % mexcuda -largeArrayDims mexDistances2.cu % mex -largeArrayDims mexMPmuFEAT.cu diff --git a/sorting/Kilosort-3.0/CUDA/mexGetSpikes2.cu b/sorting/Kilosort-3.0/CUDA/mexGetSpikes2.cu index ae287c99..63f254ec 100644 --- a/sorting/Kilosort-3.0/CUDA/mexGetSpikes2.cu +++ b/sorting/Kilosort-3.0/CUDA/mexGetSpikes2.cu @@ -17,7 +17,7 @@ #include using namespace std; -const int Nthreads = 1024, maxFR = 5000, NrankMax = 6; +const int Nthreads = 1024, maxFR = 5000, NrankMax = 12; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void sumChannels(const double *Params, const float *data, float *datasum, int *kkmax, const int *iC){ @@ -58,7 +58,7 @@ __global__ void sumChannels(const double *Params, const float *data, ////////////////////////////////////////////////////////////////////////////////////////// __global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ - volatile __shared__ float sW[201*NrankMax], sdata[(Nthreads+201)]; + volatile __shared__ float sW[61*NrankMax], sdata[(Nthreads+61)]; float y; int tid, tid0, bid, i, nid, Nrank, NT, nt0, Nchan; @@ -127,7 +127,7 @@ __global__ void cleanup_spikes(const double *Params, const float *err, const int *ftype, float *x, int *st, int *id, int *counter){ int lockout, indx, tid, bid, NT, tid0, j, t0; - volatile __shared__ float sdata[Nthreads+2*201+1]; + volatile __shared__ float sdata[Nthreads+2*61+1]; bool flag=0; float err0, Th; diff --git a/sorting/Kilosort-3.0/CUDA/mexMPnu8.cu b/sorting/Kilosort-3.0/CUDA/mexMPnu8.cu index c6119b35..f3f459c6 100644 --- a/sorting/Kilosort-3.0/CUDA/mexMPnu8.cu +++ b/sorting/Kilosort-3.0/CUDA/mexMPnu8.cu @@ -22,7 +22,7 @@ using namespace std; #include "mexNvidia_quicksort.cu" #endif -const int Nthreads = 1024, maxFR = 100000, NrankMax = 3, nmaxiter = 500, NchanMax = 32; +const int Nthreads = 512, maxFR = 100000, NrankMax = 12, nmaxiter = 500, NchanMax = 32; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void spaceFilter(const double *Params, const float *data, const float *U, @@ -189,7 +189,7 @@ __global__ void spaceFilterUpdate_v2(const double *Params, const double *data, c ////////////////////////////////////////////////////////////////////////////////////////// __global__ void timeFilter(const double *Params, const float *data, const float *W,float *conv_sig){ - volatile __shared__ float sW2[201*NrankMax], sW[201*NrankMax], sdata[(Nthreads+201)*NrankMax]; + volatile __shared__ float sW2[61*NrankMax], sW[61*NrankMax], sdata[(Nthreads+61)*NrankMax]; float x; int tid, tid0, bid, i, nid, Nrank, NT, Nfilt, nt0, irank; @@ -247,7 +247,7 @@ __global__ void timeFilter(const double *Params, const float *data, const float __global__ void timeFilterUpdate(const double *Params, const float *data, const float *W, const bool *UtU, float *conv_sig, const int *st, const int *id, const int *counter){ - volatile __shared__ float sW[201*NrankMax], sW2[201*NrankMax]; + volatile __shared__ float sW[61*NrankMax], sW2[61*NrankMax]; float x; int tid, tid0, bid, t, k,ind, Nrank, NT, Nfilt, nt0; @@ -285,6 +285,7 @@ __global__ void timeFilterUpdate(const double *Params, const float *data, const } ////////////////////////////////////////////////////////////////////////////////////////// +// description: compute the error for each filter to determine which one is best __global__ void bestFilter(const double *Params, const float *data, const float *mu, float *err, float *eloss, int *ftype){ @@ -382,7 +383,7 @@ __global__ void cleanup_spikes(const double *Params, const float *data, const float *mu, const float *err, const float *eloss, const int *ftype, int *st, int *id, float *x, float *y, float *z, int *counter){ - volatile __shared__ float sdata[Nthreads+2*201+1]; + volatile __shared__ float sdata[Nthreads+2*61+1]; float err0, Th; int lockout, indx, tid, bid, NT, tid0, j, id0, t0; bool flag=0; @@ -717,7 +718,7 @@ __global__ void computePCfeatures(const double *Params, const int *counter, const float *W, const float *U, const float *mu, const int *iW, const int *iC, const float *wPCA, float *featPC){ - volatile __shared__ float sPCA[2*201 * NrankMax], sW[201 * NrankMax], sU[NchanMax * NrankMax]; + volatile __shared__ float sPCA[61 * NrankMax], sW[61 * NrankMax], sU[NchanMax * NrankMax]; volatile __shared__ int iU[NchanMax]; float X = 0.0f, Y = 0.0f; @@ -766,7 +767,7 @@ __global__ void computePCfeatures(const double *Params, const int *counter, X = Y * x[ind]; // - mu[bid]); for (t=0;t>>(d_Params, d_draw, d_U, d_iC, d_iW, d_data); @@ -1112,7 +1113,7 @@ void mexFunction(int nlhs, mxArray *plhs[], plhs[5] = mxGPUCreateMxArrayOnGPU(draw); plhs[6] = mxGPUCreateMxArrayOnGPU(nsp); - const mwSize dimsfPC[] = {NchanU, 2*Nrank, minSize}; + const mwSize dimsfPC[] = {NchanU, Nrank, minSize}; plhs[7] = mxCreateNumericArray(3, dimsfPC, mxSINGLE_CLASS, mxREAL); featPC = (float*) mxGetData(plhs[7]); @@ -1121,7 +1122,7 @@ void mexFunction(int nlhs, mxArray *plhs[], cudaMemcpy(x, d_y, minSize * sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(vexp, d_x, minSize * sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(feat, d_feat, minSize * Nnearest*sizeof(float), cudaMemcpyDeviceToHost); - cudaMemcpy(featPC, d_featPC, 2*minSize * NchanU*Nrank*sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(featPC, d_featPC, minSize * NchanU*Nrank*sizeof(float), cudaMemcpyDeviceToHost); // send back an error message if useStableMode was selected but couldn't be used diff --git a/sorting/Kilosort-3.0/CUDA/mexMPnu8_pcTight.cu b/sorting/Kilosort-3.0/CUDA/mexMPnu8_pcTight.cu new file mode 100644 index 00000000..bcea544b --- /dev/null +++ b/sorting/Kilosort-3.0/CUDA/mexMPnu8_pcTight.cu @@ -0,0 +1,1212 @@ +/* + * mexMPnu8_pcTight.cu + * 2021 TBC Modified version of core Kilosort spike extraction function mexMPnu8.cu + * - tightened up window of feature PC calc more closely to waveform center + * - PCA & feat calcs used for outputs occur in [computePCfeatures] + * - version assumes 61 point waveform length [nt0], aligned to sample [nt0min] + * - compute wPCA & featPC on timepoints [6:nt0-15] == [6:45], len=40; + * - Seemed to work really well at improving usability of template features during + * manual curation + * - Testing expansion to template evaluation by adding same tweaks to [timeFilter] & [timeFilterUpdate] + * - this is distinct from changing range of nt0, because its still important for spike subtraction + * to include the tails of each waveform...we just want the projections computed on a tighter, + * more meaningful range of timepoints w/in the waveform + * + * 2023 SMO Simply changed range of indexing to [10:nt0-10] == [10:50], len=40; because peaks are centered now + * + * Compile individually with: + * mexcuda -largeArrayDims -dynamic -DENABLE_STABLEMODE mexMPnu8_pcTight.cu + * +*/ +#include +#include +#include +#include +#include +#include "mex.h" +#include "gpu/mxGPUArray.h" +#include +#include +#include +using namespace std; + +#ifdef ENABLE_STABLEMODE + //for sorting according to timestamps + #include "mexNvidia_quicksort.cu" +#endif + +const int Nthreads = 1024, maxFR = 100000, NrankMax = 9, nmaxiter = 500, NchanMax = 32; + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void spaceFilter(const double *Params, const float *data, const float *U, + const int *iC, const int *iW, float *dprod){ + +// <<>> +// blockIdx = current filter/template +// blockDim = 1024 (max number of threads) +// threadIdx = used both to index channel (in synchronized portion) +// and time (in non-synchronized portion). + volatile __shared__ float sU[32*NrankMax]; + volatile __shared__ int iU[32]; + float x; + int tid, bid, i,k, Nrank, Nchan, NT, Nfilt, NchanU; + + tid = threadIdx.x; + bid = blockIdx.x; + NT = (int) Params[0]; + Nfilt = (int) Params[1]; + Nrank = (int) Params[6]; + NchanU = (int) Params[10]; //NchanNear in learnTemplates = 32 + Nchan = (int) Params[9]; + + if (tid>> + // just need to do this for all filters that have overlap with id[bid] and st[id] + // as in spaceFilter, tid = threadIdx.x is first used to index over channels and pcs + // then used to loop over time, now just from -nt0 to nt0 about the input spike time + // tidx represents time, from -nt0 to nt0 + // tidy loops through all filters that have overlap + + if (tid=0 & t>> + // just need to do this for all filters that have overlap with id[bid] and st[id] + // as in spaceFilter, tid = threadIdx.x is first used to index over channels and pcs + // then used to loop over time, now just from -nt0 to nt0 about the input spike time + // tidx represents time, from -nt0 to nt0 + // tidy loops through all filters that have overlap + + if (tid=0 & t>> +// threadIdx.x used as index over pcs in temporal templates +// (num PCs * number of timepoints = Nrank * nt0) +// Applied to data that's already been through filtering with +// the spatial templates, input data has dim Nrank x NT x Nfilt + + if(tid>> +// Same as timeFilter, except timepoints now limited to +/- nt0 about +// spike times assiged to filters that may overlap the current filter +// specified by bid. The matrix of potentially overlapping filters +// is given in UtU. + + if (tid=0 && tid0>> +// loop over timepoints + + tid0 = tid + bid * blockDim.x; + while (tid0 Cbest + 1e-6){ + Cnextbest = Cbest; + Cbest = Cf; + ibest = i; + } + else + if (Cf > Cnextbest + 1e-6) + Cnextbest = Cf; + } + err[tid0] = Cbest; + eloss[tid0] = Cbest - Cnextbest; + ftype[tid0] = ibest; + + tid0 += blockDim.x * gridDim.x; + } +} + +// THIS UPDATE DOES NOT UPDATE ELOSS? +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void bestFilterUpdate(const double *Params, const float *data, + const float *mu, float *err, float *eloss, int *ftype, const int *st, const int *id, const int *counter){ + + float Cf, Cbest, lam, b, a, Cnextbest; + int tid, ind, i,t, NT, Nfilt, ibest = 0, nt0; + + tid = threadIdx.x; + NT = (int) Params[0]; + Nfilt = (int) Params[1]; + lam = (float) Params[7]; + nt0 = (int) Params[4]; + + // we only need to compute this at updated locations + ind = counter[1] + blockIdx.x; + + if (ind=0 && t Cbest + 1e-6){ + Cnextbest = Cbest; + Cbest = Cf; + ibest = i; + } + else + if (Cf > Cnextbest + 1e-6) + Cnextbest = Cf; + } + err[t] = Cbest; + ftype[t] = ibest; + } + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void cleanup_spikes(const double *Params, const float *data, + const float *mu, const float *err, const float *eloss, const int *ftype, int *st, + int *id, float *x, float *y, float *z, int *counter){ + + volatile __shared__ float sdata[Nthreads+2*61+1]; + float err0, Th; + int lockout, indx, tid, bid, NT, tid0, j, id0, t0; + bool flag=0; + + // <<>> + lockout = (int) Params[4] - 1; // Parms[4] = nt0 + tid = threadIdx.x; + bid = blockIdx.x; + + NT = (int) Params[0]; + tid0 = bid * blockDim.x ; + Th = (float) Params[2]; + //lam = (float) Params[7]; + + while(tid0Th*Th){ + flag = 0; + for(j=-lockout;j<=lockout;j++) + if(sdata[tid+lockout+j]>err0){ + flag = 1; + break; + } + if(flag==0){ + indx = atomicAdd(&counter[0], 1); + if (indxTh){ + if (id[currInd]==bid){ + if (tidx==0 && threadIdx.y==0) + nsp[bid]++; + + tidy = threadIdx.y; + while (tidyThS){ + + tidy = threadIdx.y; + // only do this if the spike is "BAD" + while (tidy> +__global__ void set_idx( unsigned int *idx, const unsigned int nitems ) { + for( int i = 0; i < nitems; ++ i ) { + idx[i] = i; + } +} + +////////////////////////////////////////////////////////////////////////////////////////// + +/* + * Host code + */ +void mexFunction(int nlhs, mxArray *plhs[], + int nrhs, mxArray const *prhs[]) +{ + /* Initialize the MathWorks GPU API. */ + mxInitGPU(); + +// only increase Shared Memory Size if needed, and the highest choice +// should be according to the Compute Capability of your GPU +// int maxbytes = 232448; // 227 KiB, Compute Capability 9.0 +// int maxbytes = 166912; // 163 KiB, Compute Capability 8.0, 8.7 +// int maxbytes = 101376; // 99 KiB, Compute Capability 8.6, 8.9 +// int maxbytes = 65536; // 64 KiB, Compute Capability 7.5 +// int maxbytes = 98304; // 96 KiB, Compute Capability 7.0, 7.2 +// int maxbytes = 49152; // 48 KiB, Compute Capability 5.0-6.2 +// cudaFuncSetAttribute(timeFilter, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes); + + /* Declare input variables*/ + double *Params, *d_Params; + unsigned int nt0, Nchan, NT, Nfilt, Nnearest, Nrank, NchanU, useStableMode; + + /* read Params and copy to GPU */ + Params = (double*) mxGetData(prhs[0]); + NT = (unsigned int) Params[0]; + Nfilt = (unsigned int) Params[1]; + nt0 = (unsigned int) Params[4]; + Nnearest = (unsigned int) Params[5]; + Nrank = (unsigned int) Params[6]; + NchanU = (unsigned int) Params[10]; + Nchan = (unsigned int) Params[9]; + useStableMode = (unsigned int) Params[16]; + + // Make a local pointer to Params, which can be passed to kernels + cudaMalloc(&d_Params, sizeof(double)*mxGetNumberOfElements(prhs[0])); + cudaMemcpy(d_Params,Params,sizeof(double)*mxGetNumberOfElements(prhs[0]),cudaMemcpyHostToDevice); + + /* collect input GPU variables*/ + mxGPUArray const *W, *iList, *U, *iC, *iW, *mu, *UtU, *wPCA; + mxGPUArray *dWU, *draw, *nsp; + const int *d_iList, *d_iC, *d_iW; + const bool *d_UtU; + int *d_st, *d_nsp, *d_ftype, *d_id, *d_counter, *d_count; + double *d_dWU; + float *d_draw, *d_err, *d_x, *d_y, *d_z, *d_dout, *d_feat, *d_data, *d_featPC, *d_eloss; + const float *d_W, *d_U, *d_mu, *d_wPCA; + + // draw is not a constant , so the data has to be "copied" over + draw = mxGPUCopyFromMxArray(prhs[1]); + d_draw = (float *)(mxGPUGetData(draw)); + U = mxGPUCreateFromMxArray(prhs[2]); + d_U = (float const *)(mxGPUGetDataReadOnly(U)); + W = mxGPUCreateFromMxArray(prhs[3]); + d_W = (float const *)(mxGPUGetDataReadOnly(W)); + mu = mxGPUCreateFromMxArray(prhs[4]); + d_mu = (float const *)(mxGPUGetDataReadOnly(mu)); + iC = mxGPUCreateFromMxArray(prhs[5]); + d_iC = (int const *)(mxGPUGetDataReadOnly(iC)); + iW = mxGPUCreateFromMxArray(prhs[6]); + d_iW = (int const *)(mxGPUGetDataReadOnly(iW)); + UtU = mxGPUCreateFromMxArray(prhs[7]); + d_UtU = (bool const *)(mxGPUGetDataReadOnly(UtU)); + iList = mxGPUCreateFromMxArray(prhs[8]); + d_iList = (int const *) (mxGPUGetDataReadOnly(iList)); + wPCA = mxGPUCreateFromMxArray(prhs[9]); + d_wPCA = (float const *)(mxGPUGetDataReadOnly(wPCA)); + + + const mwSize dimsNsp[] = {Nfilt,1}; + nsp = mxGPUCreateGPUArray(2, dimsNsp, mxINT32_CLASS, mxREAL, MX_GPU_DO_NOT_INITIALIZE); + d_nsp = (int *)(mxGPUGetData(nsp)); + const mwSize dimsdWU[] = {nt0, Nchan, Nfilt}; + dWU = mxGPUCreateGPUArray(3, dimsdWU, mxDOUBLE_CLASS, mxREAL, MX_GPU_DO_NOT_INITIALIZE); + d_dWU = (double *)(mxGPUGetData(dWU)); + + cudaMalloc(&d_dout, 2*NT * Nfilt* sizeof(float)); + cudaMalloc(&d_data, NT * Nfilt*Nrank* sizeof(float)); + + cudaMalloc(&d_err, NT * sizeof(float)); + cudaMalloc(&d_ftype, NT * sizeof(int)); + cudaMalloc(&d_eloss, NT * sizeof(float)); + cudaMalloc(&d_st, maxFR * sizeof(int)); + cudaMalloc(&d_id, maxFR * sizeof(int)); + cudaMalloc(&d_x, maxFR * sizeof(float)); + cudaMalloc(&d_y, maxFR * sizeof(float)); + cudaMalloc(&d_z, maxFR * sizeof(float)); + + cudaMalloc(&d_counter, 2*sizeof(int)); + cudaMalloc(&d_count, nmaxiter*sizeof(int)); + cudaMalloc(&d_feat, maxFR * Nnearest * sizeof(float)); + cudaMalloc(&d_featPC, maxFR * NchanU*Nrank * sizeof(float)); + + cudaMemset(d_nsp, 0, Nfilt * sizeof(int)); + cudaMemset(d_dWU, 0, Nfilt * nt0 * Nchan* sizeof(double)); + cudaMemset(d_dout, 0, NT * Nfilt * sizeof(float)); + cudaMemset(d_data, 0, Nrank * NT * Nfilt * sizeof(float)); + cudaMemset(d_counter, 0, 2*sizeof(int)); + cudaMemset(d_count, 0, nmaxiter*sizeof(int)); + cudaMemset(d_st, 0, maxFR * sizeof(int)); + cudaMemset(d_id, 0, maxFR * sizeof(int)); + cudaMemset(d_x, 0, maxFR * sizeof(float)); + cudaMemset(d_y, 0, maxFR * sizeof(float)); + cudaMemset(d_z, 0, maxFR * sizeof(float)); + cudaMemset(d_feat, 0, maxFR * Nnearest * sizeof(float)); + cudaMemset(d_featPC, 0, maxFR * NchanU*Nrank * sizeof(float)); + + int *counter; + counter = (int*) calloc(1,2 * sizeof(int)); + + cudaMemset(d_err, 0, NT * sizeof(float)); + cudaMemset(d_ftype, 0, NT * sizeof(int)); + cudaMemset(d_eloss, 0, NT * sizeof(float)); + + //allocate memory for index array, to be filled with 0->N items if sorting + //is not selected, fill with time sorted spike indicies if selected + unsigned int *d_idx; + cudaMalloc(&d_idx, maxFR * sizeof(int)); + cudaMemset(d_idx, 0, maxFR * sizeof(int)); + + //allocate arrays for sorting timestamps prior to spike subtraction from + //the data and averaging. Set to Params[17] to 1 in matlab caller + unsigned int *d_stSort; + cudaMalloc(&d_stSort, maxFR * sizeof(int)); + cudaMemset(d_stSort, 0, maxFR * sizeof(int)); + + + dim3 tpB(8, 2*nt0-1), tpF(16, Nnearest), tpS(nt0, 16), tpW(Nnearest, Nrank), tpPC(NchanU, Nrank); + + // filter the data with the spatial templates + spaceFilter<<>>(d_Params, d_draw, d_U, d_iC, d_iW, d_data); + + // filter the data with the temporal templates + timeFilter<<>>(d_Params, d_data, d_W, d_dout); + // timeFilter<<>>(d_Params, d_data, d_W, d_dout); + + // compute the best filter + bestFilter<<>>(d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype); + + // loop to find and subtract spikes + + double *d_draw64; +#ifndef ENSURE_DETERM + if (useStableMode) { + // create copy of the dataraw, d_dout, d_data as doubles for arithmetic + // number of consecutive points to convert = Params(17) (Params(18) in matlab) + cudaMalloc(&d_draw64, NT*Nchan * sizeof(double)); + convToDouble<<<100,Nthreads>>>(d_Params, d_draw, d_draw64); + } +#endif + + for(int k=0;k<(int) Params[3];k++){ //Parms[3] = nInnerIter, set to 60 final pass + // ignore peaks that are smaller than another nearby peak + cleanup_spikes<<>>(d_Params, d_dout, d_mu, d_err, d_eloss, + d_ftype, d_st, d_id, d_x, d_y, d_z, d_counter); + + // add new spikes to 2nd counter + cudaMemcpy(counter, d_counter, 2*sizeof(int), cudaMemcpyDeviceToHost); + // limit number of spike to add to feature arrays AND subtract from drez + // to maxFR. maxFR = 100000, so this limit is likely not hit for "standard" + // batch size of 65000. However, could lead to duplicate template formation + // if the limit were hit in learning templates. Should we add a warning flag? + if (counter[0]>maxFR){ + counter[0] = maxFR; + cudaMemcpy(d_counter, counter, sizeof(int), cudaMemcpyHostToDevice); + } + + // extract template features before subtraction, for counter[1] to counter[0] + // tpF(16, Nnearest), blocks are over spikes + if (Params[12]>1) + extractFEAT<<<64, tpF>>>(d_Params, d_st, d_id, d_counter, d_dout, d_iList, d_mu, d_feat); + // subtract spikes from raw data. If compile switch "ENSURE_DETERM" is on, + // use subtract_spikes_v2, which threads only over + // spikes subratcted = counter[1] up to counter[0]. + // for this calculation to be reproducible, need to sort the spikes first + + +#ifdef ENSURE_DETERM + // create set of indicies from 0 to counter[0] - counter[1] - 1 + // if useStableMode = 0, this will be passed to subtract_spikes_v2 unaltered + // and spikes will be subtracted off in the order found + // NOTE: deterministic calculations are dependent on ENABLE_STABLEMODE! + set_idx<<< 1, 1 >>>(d_idx, counter[0] - counter[1]); + #ifdef ENABLE_STABLEMODE + if (useStableMode) { + //make a copy of the timestamp array to sort + cudaMemcpy( d_stSort, d_st+counter[1], (counter[0] - counter[1])*sizeof(int), cudaMemcpyDeviceToDevice ); + int left = 0; + int right = counter[0] - counter[1] - 1; + cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, left, right, 0); + } + #endif + if (Nchan < Nthreads) { + subtract_spikes_v2<<<1, Nchan>>>(d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U); + } + else { + subtract_spikes_v2<<>>(d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U); + } + // filter the data with the spatial templates, checking only times where + // identified spikes were subtracted. Need version using a single precision copy of draw + spaceFilterUpdate<<>>(d_Params, d_draw, d_U, d_UtU, d_iC, d_iW, d_data, + d_st, d_id, d_counter); + +#else + //"Normal" mode -- recommend useStableMode, which will give mostly deterministic calculations + //useStableMode = 0 will have significant differences from run to run, but is 15-20% faster + if (useStableMode) { + subtract_spikes_v4<<>>(d_Params, d_st, d_id, d_y, d_counter, d_draw64, d_W, d_U); + // filter the data with the spatial templates, checking only times where + // identified spikes were subtracted. Need version using a double precision copy of draw + spaceFilterUpdate_v2<<>>(d_Params, d_draw64, d_U, d_UtU, d_iC, d_iW, d_data, + d_st, d_id, d_counter); + } + else { + subtract_spikes<<>>(d_Params, d_st, d_id, d_y, d_counter, d_draw, d_W, d_U); + // filter the data with the spatial templates, checking only times where + // identified spikes were subtracted. Need version using a single precision copy of draw + spaceFilterUpdate<<>>(d_Params, d_draw, d_U, d_UtU, d_iC, d_iW, d_data, + d_st, d_id, d_counter); + } +#endif + + + // filter the data with the temporal templates, checking only times where + // identified spikes were subtracted + timeFilterUpdate<<>>(d_Params, d_data, d_W, d_UtU, d_dout, + d_st, d_id, d_counter); + + // shouldn't the space filter update and time filter update also only + // be done if counter[0] - counter[1] > 0? + if (counter[0]-counter[1]>0) { + bestFilterUpdate<<>>(d_Params, d_dout, d_mu, + d_err, d_eloss, d_ftype, d_st, d_id, d_counter); + + } + // d_count records the number of spikes (tracked in d_counter[0] in each + // iteration, but is currently unused. + cudaMemcpy(d_count+k+1, d_counter, sizeof(int), cudaMemcpyDeviceToDevice); + + // copy d_counter[0] to d_counter[1]. cleanup_spikes will look for new + // spikes in the data and increment d_counter[0]; features of these new + // spikes will be added to d_featPC and then subracted out of d_out. + cudaMemcpy(d_counter+1, d_counter, sizeof(int), cudaMemcpyDeviceToDevice); + } + +#ifndef ENSURE_DETERM + if (useStableMode) { + //convert arrays back to singles for the rest of the process + convToSingle<<<100,Nthreads>>>(d_Params, d_draw64, d_draw); + } +#endif + + // compute PC features from reziduals + subtractions + if (Params[12]>0) + computePCfeatures<<>>(d_Params, d_counter, d_draw, d_st, + d_id, d_y, d_W, d_U, d_mu, d_iW, d_iC, d_wPCA, d_featPC); + + //jic addition of time sorting prior to average_snips + //get a set of indices for the sorted timestamp array + //make an array of indicies; if useStableMode = 0, this will be passed + //to average_snips unaltered + set_idx<<< 1, 1 >>>(d_idx, counter[0]); + +#ifdef ENABLE_STABLEMODE + if (useStableMode) { + //make a copy of the timestamp array to sort + cudaMemcpy( d_stSort, d_st, counter[0]*sizeof(int), cudaMemcpyDeviceToDevice ); + int left = 0; + int right = counter[0]-1; + cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, left, right, 0); + } +#endif + + + + // update dWU here by adding back to subbed spikes. + // additional parameter d_idx = array of time sorted indicies + average_snips<<>>(d_Params, d_st, d_idx, d_id, d_x, d_y, d_counter, + d_draw, d_W, d_U, d_dWU, d_nsp, d_mu, d_z); + + float *x, *feat, *featPC, *vexp; + int *st, *id; + unsigned int minSize; + if (counter[0] using namespace std; -const int Nthreads = 1024, NrankMax = 3, nt0max = 201, NchanMax = 1024; +const int Nthreads = 1024, NrankMax = 12, nt0max = 61, NchanMax = 256; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void blankdWU(const double *Params, const double *dWU, @@ -284,7 +284,8 @@ __global__ void reNormalize(const double *Params, const double *A, const double void mexFunction(int nlhs, mxArray *plhs[], int nrhs, mxArray const *prhs[]) { - int maxbytes = 166912; // 163 KiB +// int maxbytes = 166912; // 163 KiB + int maxbytes = 101376; // 99 KiB cudaFuncSetAttribute(getW, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes); /* Initialize the MathWorks GPU API. */ diff --git a/sorting/Kilosort-3.0/CUDA/mexSVDsmall2_czuba.cu b/sorting/Kilosort-3.0/CUDA/mexSVDsmall2_czuba.cu new file mode 100644 index 00000000..cdd8b5f7 --- /dev/null +++ b/sorting/Kilosort-3.0/CUDA/mexSVDsmall2_czuba.cu @@ -0,0 +1,386 @@ +/* + * Example of how to use the mxGPUArray API in a MEX file. This example shows + * how to write a MEX function that takes a gpuArray input and returns a + * gpuArray output, e.g. B=mexFunction(A). + * + * [ks25] updates: + * - align min, instead of max(abs) to prevent arb inversion of balanced waveforms + * Copyright 2012 The MathWorks, Inc. + */ +#include +#include +#include +#include +#include +#include "mex.h" +#include "gpu/mxGPUArray.h" +#include +#include +#include +using namespace std; + +const int Nthreads = 1024, NrankMax = 12, nt0max = 61, NchanMax = 256; + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void blankdWU(const double *Params, const double *dWU, + const int *iC, const int *iW, double *dWUblank){ + + int nt0, tidx, tidy, bid, Nchan, NchanNear, iChan; + + nt0 = (int) Params[4]; + Nchan = (int) Params[9]; + NchanNear = (int) Params[10]; + + tidx = threadIdx.x; + tidy = threadIdx.y; + + bid = blockIdx.x; + + while (tidy xmax){ + // xmax = abs(sW[t]); + // imax = t; + // } + + tid = threadIdx.x; + // shift by imax - tmax + for (k=0;k xmax){ + // xmax = abs(sWup[t]); + // imax = t; + // sgnmax = copysign(1.0f, sWup[t]); + // } + + // interpolate by imax + for (k=0;k>>(d_Params, d_dWU, d_iC, d_iW, d_dWUb); + + // compute dWU * dWU' + getwtw<<>>(d_Params, d_dWUb, d_wtw); + + // get W by power svd iterations + getW<<>>(d_Params, d_wtw, d_W); + + // compute U by W' * dWU + getU<<>>(d_Params, d_dWUb, d_W, d_U); + + // normalize U, get S, get mu, renormalize W + reNormalize<<>>(d_Params, d_A, d_B, d_W, d_U, d_mu); + + plhs[0] = mxGPUCreateMxArrayOnGPU(W); + plhs[1] = mxGPUCreateMxArrayOnGPU(U); + plhs[2] = mxGPUCreateMxArrayOnGPU(mu); + + cudaFree(d_wtw); + cudaFree(d_Params); + cudaFree(d_dWUb); + + mxGPUDestroyGPUArray(dWU); + mxGPUDestroyGPUArray(B); + mxGPUDestroyGPUArray(A); + mxGPUDestroyGPUArray(W); + mxGPUDestroyGPUArray(U); + mxGPUDestroyGPUArray(mu); + mxGPUDestroyGPUArray(iC); + mxGPUDestroyGPUArray(iW); + +} diff --git a/sorting/Kilosort-3.0/CUDA/mexThSpkPC.cu b/sorting/Kilosort-3.0/CUDA/mexThSpkPC.cu index cd5bacab..b287d5ee 100644 --- a/sorting/Kilosort-3.0/CUDA/mexThSpkPC.cu +++ b/sorting/Kilosort-3.0/CUDA/mexThSpkPC.cu @@ -17,11 +17,11 @@ #include using namespace std; -const int Nthreads = 1024, maxFR = 100000, NrankMax = 3, nt0max=201, NchanMax = 17; +const int Nthreads = 1024, maxFR = 100000, NrankMax = 12, nt0max=61, NchanMax = 17; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ - volatile __shared__ float sW[201*NrankMax], sdata[Nthreads+201]; + volatile __shared__ float sW[61*NrankMax], sdata[Nthreads+61]; float x, y; int tid, tid0, bid, i, nid, Nrank, NT, nt0; @@ -167,7 +167,7 @@ __global__ void maxChannels(const double *Params, const float *dataraw, const f ////////////////////////////////////////////////////////////////////////////////////////// __global__ void max1D(const double *Params, const float *data, float *conv_sig){ - volatile __shared__ float sdata[Nthreads+201]; + volatile __shared__ float sdata[Nthreads+61]; float y, spkTh; int tid, tid0, bid, i, NT, nt0; diff --git a/sorting/Kilosort-3.0/CUDA/mexWtW2.cu b/sorting/Kilosort-3.0/CUDA/mexWtW2.cu index bb23a393..e8db86d4 100644 --- a/sorting/Kilosort-3.0/CUDA/mexWtW2.cu +++ b/sorting/Kilosort-3.0/CUDA/mexWtW2.cu @@ -22,10 +22,10 @@ const int nblock = 32; __global__ void crossFilter(const double *Params, const float *W1, const float *W2, const float *UtU, float *WtW){ - //__shared__ float shW1[nblock*201], shW2[nblock*201]; + //__shared__ float shW1[nblock*61], shW2[nblock*61]; extern __shared__ float array[]; float* shW1 = (float*)array; - float* shW2 = (float*)&shW1[nblock*201]; + float* shW2 = (float*)&shW1[nblock*61]; float x; int nt0, tidx, tidy , bidx, bidy, i, Nfilt, t, tid1, tid2; @@ -84,7 +84,14 @@ __global__ void crossFilter(const double *Params, const float *W1, const float * void mexFunction(int nlhs, mxArray *plhs[], int nrhs, mxArray const *prhs[]) { - int maxbytes = 166912; // 163 KiB +// only increase Shared Memory Size if needed, and the highest choice +// should be according to the Compute Capability of your GPU +// int maxbytes = 232448; // 227 KiB, Compute Capability 9.0 +// int maxbytes = 166912; // 163 KiB, Compute Capability 8.0, 8.7 +// int maxbytes = 101376; // 99 KiB, Compute Capability 8.6, 8.9 +// int maxbytes = 65536; // 64 KiB, Compute Capability 7.5 +// int maxbytes = 98304; // 96 KiB, Compute Capability 7.0, 7.2 +int maxbytes = 49152; // 48 KiB, Compute Capability 5.0-6.2 cudaFuncSetAttribute(crossFilter, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes); /* Declare input variables*/ @@ -122,7 +129,7 @@ void mexFunction(int nlhs, mxArray *plhs[], dim3 grid(1 + (Nfilt/nblock), 1 + (Nfilt/nblock)); dim3 block(nblock, nblock); - crossFilter<<>>(d_Params, d_W1, d_W2, d_UtU, d_WtW); + crossFilter<<>>(d_Params, d_W1, d_W2, d_UtU, d_WtW); plhs[0] = mxGPUCreateMxArrayOnGPU(WtW); diff --git a/sorting/Kilosort-3.0/CUDA/spikedetector3.cu b/sorting/Kilosort-3.0/CUDA/spikedetector3.cu index 2bf33430..c4de9f63 100644 --- a/sorting/Kilosort-3.0/CUDA/spikedetector3.cu +++ b/sorting/Kilosort-3.0/CUDA/spikedetector3.cu @@ -17,12 +17,12 @@ #include using namespace std; -const int Nthreads = 1024, NrankMax = 6, maxFR = 10000, nt0max=201, NchanMax = 17, nsizes = 5; +const int Nthreads = 1024, NrankMax = 12, maxFR = 10000, nt0max=61, NchanMax = 17, nsizes = 5; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ - volatile __shared__ float sW[201*NrankMax], sdata[(Nthreads+201)]; + volatile __shared__ float sW[61*NrankMax], sdata[(Nthreads+61)]; float y; int tid, tid0, bid, i, nid, Nrank, NT, nt0, Nchan; @@ -118,7 +118,7 @@ __global__ void sumChannels(const double *Params, const float *data, ////////////////////////////////////////////////////////////////////////////////////////// __global__ void max1D(const double *Params, const float *data, float *conv_sig){ - volatile __shared__ float sdata[Nthreads+201]; + volatile __shared__ float sdata[Nthreads+61]; float y, spkTh; int tid, tid0, bid, i, NT, nt0, nt0min; diff --git a/sorting/Kilosort-3.0/CUDA/spikedetector3PC.cu b/sorting/Kilosort-3.0/CUDA/spikedetector3PC.cu index 8a5cf2d0..b4a95624 100644 --- a/sorting/Kilosort-3.0/CUDA/spikedetector3PC.cu +++ b/sorting/Kilosort-3.0/CUDA/spikedetector3PC.cu @@ -17,12 +17,12 @@ #include using namespace std; -const int Nthreads = 1024, NrankMax = 6, maxFR = 10000, nt0max=201, NchanMax = 17, nsizes = 5; +const int Nthreads = 1024, NrankMax = 9, maxFR = 10000, nt0max=61, NchanMax = 17, nsizes = 5; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ - volatile __shared__ float sW[201*NrankMax], sdata[(Nthreads+201)]; + volatile __shared__ float sW[61*NrankMax], sdata[(Nthreads+61)]; float y; int tid, tid0, bid, i, nid, Nrank, NT, nt0, Nchan; @@ -98,6 +98,10 @@ __global__ void sumChannels(const double *Params, const float *data, for(j=0; j m1 - rs = rs(:, [2,1]); - rr = rr([2,1]); - mux = mu2; - mu2 = mu1; - mu1 = mux; -end + m1 = sum(wav1 .^ 2) ^ .5; + m2 = sum(wav2 .^ 2) ^ .5; -n1 = sum(rs(:,1)>.5); -n2 = sum(rs(:,2)>.5); -nmin = min(n1, n2); -% fprintf('%6.0d, %6.0d, %2.2f, %2.2f, %2.4f \n', n1, n2, rr(1), rr(2), abs(mu1-mu2)); + if m2 > m1 + rs = rs(:, [2, 1]); + rr = rr([2, 1]); + mux = mu2; + mu2 = mu1; + mu1 = mux; + end -flag = 1; -iclust = rs(:,1)>.5; -if (min(rr) .5); + n2 = sum(rs(:, 2) > .5); + nmin = min(n1, n2); + % fprintf('%6.0d, %6.0d, %2.2f, %2.2f, %2.4f \n', n1, n2, rr(1), rr(2), abs(mu1-mu2)); -if flag==1 - do_roll = 0; - r0 = mean((wav1-wav2).^2); - for j = 1:size(wroll,3) - - wav = wroll(:,:,j) * reshape(wav2, [size(wroll,2), numel(wav2)/size(wroll,2)]); - wav = wav(:)'; - - if j==1 || r0 > mean((wav1-wav).^2) - wav2_best = wav; - r0 = mean((wav1-wav2).^2); - do_roll = 1; - end - end - wav2 = wav2_best; - - rc = sum(wav1 .* wav2)/(m1 * m2); - dmu = 2 * abs(m1-m2)/(m1+m2); - if rc>.9 && dmu<.2 + flag = 1; % this flag is set to 0 if the split is vetoed + iclust = rs(:, 1) > .5; + if (min(rr) < rmin || nmin < nlow) flag = 0; -% fprintf('veto from similarity r = %2.2f, dmu = %2.2f, roll = %d \n', rc, dmu, do_roll) end -end -if use_CCG && (flag==1) - ss1 = ss(iclust); - ss2 = ss(~iclust); - [K, Qi, Q00, Q01, rir] = ccg(ss1, ss2, 500, dt); % compute the cross-correlogram between spikes in the putative new clusters - Q12 = min(Qi/max(Q00, Q01)); % refractoriness metric 1 - R = min(rir); % refractoriness metric 2 - if Q12<.25 && R<.05 % if both metrics are below threshold. -% disp('veto from CCG') - flag = 0; + % this section uses products of shifted PCA snippet projections in wroll + if flag == 1 + do_roll = 0; + r0 = mean((wav1 - wav2) .^ 2); + for j = 1:size(wroll, 3) + + wav = wroll(:, :, j) * reshape(wav2, [size(wroll, 2), numel(wav2) / size(wroll, 2)]); + wav = wav(:)'; + + if j == 1 || r0 > mean((wav1 - wav) .^ 2) + wav2_best = wav; + r0 = mean((wav1 - wav2) .^ 2); + do_roll = 1; + end + end + wav2 = wav2_best; + + rc = sum(wav1 .* wav2) / (m1 * m2); + dmu = 2 * abs(m1 - m2) / (m1 + m2); + if rc > .9 && dmu < .2 + flag = 0; + % fprintf('veto from similarity r = %2.2f, dmu = %2.2f, roll = %d \n', rc, dmu, do_roll) + end end -end -% veto from alignment here -% compute correlation of centroids + if use_CCG && (flag == 1) + ss1 = ss(iclust); + ss2 = ss(~iclust); + [K, Qi, Q00, Q01, rir] = ccg(ss1, ss2, 500, dt); % compute the cross-correlogram between spikes in the putative new clusters + Q12 = min(Qi / max(Q00, Q01)); % refractoriness metric 1 + R = min(rir); % refractoriness metric 2 + if Q12 < .25 && R < .05 % if both metrics are below threshold. + % disp('veto from CCG') + flag = 0; + end + end + % veto from alignment here + % compute correlation of centroids -if (flag==0) && (retry>0) - w = w / sum(w.^2).^.5; - clp = Xd - (Xd * w') * w; -% disp('one more try') - [x, iclust, flag] = bimodal_pursuit(clp, wroll, ss, rmin, nlow, retry-1, use_CCG); -end + if (flag == 0) && (retry > 0) + w = w / sum(w .^ 2) .^ .5; + clp = Xd - (Xd * w') * w; + % disp('one more try') + [x, iclust, flag] = bimodal_pursuit(clp, wroll, ss, rmin, nlow, retry - 1, use_CCG, nPCs); + end -% sd = sum(mu_clp.^2)^.5; + % sd = sum(mu_clp.^2)^.5; end -function u = nonrandom_projection(clp) -npow = 6; -nbase = 2; -u = make_rproj(npow, nbase); -u = u - .5; %mean(u(:)); - -Xd = clp(:, 1:npow); -ntry = size(u,2); -scmax = gpuArray.zeros(ntry, 1, 'single'); -w = gpuArray(single(u)); -w = w ./ [1:npow]'; -w = w ./ sum(w.^2, 1).^.5; - -for j = 1:ntry - x = Xd * w(:,j) ; - [r, scmax(j)] = find_split(x); -end +% this function checks all combinations of projection vectors for optimal bimodality of split +function u = nonrandom_projection(clp, nPCs) + npow = nPCs; % this is the combinatory power of the projection + nbase = 2; % this is the base for the combinatoric problem + u = make_rproj(npow, nbase); + u = u - .5; %mean(u(:)); % center the projection vectors + + Xd = clp(:, 1:npow); + ntry = size(u, 2); + scmax = gpuArray.zeros(ntry, 1, 'single'); + w = gpuArray(single(u)); + w = w ./ [1:npow]'; % is this to make the projection vectors orthogonal (maybe?) + w = w ./ sum(w .^ 2, 1) .^ .5; % this is to normalize the projection vectors + + for j = 1:ntry + x = Xd * w(:, j); % project the data + [~, scmax(j), ~, ~, ~, ~, ~] = find_split(x); % find the best split + end -[~, imax] = max(scmax); -u = gpuArray.zeros(size(clp,2), 1, 'single'); -u(1:npow) = w(:,imax); + [~, imax] = max(scmax); + u = gpuArray.zeros(size(clp, 2), 1, 'single'); + u(1:npow) = w(:, imax); end - +% this function finds the best projections to split along function u = make_rproj(npow, nbase) -u = zeros(npow, nbase^(npow-1)); -for j = 1:nbase^(npow-1) - u(:, j) = proj_comb(j-1, npow, nbase); -end + u = zeros(npow, nbase ^ (npow - 1)); % all possible combinations of projection vectors + for j = 1:nbase ^ (npow - 1) + u(:, j) = proj_comb(j - 1, npow, nbase); + end end - +% this function projects a number in base nbase to a vector of length npow function u = proj_comb(k, npow, nbase) -u = zeros(npow, 1); -u(1) = 1; -for j = 2:npow - u(j) = rem(k, nbase); - k = floor(k/nbase); -end + u = zeros(npow, 1); + u(1) = 1; + for j = 2:npow + u(j) = rem(k, nbase); + k = floor(k / nbase); + end end - diff --git a/sorting/Kilosort-3.0/clustering/extract_spikes.m b/sorting/Kilosort-3.0/clustering/extract_spikes.m index 5e4bfe1d..bb699f9a 100644 --- a/sorting/Kilosort-3.0/clustering/extract_spikes.m +++ b/sorting/Kilosort-3.0/clustering/extract_spikes.m @@ -1,121 +1,125 @@ function [rez, st3, tF] = extract_spikes(rez) -ymin = min(rez.yc); -ymax = max(rez.yc); -xmin = min(rez.xc); -xmax = max(rez.xc); - -% dmin = median(diff(unique(rez.yc))); -% fprintf('vertical pitch size is %d \n', dmin) -% rez.ops.dmin = dmin; -% rez.ops.yup = ymin:dmin/2:ymax; % centers of the upsampled y positions -% -% % dminx = median(diff(unique(rez.xc))); -% yunq = unique(rez.yc); -% mxc = zeros(numel(yunq), 1); -% for j = 1:numel(yunq) -% xc = rez.xc(rez.yc==yunq(j)); -% if numel(xc)>1 -% mxc(j) = median(diff(sort(xc))); -% end -% end -% dminx = median(mxc); -% fprintf('horizontal pitch size is %d \n', dminx) -% -% rez.ops.dminx = dminx; -% nx = round((xmax-xmin) / (dminx/2)) + 1; -% rez.ops.xup = linspace(xmin, xmax, nx); % centers of the upsampled x positions -% disp(rez.ops.xup) - -% rez.ops.yup = ymin:10:ymax; % centers of the upsampled y positions -% rez.ops.xup = xmin + [-16 0 16 32 48 64]; % centers of the upsampled x positions - -ops = rez.ops; - -spkTh = ops.Th(1); -sig = 10; -dNearActiveSite = median(diff(unique(rez.yc))); - -[ycup, xcup] = meshgrid(ops.yup, ops.xup); - -NrankPC = ops.nPCs; %!! -[wTEMP, wPCA] = extractTemplatesfromSnippets(rez, NrankPC); - -NchanNear = min(ops.Nchan, 16); %% CHANGED: was 8 -[iC, dist] = getClosestChannels2(ycup, xcup, rez.yc, rez.xc, NchanNear); - -igood = dist(1,:)1 + % mxc(j) = median(diff(sort(xc))); + % end + % end + % dminx = median(mxc); + % fprintf('horizontal pitch size is %d \n', dminx) + % + % rez.ops.dminx = dminx; + % nx = round((xmax-xmin) / (dminx/2)) + 1; + % rez.ops.xup = linspace(xmin, xmax, nx); % centers of the upsampled x positions + % disp(rez.ops.xup) + + % rez.ops.yup = ymin:10:ymax; % centers of the upsampled y positions + % rez.ops.xup = xmin + [-16 0 16 32 48 64]; % centers of the upsampled x positions + + ops = rez.ops; + + spkTh = ops.Th(1); + sig = 1000; % set microns as hardcoded BASE gaussian kernel radius for the nearest channels + dNearActiveSite = median(diff(unique(rez.yc))); + + [ycup, xcup] = meshgrid(ops.yup, ops.xup); % define the 2x upsampled grid + + NrankPC = ops.nPCs; %!! + + % [wTEMP, wPCA] = extractTemplatesfromSnippets(rez, NrankPC); + wTEMP = rez.wTEMP; + wPCA = rez.wPCA; + NchanNear = min(ops.Nchan, 16); % % CHANGED: was 8 + [iC, dist] = getClosestChannels2(ycup, xcup, rez.yc, rez.xc, NchanNear); + + igood = dist(1, :) < dNearActiveSite; + iC = iC(:, igood); + dist = dist(:, igood); + + ycup = ycup(igood); + xcup = xcup(igood); + + NchanNearUp = min(numel(ycup), 10 * NchanNear); + [iC2, dist2] = getClosestChannels2(ycup, xcup, ycup, xcup, NchanNearUp); + + nsizes = 5; + v2 = gpuArray.zeros(nsizes, size(dist, 2), 'single'); + for k = 1:nsizes + v2(k, :) = sum(exp(- 2 * dist .^ 2 / (sig * k) ^ 2), 1); + end -NchanUp = size(iC,2); + NchanUp = size(iC, 2); -% wTEMP = wPCA * (wPCA' * wTEMP); + % wTEMP = wPCA * (wPCA' * wTEMP); + t0 = 0; + id = []; + mu = []; + nsp = 0; -t0 = 0; -id = []; -mu = []; -nsp = 0; + tF = zeros(NrankPC, NchanNear, 1e6, 'single'); -tF = zeros(NrankPC, NchanNear, 1e6, 'single'); + tic + st3 = zeros(1000000, 6); -tic -st3 = zeros(1000000, 6); + for k = 1:ops.Nbatch + dataRAW = get_batch(rez.ops, k); -for k = 1:ops.Nbatch - dataRAW = get_batch(rez.ops, k); + Params = [size(dataRAW, 1) ops.Nchan ops.nt0 NchanNear NrankPC ops.nt0min spkTh NchanUp NchanNearUp sig]; - Params = [size(dataRAW,1) ops.Nchan ops.nt0 NchanNear NrankPC ops.nt0min spkTh NchanUp NchanNearUp sig]; + [dat, kkmax, st, cF, feat] = ... + spikedetector3PC(Params, dataRAW, wTEMP, iC - 1, dist, v2, iC2 - 1, dist2, wPCA); + % [dat, kkmax, st, cF] = ... + % spikedetector3(Params, dataRAW, wTEMP, iC-1, dist, v2, iC2-1, dist2); - [dat, kkmax, st, cF, feat] = ... - spikedetector3PC(Params, dataRAW, wTEMP, iC-1, dist, v2, iC2-1, dist2, wPCA); -% [dat, kkmax, st, cF] = ... -% spikedetector3(Params, dataRAW, wTEMP, iC-1, dist, v2, iC2-1, dist2); - - ns = size(st,2); - if nsp + ns>size(tF,3) - tF(:,:,end + 1e6) = 0; - st3(end + 1e6, 1) = 0; - end - - toff = ops.nt0min + t0 + ops.NT *(k-1); - st(1,:) = st(1,:) + toff; - st = double(st); - % https://github.com/MouseLand/Kilosort/issues/427 - try - st(5,:) = cF; - catch - st = [st; cF]; - end - st(6,:) = k-1; - - st3(nsp + [1:ns], :) = gather(st)'; - - tF(:, :, nsp + [1:ns]) = gather(feat); - nsp = nsp + ns; - - if rem(k,100)==1 || k==ops.Nbatch - fprintf('%2.2f sec, %d batches, %d spikes \n', toc, k, nsp) + ns = size(st, 2); + if nsp + ns > size(tF, 3) + tF(:, :, end +1e6) = 0; + st3(end +1e6, 1) = 0; + end + + toff = ops.nt0min + t0 + ops.NT * (k - 1); + st(1, :) = st(1, :) + toff; + st = double(st); + % https://github.com/MouseLand/Kilosort/issues/427 + try + st(5, :) = cF; + catch + st = [st; cF]; + end + st(6, :) = k - 1; + + st3(nsp + [1:ns], :) = gather(st)'; + + tF(:, :, nsp + [1:ns]) = gather(feat); + nsp = nsp + ns; + + if rem(k, round(ops.Nbatch / 10)) == 0 % print progress 10 times + fprintf('%2.2f sec, %d batches, %d spikes \n', toc, k, nsp) + end end -end -tF = tF(:, :, 1:nsp); -st3 = st3(1:nsp, :); + tF = tF(:, :, 1:nsp); % remove excess preallocated space + st3 = st3(1:nsp, :); -rez.iC = iC; -tF = permute(tF, [3, 1, 2]); + rez.iC = iC; + tF = permute(tF, [3, 1, 2]); -rez.ycup = ycup; -rez.xcup = xcup; + rez.ycup = ycup; + rez.xcup = xcup; + +end diff --git a/sorting/Kilosort-3.0/clustering/final_clustering.m b/sorting/Kilosort-3.0/clustering/final_clustering.m index 72a39c43..f438d270 100644 --- a/sorting/Kilosort-3.0/clustering/final_clustering.m +++ b/sorting/Kilosort-3.0/clustering/final_clustering.m @@ -1,160 +1,162 @@ function rez1 = final_clustering(rez, tF, st3) -wPCA = rez.wPCA; -wTEMP = rez.wTEMP; + wPCA = rez.wPCA; + wTEMP = rez.wTEMP; -iC = rez.iC; -ops = rez.ops; + iC = rez.iC; + ops = rez.ops; + wroll = []; + tlag = [-2, -1, 1, 2]; + for j = 1:length(tlag) + wroll(:, :, j) = circshift(wPCA, tlag(j), 1)' * wPCA; + end -wroll = []; -tlag = [-2, -1, 1, 2]; -for j = 1:length(tlag) - wroll(:,:,j) = circshift(wPCA, tlag(j), 1)' * wPCA; -end - -%% split templates into batches -rmin = 0.6; -nlow = 100; -n0 = 0; -use_CCG = 1; - -Nchan = rez.ops.Nchan; -Nk = size(iC,2); -yunq = unique(rez.yc); - -ktid = int32(st3(:,2)); - -uweigh = abs(rez.U(:,:,1)); -uweigh = uweigh ./ sum(uweigh,1); -ycup = sum(uweigh .* rez.yc, 1); -xcup = sum(uweigh .* rez.xc, 1); - -Nfilt = size(rez.W,2); -dWU = gpuArray.zeros(ops.nt0, ops.Nchan, Nfilt, 'double'); -for j = 1:Nfilt - dWU(:,:,j) = rez.mu(j) * squeeze(rez.W(:, j, :)) * squeeze(rez.U(:, j, :))'; -end - - -ops = rez.ops; -NchanNear = min(ops.Nchan, 16); - -[iC, mask, C2C] = getClosestChannels(rez, ops.sigmaMask, NchanNear); - + %% split templates into batches + rmin = 0.6; + nlow = 100; + n0 = 0; + use_CCG = 1; + + Nchan = rez.ops.Nchan; + Nk = size(iC, 2); + yunq = unique(rez.yc); + + ktid = int32(st3(:, 2)); + + uweigh = abs(rez.U(:, :, 1)); + uweigh = uweigh ./ sum(uweigh, 1); + ycup = sum(uweigh .* rez.yc, 1); + xcup = sum(uweigh .* rez.xc, 1); + + Nfilt = size(rez.W, 2); + dWU = gpuArray.zeros(ops.nt0, ops.Nchan, Nfilt, 'double'); + for j = 1:Nfilt + try + dWU(:, :, j) = rez.mu(j) * squeeze(rez.W(:, j, :)) * squeeze(rez.U(:, j, :))'; + catch + % try again, sometimes fails due to GPU issues + dWU(:, :, j) = rez.mu(j) * squeeze(rez.W(:, j, :)) * squeeze(rez.U(:, j, :))'; + end + end -[~, iW] = max(abs(dWU(ops.nt0min, :, :)), [], 2); -iW = int32(squeeze(iW)); + ops = rez.ops; + NchanNear = min(ops.Nchan, 16); + + [iC, mask, C2C] = getClosestChannels(rez, ops.sigmaMask, NchanNear); + + [~, iW] = max(abs(dWU(ops.nt0min, :, :)), [], 2); + iW = int32(squeeze(iW)); + + iC = gather(iC(:, iW)); + %% + ss = double(st3(:, 1)) / ops.fs; + + dmin = rez.ops.dmin; + ycenter = (min(rez.yc) + dmin - 1):(2 * dmin):(max(rez.yc) + dmin + 1); + dminx = rez.ops.dminx; + xcenter = (min(rez.xc) + dminx - 1):(2 * dminx):(max(rez.xc) + dminx + 1); + [xcenter, ycenter] = meshgrid(xcenter, ycenter); + xcenter = xcenter(:); + ycenter = ycenter(:); + + Wpca = zeros(size(wPCA, 2), Nchan, 1000, 'single'); + nst = numel(ktid); + hid = zeros(nst, 1, 'int32'); + + xy = zeros(nst, 2); + + tic + for j = 1:numel(ycenter) + if rem(j, round(numel(ycenter / 10))) == 0 % print progress 10 times + fprintf('time %2.2f, GROUP %d/%d, units %d \n', toc, j, numel(ycenter), n0) + end + y0 = ycenter(j); + x0 = xcenter(j); + xchan = (abs(ycup - y0) < dmin) & (abs(xcup - x0) < dminx); + + itemp = find(xchan); + + tin = ismember(ktid, itemp); + + if sum(tin) < 1 + continue; + end + + pid = ktid(tin); + data = tF(tin, :, :); + + ich = unique(iC(:, itemp)); + % ch_min = ich(1)-1; + % ch_max = ich(end); + + nsp = size(data, 1); + dd = zeros(nsp, size(wPCA, 2), numel(ich), 'single'); + for k = 1:length(itemp) + ix = pid == itemp(k); % find spikes on this template + [~, ia, ib] = intersect(iC(:, itemp(k)), ich); + dd(ix, :, ib) = data(ix, :, ia); + end + xy(tin, :) = spike_position(dd, wPCA, wTEMP, rez.xc(ich), rez.yc(ich)); + kid = run_pursuit(dd, nlow, rmin, n0, wroll, ss(tin), use_CCG, ops.nPCs); + + [~, ~, kid] = unique(kid); + nmax = max(kid); + for t = 1:nmax + Wpca(:, ich, t + n0) = gather(sq(mean(dd(kid == t, :, :), 1))); + end + + hid(tin) = gather(kid + n0); + n0 = n0 + nmax; + end + Wpca = Wpca(:, :, 1:n0); + toc + %% + + rez.xy = xy; + + clust_good = check_clusters(hid, ss, .2); + sum(clust_good) + + % waveform length was hardcoded at 61. Should be parametric, w/min index for consistent trough polarity + rez.W = zeros(ops.nt0, 0, ops.nEig, 'single'); + rez.U = zeros(ops.Nchan, 0, ops.nEig, 'single'); + rez.mu = zeros(1, 0, 'single'); + for t = 1:n0 % for each cluster + dWU = wPCA * gpuArray(Wpca(:, :, t)); + % align max of abs(dWU) to nt0min + [~, dWU_shift] = max(sum(abs(dWU), 2), [], 1); % shape of dWU_shift is 1 x 1 + dWU_shift = dWU_shift - ops.nt0min; % get shift needed to align max value for each channel to nt0min + dWU = circshift(dWU, -dWU_shift); % shift PC components by that amount + [w, s, u] = svdecon(dWU); + wsign = -sign(w(ops.nt0min + 1, 1)); + rez.W(:, t, :) = wsign * w(:, 1:ops.nEig); + rez.U(:, t, :) = wsign * u(:, 1:ops.nEig) * s(1:ops.nEig, 1:ops.nEig); + rez.mu(t) = sum(sum(rez.U(:, t, :) .^ 2)) ^ .5; + rez.U(:, t, :) = rez.U(:, t, :) / rez.mu(t); + end -iC = gather(iC(:,iW)); -%% -ss = double(st3(:,1)) / ops.fs; + %% + amps = sq(sum(sum(tF .^ 2, 2), 3)) .^ .5; -dmin = rez.ops.dmin; -ycenter = (min(rez.yc) + dmin-1):(2*dmin):(max(rez.yc)+dmin+1); -dminx = rez.ops.dminx; -xcenter = (min(rez.xc) + dminx-1):(2*dminx):(max(rez.xc)+dminx+1); -[xcenter, ycenter] = meshgrid(xcenter, ycenter); -xcenter = xcenter(:); -ycenter = ycenter(:); + rez1 = rez; + rez1.st3 = st3; + rez1.st3(:, 1) = rez1.st3(:, 1); % + ops.ntbuff; + rez1.st3(:, 2) = hid; + rez1.st3(:, 3) = amps; -Wpca = zeros(size(wPCA,2), Nchan, 1000, 'single'); -nst = numel(ktid); -hid = zeros(nst,1 , 'int32'); + rez1.cProj = []; + rez1.iNeigh = []; + rez1.cProjPC = []; + rez1.iNeighPC = []; + % rez1.st3(:,1) = rez1.st3(:,1); % - 30000 * 2000; -xy = zeros(nst, 2); + rez1.U = permute(Wpca, [2, 3, 1]); -tic -for j = 1:numel(ycenter) - if rem(j,5)==1 - fprintf('time %2.2f, GROUP %d/%d, units %d \n', toc, j, numel(ycenter), n0) - end - y0 = ycenter(j); - x0 = xcenter(j); - xchan = (abs(ycup - y0) < dmin) & (abs(xcup - x0) < dminx); - - itemp = find(xchan); - - tin = ismember(ktid, itemp); - - if sum(tin)<1 - continue; - end - - pid = ktid(tin); - data = tF(tin, :, :); - - ich = unique(iC(:, itemp)); -% ch_min = ich(1)-1; -% ch_max = ich(end); - - nsp = size(data,1); - dd = zeros(nsp, size(wPCA,2), numel(ich), 'single'); - for k = 1:length(itemp) - ix = pid==itemp(k); - [~,ia,ib] = intersect(iC(:,itemp(k)), ich); - dd(ix, :, ib) = data(ix,:,ia); - end - xy(tin, :) = spike_position(dd, wPCA, wTEMP, rez.xc(ich), rez.yc(ich)); - - kid = run_pursuit(dd, nlow, rmin, n0, wroll, ss(tin), use_CCG); - - [~, ~, kid] = unique(kid); - nmax = max(kid); - for t = 1:nmax - Wpca(:, ich, t + n0) = gather(sq(mean(dd(kid==t,:,:),1))); - end - - hid(tin) = gather(kid + n0); - n0 = n0 + nmax; -end -Wpca = Wpca(:,:,1:n0); -toc -%% - -rez.xy = xy; - -clust_good = check_clusters(hid, ss, .2); -sum(clust_good) - -% waveform length was hardcoded at 61. Should be parametric, w/min index for consistent trough polarity -rez.W = zeros(ops.nt0, 0, ops.nEig, 'single'); -rez.U = zeros(ops.Nchan, 0, ops.nEig, 'single'); -rez.mu = zeros(1,0, 'single'); -for t = 1:n0 - dWU = wPCA * gpuArray(Wpca(:,:,t)); - [w,s,u] = svdecon(dWU); - wsign = -sign(w(ops.nt0min+1, 1)); - rez.W(:,t,:) = wsign * w(:,1:ops.nEig); - rez.U(:,t,:) = wsign * u(:,1:ops.nEig) * s(1:ops.nEig,1:ops.nEig); - rez.mu(t) = sum(sum(rez.U(:,t,:).^2))^.5; - rez.U(:,t,:) = rez.U(:,t,:) / rez.mu(t); -end - -%% -amps = sq(sum(sum(tF.^2,2),3)).^.5; - - -rez1 = rez; -rez1.st3 = st3; -rez1.st3(:,1) = rez1.st3(:,1); % + ops.ntbuff; -rez1.st3(:,2) = hid; -rez1.st3(:,3) = amps; - - -rez1.cProj = []; -rez1.iNeigh = []; -rez1.cProjPC = []; -rez1.iNeighPC = []; -% rez1.st3(:,1) = rez1.st3(:,1); % - 30000 * 2000; - -rez1.U = permute(Wpca, [2,3,1]); - -rez1.mu = sum(sum(rez1.U.^2, 1),3).^.5; -rez1.U = rez1.U ./ rez1.mu; -rez1.mu = rez1.mu(:); - -rez1.W = reshape(rez.wPCA, [ops.nt0, 1, size(wPCA,2)]); -rez1.W = repmat(rez1.W, [1, n0, 1]); -rez1.est_contam_rate = ones(n0,1); + rez1.mu = sum(sum(rez1.U .^ 2, 1), 3) .^ .5; + rez1.U = rez1.U ./ rez1.mu; + rez1.mu = rez1.mu(:); + rez1.W = reshape(rez.wPCA, [ops.nt0, 1, size(wPCA, 2)]); + rez1.W = repmat(rez1.W, [1, n0, 1]); + rez1.est_contam_rate = ones(n0, 1); diff --git a/sorting/Kilosort-3.0/clustering/find_split.m b/sorting/Kilosort-3.0/clustering/find_split.m index a234e378..d7395199 100644 --- a/sorting/Kilosort-3.0/clustering/find_split.m +++ b/sorting/Kilosort-3.0/clustering/find_split.m @@ -1,5 +1,5 @@ function [r, scmax, p, m0, mu1, mu2, sig] = find_split(x) - +% scmax is the score of the split x = gather(x); qbar = [.001, .999]; nbins = 1001; diff --git a/sorting/Kilosort-3.0/clustering/run_pursuit.m b/sorting/Kilosort-3.0/clustering/run_pursuit.m index 1d368af9..d8ad96a5 100644 --- a/sorting/Kilosort-3.0/clustering/run_pursuit.m +++ b/sorting/Kilosort-3.0/clustering/run_pursuit.m @@ -1,79 +1,76 @@ -function kid = run_pursuit(data, nlow, rmin, n0, wroll, ss, use_CCG) - -Xd = gpuArray(data(:, :)); -amps = sum(Xd.^2, 2).^.5; - -kid = zeros(size(Xd,1), 1); - -aj = zeros(1000,1); -for j = 1:1000 - ind = find(kid==0); -% fprintf('cluster %d\n', n0+j) - [ix, xold, xnew] = break_a_cluster(Xd(ind, :), wroll, ss(ind), nlow, rmin, use_CCG); - - aj(j) = gather(mean(amps(ind(ix)))); -% fprintf('amps = %2.2f \n\n', aj(j)); - kid(ind(ix)) = j; - - if length(ix) == length(ind) - break; - end - -end -aj = aj(1:j); +function kid = run_pursuit(data, nlow, rmin, n0, wroll, ss, use_CCG, nPCs) -end + Xd = gpuArray(data(:, :)); + amps = sum(Xd .^ 2, 2) .^ .5; % get amplitudes by square root of sum of squares + kid = zeros(size(Xd, 1), 1); % these are the cluster ID assignments! -function [ix, xold, x] = break_a_cluster(data,wroll, ss, nlow, rmin, use_CCG) -ix = 1:size(data,1); + aj = zeros(1000, 1); % this is all the amplitudes of the clusters + for j = 1:1000 + ind = find(kid == 0); % these are the spikes inds not assigned to a cluster yet + % fprintf('cluster %d\n', n0+j) + [ix, xold, xnew] = break_a_cluster(Xd(ind, :), wroll, ss(ind), nlow, rmin, use_CCG, nPCs); -xold = []; -dt = 1/1000; -for j = 1:10 - dd = data(ix, :); - if length(ix) < 2 * nlow - x = []; -% disp('done with this cluster (too small)') - break; - end - - [x, iclust, flag] = bimodal_pursuit(dd, wroll, ss(ix), rmin, nlow, 1, use_CCG); - - if flag==0 -% disp('done with this cluster') - break; - end + aj(j) = gather(mean(amps(ind(ix)))); % mean amplitude of this cluster + % fprintf('amps = %2.2f \n\n', aj(j)); + kid(ind(ix)) = j; % assign these spikes to this cluster - ix = ix(iclust); - xold = x; -end + if length(ix) == length(ind) + break; + end -end + end + aj = aj(1:j); +end +% this function breaks a cluster into two clusters if it found to be bimodal +function [ix, xold, x] = break_a_cluster(data, wroll, ss, nlow, rmin, use_CCG, nPCs) + ix = 1:size(data, 1); + + xold = []; + dt = 1/1000; + for j = 1:10 + dd = data(ix, :); + if length(ix) < 2 * nlow + x = []; + % disp('done with this cluster (too small)') + break; + end + + [x, iclust, flag] = bimodal_pursuit(dd, wroll, ss(ix), rmin, nlow, 1, use_CCG, nPCs); + + if flag == 0 + % disp('done with this cluster') + break; + end + + ix = ix(iclust); % these are the spikes that remain in the cluster + xold = x; + end -function dd = grab_data(rez, y0, ktid) -xchan = abs(rez.yc - y0) < 20; -itemp = find(xchan(tmp_chan)); +end +% function dd = grab_data(rez, y0, ktid) +% xchan = abs(rez.yc - y0) < 20; +% itemp = find(xchan(tmp_chan)); -tin = ismember(ktid, itemp); -pid = ktid(tin); -data = rez.cProjPC(tin, :, :); +% tin = ismember(ktid, itemp); +% pid = ktid(tin); +% data = rez.cProjPC(tin, :, :); -iC = rez.iNeighPC(1:16, :); +% iC = rez.iNeighPC(1:16, :); -ich = unique(iC(:, itemp)); -ch_min = ich(1)-1; -ch_max = ich(end); +% ich = unique(iC(:, itemp)); +% ch_min = ich(1) - 1; +% ch_max = ich(end); -nsp = size(data,1); -dd = zeros(nsp, size(data,2), ch_max-ch_min, 'single'); -for j = 1:length(itemp) - ix = pid==itemp(j); - dd(ix, :, iC(:,itemp(j))-ch_min) = data(ix,:,:); -end -dd = dd(:,:); +% nsp = size(data, 1); +% dd = zeros(nsp, size(data, 2), ch_max - ch_min, 'single'); +% for j = 1:length(itemp) +% ix = pid == itemp(j); +% dd(ix, :, iC(:, itemp(j)) - ch_min) = data(ix, :, :); +% end +% dd = dd(:, :); -end +% end diff --git a/sorting/Kilosort-3.0/clustering/template_learning.m b/sorting/Kilosort-3.0/clustering/template_learning.m index 8b606268..601c5ca0 100644 --- a/sorting/Kilosort-3.0/clustering/template_learning.m +++ b/sorting/Kilosort-3.0/clustering/template_learning.m @@ -1,123 +1,171 @@ -function rez = template_learning(rez, tF, st3) +function [rez, spike_times_for_kid] = template_learning(rez, tF, st3) -wPCA = rez.wPCA; -iC = rez.iC; -ops = rez.ops; + wPCA = rez.wPCA; % shape is #PC components, #channels + iC = rez.iC; + ops = rez.ops; + xcup = rez.xcup; + ycup = rez.ycup; -xcup = rez.xcup; -ycup = rez.ycup; - -wroll = []; -tlag = [-2, -1, 1, 2]; -for j = 1:length(tlag) - wroll(:,:,j) = circshift(wPCA, tlag(j), 1)' * wPCA; -end - -%% split templates into batches -rmin = 0.6; -nlow = 100; -n0 = 0; -use_CCG = 0; - -Nchan = rez.ops.Nchan; -Nk = size(iC,2); -yunq = unique(rez.yc); - -ktid = int32(st3(:,2)) + 1; -tmp_chan = iC(1, :); -ss = double(st3(:,1)) / ops.fs; - -dmin = rez.ops.dmin; -ycenter = (min(rez.yc) + dmin-1):(2*dmin):(max(rez.yc)+dmin+1); -dminx = rez.ops.dminx; -xcenter = (min(rez.xc) + dminx-1):(2*dminx):(max(rez.xc)+dminx+1); -[xcenter, ycenter] = meshgrid(xcenter, ycenter); -xcenter = xcenter(:); -ycenter = ycenter(:); - -Wpca = zeros(6, Nchan, 1000, 'single'); -nst = numel(ktid); -hid = zeros(nst,1 , 'int32'); - -% ycup = rez.yc; - - -tic - -for j = 1:numel(ycenter) - if rem(j,5)==1 - fprintf('time %2.2f, GROUP %d/%d, units %d \n', toc, j, numel(ycenter), n0) - end - - y0 = ycenter(j); - x0 = xcenter(j); - xchan = (abs(ycup - y0) < dmin) & (abs(xcup - x0) < dminx); - itemp = find(xchan); - - if isempty(itemp) - continue; + wroll = []; + tlag = [-2, -1, 1, 2]; + % compute a product of each PC component with itself at 4 different time lags + for j = 1:length(tlag) + wroll(:, :, j) = circshift(wPCA, tlag(j), 1)' * wPCA; end - tin = ismember(ktid, itemp); - pid = ktid(tin); - data = tF(tin, :, :); - - if isempty(data) - continue; - end -% size(data) - - %https://github.com/MouseLand/Kilosort/issues/427 - try - ich = unique(iC(:, itemp)); - catch - tmpS = iC(:, itemp); - ich = unique(tmpS); + + %% split templates into batches by grid location + rmin = 0.6; % minimum correlation between templates + nlow = 100; % minimum number of spikes needed to keep a template + n0 = 0; % number of clusters so far + use_CCG = 0; + + Nchan = rez.ops.Nchan; + Nk = size(iC, 2); + yunq = unique(rez.yc); + + ktid = int32(st3(:, 2)) + 1; % get upsampled grid location of each spike + tmp_chan = iC(1, :); + ss = double(st3(:, 1)) / ops.fs; % spike times in seconds + + dmin = rez.ops.dmin; + ycenter = (min(rez.yc) + dmin - 1):(2 * dmin):(max(rez.yc) + dmin + 1); + dminx = rez.ops.dminx; + xcenter = (min(rez.xc) + dminx - 1):(2 * dminx):(max(rez.xc) + dminx + 1); + [xcenter, ycenter] = meshgrid(xcenter, ycenter); % define grid of electrode locations + xcenter = xcenter(:); + ycenter = ycenter(:); + + Wpca = zeros(ops.nPCs, Nchan, 1000, 'single'); + spike_times_for_kid = cell(1000, 1); + nst = numel(ktid); % number of spikes + hid = zeros(nst, 1, 'int32'); + + % flag for plotting + ops.fig = getOr(ops, 'fig', 1); + + tic + for j = 1:numel(ycenter) % process spikes found for each y grid location + if rem(j, round(numel(ycenter) / 10)) == 0 % print progress at most 10 times + fprintf('time %2.2f, grid loc. grp. %d/%d, units %d \n', toc, j, numel(ycenter), n0) + end + + y0 = ycenter(j); % get y electrode locations + x0 = xcenter(j); % get x electrode locations + % exclude grid locations that are not by electrodes + xchan = (abs(ycup - y0) < dmin) & (abs(xcup - x0) < dminx); + itemp = find(xchan); % get nearby electrode location indices for templates + + if isempty(itemp) + continue; + end + tin = ismember(ktid, itemp); % get bitmask for spikes near electrodes + pid = ktid(tin); % get spike ids for spikes near electrodes + data = tF(tin, :, :); % exclude PC convolutions for spikes far from electrodes + + if isempty(data) + continue; + end + % size(data) + + %https://github.com/MouseLand/Kilosort/issues/427 + try + ich = unique(iC(:, itemp)); + catch + tmpS = iC(:, itemp); + ich = unique(tmpS); + end + % ch_min = ich(1)-1; + % ch_max = ich(end); + + if numel(ich) < 1 + continue; + end + + nsp = size(data, 1); + dd = zeros(nsp, ops.nPCs, numel(ich), 'single'); % #spikes, #PC components, #channels + for k = 1:length(itemp) % for each template + ix = pid == itemp(k); % ix is bitmask for spikes near this electrode + % how to go from channels to different order, ib is indeces ordered like ich + [~, ia, ib] = intersect(iC(:, itemp(k)), ich); + dd(ix, :, ib) = data(ix, :, ia); % dd is just the PC convolutions ordered by distance from electrode + end + + kid = run_pursuit(dd, nlow, rmin, n0, wroll, ss(tin), use_CCG, ops.nPCs); + + [~, ~, kid] = unique(kid); % make cluster ids consecutive + nmax = max(kid); % number of clusters found + for t = 1:nmax % for each cluster + % Wpca(:, ch_min+1:ch_max, t + n0) = gather(sq(mean(dd(kid==t,:,:),1))); + % compute mean PC coordinates for each cluster of spikes, there is a separate PC space for each channel + Wpca(:, ich, t + n0) = gather(sq(mean(dd(kid == t, :, :), 1))); + spike_times_for_kid{t + n0} = st3(tin(kid == t), 1); % get spike times for each cluster + end + + hid(tin) = gather(kid + n0); + n0 = n0 + nmax; end -% ch_min = ich(1)-1; -% ch_max = ich(end); - - if numel(ich)<1 - continue; + Wpca = Wpca(:, :, 1:n0); + % Wpca = cat(2, Wpca, zeros(size(Wpca,1), ops.nEig-size(Wpca, 2), size(Wpca, 3), 'single')); + spike_times_for_kid = spike_times_for_kid(1:n0); + % plot mean PC coordinates for each cluster for each channel and cluster (not that useful) + % if ops.fig + % ichc = gather(ich); + % nPCs = size(Wpca, 1); + % figure(12) + % for k=1:n0; for ichcs=1:numel(ichc); for iPC=1:nPCs; scatter(ichcs*ones('like',Wpca(:,ichcs,k)), Wpca(:,ichcs,k)+40*k); end; end; end + % % plot top 3 PC coordinates for each cluster for all channels in each cluster + % figure(13) + % for k=1:n0 + % scatter3(Wpca(1,1,k), Wpca(2,1,k), Wpca(3,1,k), 20, [1-k/n0,k/n0,k/n0]); + % hold on; + % end + % xlabel('PC1'); + % ylabel('PC2'); + % zlabel('PC3'); + % set(gca,'DataAspectRatio',[1 1 1]); + % end + toc + %% + % Ncomps = min(ops.nEig, size(Wpca, 2)); + Ncomps = ops.nEig; + rez.W = zeros(ops.nt0, 0, Ncomps, 'single'); + rez.U = zeros(ops.Nchan, 0, Ncomps, 'single'); + rez.mu = zeros(1, 0, 'single'); + if ops.fig + figure(14); hold on; + RGB_colors = rand(n0, 3); end - - nsp = size(data,1); - dd = zeros(nsp, 6, numel(ich), 'single'); - for k = 1:length(itemp) - ix = pid==itemp(k); - % how to go from channels to different order - [~,ia,ib] = intersect(iC(:,itemp(k)), ich); - dd(ix, :, ib) = data(ix,:,ia); + for t = 1:n0 % for each cluster + dWU = wPCA * gpuArray(Wpca(:, :, t)); % multiply PC components by mean PC coordinates for each cluster + % shape of dWU is nt0 x Nchan + % take absolute value, then sum across channels, then find a shift to align max of abs(dWU) to nt0min + [~, dWU_shift] = max(sum(abs(dWU), 2), [], 1); % shape of dWU_shift is 1 x 1 + dWU_shift = dWU_shift - ops.nt0min; % get shift needed to align max value for each channel to nt0min + dWU = circshift(dWU, -dWU_shift); % shift PC components by that amount + [w, s, u] = svdecon(dWU); % compute SVD of that product to deconstruct it into spatial and temporal components + wsign = -sign(w(ops.nt0min, 1)); % flip sign of waveform if necessary, for consistency + % vvv save first Ncomps components of W, containing final rotation matrix + rez.W(:, t, :) = gather(wsign * w(:, 1:Ncomps)); + % vvv save first Ncomps components of U, containing initial rotation and scaling matrix + rez.U(:, t, :) = gather(wsign * u(:, 1:Ncomps) * s(1:Ncomps, 1:Ncomps)); + rez.mu(t) = gather(sum(sum(rez.U(:, t, :) .^ 2)) ^ .5); % get norm of U + rez.U(:, t, :) = rez.U(:, t, :) / rez.mu(t); % normalize U + if ops.fig + for iloc = 1:numel(ich) % for each channel + % plot PC component reconstructions for each channel, with color based on cluster + plot(dWU(:, iloc) - 10 * iloc, 'color', RGB_colors(t, :)); + end + end end - kid = run_pursuit(dd, nlow, rmin, n0, wroll, ss(tin), use_CCG); - - [~, ~, kid] = unique(kid); - nmax = max(kid); - for t = 1:nmax -% Wpca(:, ch_min+1:ch_max, t + n0) = gather(sq(mean(dd(kid==t,:,:),1))); - Wpca(:, ich, t + n0) = gather(sq(mean(dd(kid==t,:,:),1))); + if ops.fig + title('First Multi-Channel Templates (Color Coded by Cluster)'); + xlabel('Time'); + ylabel('Channel'); end - - hid(tin) = gather(kid + n0); - n0 = n0 + nmax; -end -Wpca = Wpca(:,:,1:n0); -toc -%% -rez.W = zeros(ops.nt0, 0, 3, 'single'); -rez.U = zeros(ops.Nchan, 0, 3, 'single'); -rez.mu = zeros(1,0, 'single'); -for t = 1:n0 - dWU = wPCA * gpuArray(Wpca(:,:,t)); - [w,s,u] = svdecon(dWU); - wsign = -sign(w(ops.nt0min,1)); - rez.W(:,t,:) = gather(wsign * w(:,1:3)); - rez.U(:,t,:) = gather(wsign * u(:,1:3) * s(1:3,1:3)); - rez.mu(t) = gather(sum(sum(rez.U(:,t,:).^2))^.5); - rez.U(:,t,:) = rez.U(:,t,:) / rez.mu(t); + %% + rez.ops.wPCA = wPCA; + % remove any NaNs from rez.W + rez.W(isnan(rez.W)) = 0; end - -%% -rez.ops.wPCA = wPCA; - diff --git a/sorting/Kilosort-3.0/mainLoop/extractTemplatesfromSnippets.m b/sorting/Kilosort-3.0/mainLoop/extractTemplatesfromSnippets.m index 5ec5771e..d185c6f5 100644 --- a/sorting/Kilosort-3.0/mainLoop/extractTemplatesfromSnippets.m +++ b/sorting/Kilosort-3.0/mainLoop/extractTemplatesfromSnippets.m @@ -1,73 +1,446 @@ function [wTEMP, wPCA] = extractTemplatesfromSnippets(rez, nPCs) -% this function is very similar to extractPCfromSnippets. -% outputs not just the PC waveforms, but also the template "prototype", -% basically k-means clustering of 1D waveforms. + % this function is very similar to extractPCfromSnippets. + % outputs not just the PC waveforms, but also the template "prototype", + % basically k-means clustering of 1D waveforms. -ops = rez.ops; + ops = rez.ops; -% skip every this many batches -nskip = getOr(ops, 'nskip', 25); + % skip every this many batches + nskip = getOr(ops, 'nskip', 25); -Nbatch = rez.temp.Nbatch; -NT = ops.NT; -batchstart = 0:NT:NT*Nbatch; + Nbatch = rez.temp.Nbatch; + NT = ops.NT; + batchstart = 0:NT:NT * Nbatch; -fid = fopen(ops.fproc, 'r'); % open the preprocessed data file + fid = fopen(ops.fproc, 'r'); % open the preprocessed data file -k = 0; -dd = gpuArray.zeros(ops.nt0, 5e4, 'single'); % preallocate matrix to hold 1D spike snippets -for ibatch = 1:nskip:Nbatch - offset = 2 * ops.Nchan*batchstart(ibatch); - fseek(fid, offset, 'bof'); - dat = fread(fid, [ops.Nchan NT], '*int16'); - dat = dat'; + k = 0; + dd = gpuArray.zeros(ops.nt0, 5e4, 'single'); % preallocate matrix to hold 1D spike snippets + if ops.fig % PLOTTING + figure(1); hold on; + end + for ibatch = 1:nskip:Nbatch + offset = 2 * ops.Nchan * batchstart(ibatch); + fseek(fid, offset, 'bof'); + dat = fread(fid, [ops.Nchan NT], '*int16'); + dat = dat'; + + % move data to GPU and scale it back to unit variance + dataRAW = gpuArray(dat); + dataRAW = single(dataRAW); + dataRAW = dataRAW / ops.scaleproc; + + % find isolated spikes from each batch + [row, col] = isolated_peaks_multithreshold(-abs(dataRAW), ops, ibatch); + + % for each peak, get the voltage snippet from that channel + clips = get_SpikeSample(dataRAW, row, col, ops, 0); + c = sq(clips(:, :)); + if ops.fig == 1 % PLOTTING + plot(c) + end + if k + size(c, 2) > size(dd, 2) + dd(:, 2 * size(dd, 2)) = 0; + end - % move data to GPU and scale it back to unit variance - dataRAW = gpuArray(dat); - dataRAW = single(dataRAW); - dataRAW = dataRAW / ops.scaleproc; + dd(:, k + [1:size(c, 2)]) = c; + k = k + size(c, 2); + if k > 1e5 + break; + end + end + fclose(fid); + if ops.fig == 1 % PLOTTING + title('local isolated spikes (1D voltage waveforms)'); + end + % discard empty samples + dd = dd(:, 1:k); + % window definition + % window = tukeywin(size(wTEMP, 1), 0.9); + sigma_time = 0.125; % ms of the gaussian kernel to focus penalty on central region of waveforms + sigma = ops.fs * sigma_time / 1000; % samples + gaussian_window = gausswin(size(dd, 1), (size(dd, 1) - 1) / (2 * sigma)); + zeros_for_tukey = zeros(size(dd, 1), 1); + percent_tukey_coverage = 80; + partial_tukey_window = tukeywin(ceil(size(dd, 1) * percent_tukey_coverage / 100), 0.5); + % put middle of partial tukey at the center of zeros_for_tukey + zeros_for_tukey(ceil(size(dd, 1) / 2) - ceil(size(partial_tukey_window, 1) / 2) + ... + 1:ceil(size(dd, 1) / 2) + floor(size(partial_tukey_window, 1) / 2)) = partial_tukey_window; + tukey_window = zeros_for_tukey; + + dd_windowed = dd .* tukey_window; + % align max absolute peaks to the center of the template (ops.nt0min) + [~, peak_indexes] = max(abs(dd_windowed), [], 1); + spikes_shifts = peak_indexes - ops.nt0min; + dd_aligned = gpuArray(nan(size(dd))); + dd_windowed_aligned = gpuArray(nan(size(dd))); + for i = 1:size(dd_windowed, 2) + dd_windowed_aligned(:, i) = circshift(dd_windowed(:, i), -spikes_shifts(i)); + dd_aligned(:, i) = circshift(dd(:, i), -spikes_shifts(i)); + end + dd_cpu = double(gather(dd_windowed_aligned)); + % PCA is computed on the windowed data + [U, ~, ~] = svdecon(dd_cpu); % the PCs are just the left singular vectors of the waveforms + % if ops.fig == 1 % PLOTTING + % figure(4); hold on; + % for i = 1:nPCs + % plot(U(:,i)+i*1); + % end + % title(strcat("Top ", num2str(nPCs), " PCs")); + % end + wPCA = gpuArray(single(U(:, 1:nPCs))); % take as many as needed + % adjust the arbitrary sign of the first PC so its peak is downward + wPCA(:, 1) =- wPCA(:, 1) * sign(wPCA(ops.nt0min, 1)); - % find isolated spikes from each batch - [row, col, mu] = isolated_peaks_new(dataRAW, ops); + use_kmeans = true; + % initialize the template clustering + if use_kmeans + dd_pca = wPCA' * dd_aligned; + % compute k-means clustering of the waveforms + rng('default'); rng(1); % initializing random number generator for reproducibility + % stream = RandStream('mlfg6331_64'); % Random number stream + p = gcp('nocreate'); % If no pool, do not create new one. + if isempty(p) + options = statset('UseParallel', 0); + num_jobs = 12; + else + options = statset('UseParallel', 1); %'UseSubstreams', 1,'Streams', stream); + num_jobs = p.NumWorkers; + end + % try to use all available jobs, but if it fails, halve the number of jobs and + % if it still fails, halve it again, until it doesn't fail + % if all else fails, just run it sequentially + try + [cluster_id, ~, ~, Dist_from_K] = kmeans(dd_pca', nPCs, 'Distance', 'sqeuclidean', ... + 'MaxIter', 10000, 'Replicates', num_jobs, 'Display', 'final', 'Options', options); + catch + disp('k-means failed in parallel, running sequentially instead') + [cluster_id, ~, ~, Dist_from_K] = kmeans(dd_pca', nPCs, 'Distance', 'sqeuclidean', ... + 'MaxIter', 10000, 'Replicates', num_jobs, 'Display', 'final'); + end + spikes = gpuArray(nan(size(dd_aligned))); + number_of_spikes_to_use = nan(nPCs, 1); + for K = 1:nPCs + % use 90% of the closest spikes to the cluster center, to avoid outliers + number_of_close_spikes = floor(sum(cluster_id == K) * 0.9); + [~, min_dist_idxs] = mink(Dist_from_K(:, K), number_of_close_spikes); + spikes_to_use = dd_aligned(:, cluster_id == K & ismember(1:length(cluster_id), min_dist_idxs)'); + number_of_spikes_to_use(K) = size(spikes_to_use, 2); + % choose closest spikes to the cluster center + spikes(:, sum(number_of_spikes_to_use(1:K - 1)) + ... + 1:sum(number_of_spikes_to_use, 'omitnan')) = spikes_to_use; + end + % drop nan values + spikes = spikes(:, ~isnan(spikes(1, :))); - % for each peak, get the voltage snippet from that channel - clips = get_SpikeSample(dataRAW, row, col, ops, 0); + % dbstop in extractTemplatesfromSnippets.m at 140 + total_num_spikes_used = sum(number_of_spikes_to_use); + % take max of wave peaks + [peak_amplitudes, ~] = max(spikes, [], 1); - c = sq(clips(:, :)); + wave_choice_left_bounds = [1; cumsum(number_of_spikes_to_use)]; + wave_choice_right_bounds = wave_choice_left_bounds(2:end); + wave_choice_left_bounds = wave_choice_left_bounds(1:end - 1); + N_waves_between_choices = wave_choice_right_bounds - wave_choice_left_bounds; + % define wave_choice_boundaries as all 1 std dev above and below the cluster centers + % this avoids noisy spikes being chosen as templates + else + % wTEMP = dd(:, randperm(size(dd,2), nPCs)); + % sort dd by largest peak amplitude, with positive peaks first + % if negative peaks are larger, those spikes will be added later in the array + % check if each spikes max or min is larger + % if the max is larger, use the max_idx, otherwise use the min_idx + % this is to make sure that the spikes are segregated by polarity + % and that the largest spikes are used first + maxes_for_each_spike = max(dd); + mins_for_each_spike = min(dd); + max_larger_mask = maxes_for_each_spike > abs(mins_for_each_spike); + [max_peaks, max_larger_sorted_idx] = sort(max(dd(:, max_larger_mask)), 'descend'); + [min_peaks, min_larger_sorted_idx] = sort(min(dd(:, ~max_larger_mask)), 'descend'); + if ~isempty(max_peaks) && ~isempty(min_peaks) + peak_amplitudes = [max_peaks, min_peaks]; + elseif ~isempty(max_peaks) + peak_amplitudes = max_peaks; + elseif ~isempty(min_peaks) + peak_amplitudes = min_peaks; + else + error('No spikes found in the data!') + end + % sort the spikes by amplitude, with positive and negative peaks separate + max_mask_idx = find(max_larger_mask); + min_mask_idx = find(~max_larger_mask); + idx = [max_mask_idx(max_larger_sorted_idx), min_mask_idx(min_larger_sorted_idx)]; + total_num_spikes_used = length(idx); - if k+size(c,2)>size(dd,2) - dd(:, 2*size(dd,2)) = 0; + % assign non-uniform wave choice boundaries based on the amplutide of the peaks + fraction_of_N_peaks = ceil(0.02 * length(peak_amplitudes)); + % get even distribution of spike amplitudes, treating positive and negative peaks separately + % also skip the first and last chunks to avoid outliers + num_max_peak_boundaries = ceil(length(max_peaks) / length(peak_amplitudes) * nPCs); + num_min_peak_boundaries = nPCs - num_max_peak_boundaries; + if ~isempty(max_peaks) && ~isempty(min_peaks) % if there are both positive and negative peaks + max_peak_boundaries = linspace(max_peaks(min(fraction_of_N_peaks, length(max_peaks))), ... + max_peaks(end), num_max_peak_boundaries); + min_peak_boundaries = linspace(min_peaks(1), min_peaks(end - min(fraction_of_N_peaks, ... + length(min_peaks))), num_min_peak_boundaries); + % combine the boundaries + peak_boundaries = [max_peaks(1), max_peak_boundaries, min_peak_boundaries(2:end), ... + min_peaks(end)]; + elseif ~isempty(max_peaks) % if there are only positive peaks + max_peak_boundaries = linspace(max_peaks(min(fraction_of_N_peaks, length(max_peaks))), ... + max_peaks(end), nPCs); + peak_boundaries = [max_peaks(1), max_peak_boundaries, max_peaks(end)]; + elseif ~isempty(min_peaks) % if there are only negative peaks + min_peak_boundaries = linspace(min_peaks(1), min_peaks(end - min(fraction_of_N_peaks, ... + length(min_peaks))), nPCs); + peak_boundaries = [min_peaks(1), min_peak_boundaries, min_peaks(end)]; + else + error('No spikes found in the data!') + end + % find the closest peak to each boundary + [~, wave_choice_boundaries] = min(abs(peak_amplitudes - peak_boundaries'), [], 2); + % N_waves_between_choices = diff(wave_choice_boundaries); + + % assign uniform wave choice boundaries + % uniform_wave_choice_boundaries = round(linspace(1, length(peaks), nPCs+1)); + + % wave_choice_left_bounds = wave_choice_boundaries(1:end-1); + group_percent_expansion = 8; % percent + group_size_expansion = ceil(group_percent_expansion / 2/100 * length(peak_amplitudes)); + % define the boundaries, preventing negative values + wave_choice_left_expand = max(wave_choice_boundaries - group_size_expansion, 1); + wave_choice_left_bounds = wave_choice_left_expand(1:end - 1); + wave_choice_right_expand = min(wave_choice_boundaries + group_size_expansion, ... + length(peak_amplitudes)); + wave_choice_right_bounds = wave_choice_right_expand(2:end); + N_waves_between_choices = wave_choice_right_bounds - wave_choice_left_bounds; + disp(['Chunks overlap by ', num2str(group_percent_expansion), '%, which is ', ... + num2str(group_size_expansion), ' spikes']) + disp('Number of spikes in each chunk: ') + disp(N_waves_between_choices') + end + if ops.fig % PLOTTING + figure(22); + plot((1:length(peak_amplitudes)) * ops.nt0, peak_amplitudes, 'm'); hold on; + for iPeak = 1:length(peak_amplitudes) + if mod(iPeak, 2) == 0 + color = 'k'; + if use_kmeans + plot((-ops.nt0min + 1:ops.nt0min - 1) + iPeak * ops.nt0, spikes(:, iPeak)', ... + 'DisplayName', num2str(iPeak), 'Color', color) + else + plot((-ops.nt0min + 1:ops.nt0min - 1) + iPeak * ops.nt0, dd(:, idx(iPeak))', ... + 'DisplayName', num2str(iPeak), 'Color', color) + end + end + end + title('peak amplitudes for each spike') + % plot vertical lines at the boundaries + for iBoundary = 1:length(wave_choice_left_bounds) + % different color for each chunk, from magenta to cyan + color = [1 - iBoundary / length(wave_choice_left_bounds), iBoundary / ... + length(wave_choice_left_bounds), 1]; + plot([1, 1] * wave_choice_left_bounds(iBoundary) * ops.nt0, ylim, ... + 'Color', color, 'LineWidth', 2) + plot([1, 1] * wave_choice_right_bounds(iBoundary) * ops.nt0, ylim, ... + 'Color', color, 'LineWidth', 2) + end + end + if use_kmeans + wTEMP = spikes(:, wave_choice_left_bounds); % start with first spike in each cluster + else + plot((1:length(peak_amplitudes)) * ops.nt0, peak_amplitudes, 'm', 'LineWidth', 3); + wTEMP = dd(:, idx(wave_choice_left_bounds)); % initialize with a smooth range of amplitudes end + largest_CC_idx = 1; + N_tries_for_largest_CC_idx_so_far = 0; + best_CC_idxs = wave_choice_left_bounds; + lowest_total_cost_for_each_chunk = 1e12 * ones(1, length(wave_choice_left_bounds)); + iter = 1; + + while iter < 2 + % replace with next wave in chunk to check correlation, + % each wave index withing the chunk starting at the wave_choice_left_bounds + if use_kmeans + descriptor = 'k-means cluster'; + wTEMP(:, largest_CC_idx) = spikes(:, wave_choice_left_bounds(largest_CC_idx) + ... + N_tries_for_largest_CC_idx_so_far); + else + descriptor = 'chunk'; + wTEMP(:, largest_CC_idx) = dd(:, idx(wave_choice_left_bounds(largest_CC_idx) + ... + N_tries_for_largest_CC_idx_so_far)); + end + % multiply waveforms by a Gaussian with the sigma value + % this is to make the correlation more sensitive to the central shape of the waveform + % wTEMP_for_corr = wTEMP .* gausswin(size(wTEMP, 1), (size(wTEMP, 1) - 1) / (2 * sigma)); + % align largest peak of each template to ops.nt0min before checking correlation to avoid + % the correlation being sensitive to the alignment of the waveforms + wTEMP_for_corr = wTEMP .* gaussian_window; % use window to focus on central region of waveforms + CC = corr(wTEMP_for_corr); + + %% section to compute terms of the cost function + % get residual of the waveform for this row of the CC matrix + % sum the absolute value of the residual, scale by the absolute value of wTEMP_for_corr + % this is to avoid using the waves with non-central shapes, by using it as a cost function + wTEMP_gaussian_residual = sum(abs(wTEMP(:, largest_CC_idx) - ... + wTEMP_for_corr(:, largest_CC_idx))) / sum(abs(wTEMP_for_corr(:, largest_CC_idx))); + + % compute the penalty for the highest single correlation in the CC matrix for this template choice + if largest_CC_idx == 1 + highest_template_similarity_penalty = max(max(CC(largest_CC_idx, 2:end), ... + max(CC(2:end, largest_CC_idx)))); + else + highest_template_similarity_penalty = max(max(CC(1:largest_CC_idx - 1, largest_CC_idx)), ... + max(CC(largest_CC_idx, 1:largest_CC_idx - 1))); + end + highest_template_similarity_penalty = max(highest_template_similarity_penalty, 0); % make sure it is not negative - dd(:, k + [1:size(c,2)]) = c; - k = k + size(c,2); - if k>1e5 - break; + % compute sum of all similarities with other previous template choices + % negative correlations are not penalized + if largest_CC_idx == 1 + corr_sum_with_other_template_choices = sum(max(CC(largest_CC_idx, 2:end), 0)) + ... + sum(max(CC(2:end, largest_CC_idx), 0)); + else + corr_sum_with_other_template_choices = sum(sum(max(CC(1:largest_CC_idx - 1, largest_CC_idx), 0))) + ... + sum(sum(max(CC(largest_CC_idx, 1:largest_CC_idx - 1), 0))); + end + + % compute the total cost for this wave choice, cheapest waveform will be chosen + corr_sum_with_other_template_choices_term = corr_sum_with_other_template_choices; + wTEMP_gaussian_residual_term = nPCs * wTEMP_gaussian_residual; + highest_template_similarity_penalty_term = nPCs * highest_template_similarity_penalty; + total_cost_for_wave = corr_sum_with_other_template_choices_term + ... + wTEMP_gaussian_residual_term + highest_template_similarity_penalty_term; + + % choose the best wave for this chunk/cluster + % check the cost, and ensure that the wave has not already been chosen before in a previous chunk + if (total_cost_for_wave < lowest_total_cost_for_each_chunk(largest_CC_idx)) && ... + ~ismember(wave_choice_left_bounds(largest_CC_idx) + N_tries_for_largest_CC_idx_so_far, best_CC_idxs) + lowest_total_cost_for_each_chunk(largest_CC_idx) = total_cost_for_wave; + best_CC_idxs(largest_CC_idx) = wave_choice_left_bounds(largest_CC_idx) + N_tries_for_largest_CC_idx_so_far; + end + N_tries_for_largest_CC_idx_so_far = N_tries_for_largest_CC_idx_so_far + 1; + % terminate if we have tried all waves in the amplitude-sorted chunk + if N_tries_for_largest_CC_idx_so_far >= N_waves_between_choices(largest_CC_idx) || ... + N_tries_for_largest_CC_idx_so_far >= total_num_spikes_used + disp(strcat("Tried all waves in "+descriptor + " ", num2str(largest_CC_idx), ... + ", using wave idx with best CC: ", num2str(best_CC_idxs(largest_CC_idx)))) + % sorted_CC = sort(CC(largest_CC_idx,:), 'descend'); + disp(strcat("Total cross-channel correlation for this cluster ", num2str(sum(abs(CC(largest_CC_idx, :)))))) + disp("Residual cost for this "+descriptor + " was " + num2str(wTEMP_gaussian_residual)) + disp("Highest template similarity penalty for this "+descriptor + " was " + num2str(highest_template_similarity_penalty)) + disp("Percent of influence for each term: ") + corr_sum_with_other_template_choices_percent = ... + corr_sum_with_other_template_choices_term / total_cost_for_wave * 100; + wTEMP_gaussian_residual_percent = wTEMP_gaussian_residual_term / total_cost_for_wave * 100; + highest_template_similarity_penalty_percent = ... + highest_template_similarity_penalty_term / total_cost_for_wave * 100; + disp([corr_sum_with_other_template_choices_percent, wTEMP_gaussian_residual_percent, ... + highest_template_similarity_penalty_percent]) + disp("Final cost for this "+descriptor + " was: " + num2str(lowest_total_cost_for_each_chunk(largest_CC_idx))) + % disp(sorted_CC) + disp(CC) + largest_CC_idx = largest_CC_idx + 1; + N_tries_for_largest_CC_idx_so_far = 0; + if largest_CC_idx > nPCs + largest_CC_idx = 1; + % correlated_pairs = false; % terminate the while loop + disp("Final waveforms chosen:") + disp(best_CC_idxs) + disp("Final CC matrix:") + disp(CC) + iter = iter + 1; + end + end end -end -fclose(fid); - -% discard empty samples -dd = dd(:, 1:k); - -% initialize the template clustering with random waveforms -% wTEMP = dd(:, randperm(size(dd,2), nPCs)); -wTEMP = dd(:, round(linspace(1, size(dd,2), nPCs))); -wTEMP = wTEMP ./ sum(wTEMP.^2,1).^.5; % normalize them - -for i = 1:10 - % at each iteration, assign the waveform to its most correlated cluster - cc = wTEMP' * dd; - [amax, imax] = max(cc,[],1); - for j = 1:nPCs - wTEMP(:,j) = dd(:,imax==j) * amax(imax==j)'; % weighted average to get new cluster means - end - wTEMP = wTEMP ./ sum(wTEMP.^2,1).^.5; % unit normalize -end -dd = double(gather(dd)); -[U Sv V] = svdecon(dd); % the PCs are just the left singular vectors of the waveforms + if use_kmeans + wTEMP = dd(:, randperm(size(dd, 2), nPCs)); % removing this line will cause KS to not find + % spikes sometimes... variable is overwritten in the next line, so it's inconsequential + wTEMP(:, 1:nPCs) = spikes(:, best_CC_idxs(1:nPCs)); + else + wTEMP(:, 1:nPCs) = dd(:, idx(best_CC_idxs(1:nPCs))); + end + + wTEMP = wTEMP ./ sum(wTEMP .^ 2, 1) .^ .5; % normalize the templates + if ops.fig % PLOTTING + % wTEMP_for_CC_final = wTEMP .* repmat(gausswin(size(wTEMP,1), (size(wTEMP,1)-1)/(2*sigma)), 1, size(wTEMP,2)); + % specify colormap to be 'cool' + cmap = colormap(cool(nPCs)); + figure(2); hold on; + scale = 0.8; + for i = 1:nPCs + windowed_wTEMP = tukey_window .* wTEMP(:, i); + plot(wTEMP(:, i) + i * scale, 'LineWidth', 2, 'Color', cmap(i, :)); + % plot standardized Gaussian multiplied waveforms for comparison + plot(windowed_wTEMP + i * scale, 'r'); + if i == nPCs + % show gaussian window + plot(i * scale + 0.5 * tukey_window, 'g'); + plot(i * scale + 0.5 * gaussian_window, 'c'); + end + end + title('initial templates'); + pbaspect([1 2 1]) + end + + % use k-means isolated spikes for correlation calculation and averaging to ignore outlier spikes + if use_kmeans + % just take average of top 10 percent most correlated spikes to the ones chosen in wTEMP + % for each cluster. Use wave_choice_left_bounds to get the spikes to use from 'spikes' + for iCluster = 1:nPCs + % get the spikes that were chosen for this cluster + cluster_spikes = spikes(:, wave_choice_left_bounds(iCluster):wave_choice_right_bounds(iCluster)); + % get the correlation matrix for this cluster + cluster_CC = corr(cluster_spikes); + % get the top 1 percent most correlated spikes to the chosen spikes, must be at least 10 spikes + try + [~, top_CC_idxs] = maxk(cluster_CC(:, 1), max(ceil(size(cluster_CC, 1) * 0.01), 10)); + catch + % if there are less than 10 spikes in the cluster, just take the top 1 percent + [~, top_CC_idxs] = maxk(cluster_CC(:, 1), ceil(size(cluster_CC, 1) * 0.01)); + + end + + % get the top 10 percent most correlated spikes + top_CC_spikes = cluster_spikes(:, top_CC_idxs); + % average the top 10 most correlated spikes + wTEMP(:, iCluster) = mean(top_CC_spikes, 2); + end -wPCA = gpuArray(single(U(:, 1:nPCs))); % take as many as needed -wPCA(:,1) = - wPCA(:,1) * sign(wPCA(ops.nt0min,1)); % adjust the arbitrary sign of the first PC so its negativity is downward + wTEMP = wTEMP ./ sum(wTEMP .^ 2, 1) .^ .5; % standardize the new clusters + else % use all spikes for correlation calculation and averaging + for i = 1:10 + % at each iteration, assign the waveform to its most correlated cluster + CC = wTEMP' * dd; + [amax, imax] = max(CC, [], 1); % find the best cluster for each waveform + for j = 1:nPCs + wTEMP(:, j) = dd(:, imax == j) * amax(imax == j)'; % weighted average to get new cluster means + end + wTEMP = wTEMP ./ sum(wTEMP .^ 2, 1) .^ .5; % standardize the new clusters + end + end + % tukey it + wTEMP_tukeyed = wTEMP .* tukey_window; + if ops.fig % PLOTTING + figure(3); hold on; + for i = 1:nPCs + plot(wTEMP(:, i) + i * scale, 'LineWidth', 2, 'Color', cmap(i, :)); + plot(wTEMP_tukeyed(:, i) + i * scale, 'r'); + if i == nPCs + % show tukey + plot(i * scale + 0.5 * tukey_window, 'c'); + end + end + % set aspect ratio to 3, 1 + title('prototype templates'); + pbaspect([1 2 1]) + end + + if use_kmeans + % recomputing PCA on the k-means isolated spikes (90% of closest spikes to cluster center) + [U, ~, ~] = svdecon(spikes); + wPCA = gpuArray(single(U(:, 1:nPCs))); % take as many as needed + end +end diff --git a/sorting/Kilosort-3.0/mainLoop/getMeUtU.m b/sorting/Kilosort-3.0/mainLoop/getMeUtU.m index 315ff2ca..a71219a2 100644 --- a/sorting/Kilosort-3.0/mainLoop/getMeUtU.m +++ b/sorting/Kilosort-3.0/mainLoop/getMeUtU.m @@ -1,23 +1,30 @@ function [UtU, maskU, iList] = getMeUtU(iU, iC, mask, Nnearest, Nchan) -% this function determines if two templates share any channels -% iU are the channels that each template is assigned to, one main channel per template -% iC has as column K the list of neigboring channels for channel K -% mask are the weights assigned for the corresponding neighboring channels -% in iC (gaussian-decaying) + % this function determines if two templates share any channels + % iU are the channels that each template is assigned to, one main channel per template + % iC has as column K the list of neigboring channels for channel K + % mask are the weights assigned for the corresponding neighboring channels + % in iC (gaussian-decaying) -Nfilt = numel(iU); + Nfilt = numel(iU); -U = gpuArray.zeros(Nchan, Nfilt, 'single'); % create a sparse matrix with ones if a channel K belongs to a template + U = gpuArray.zeros(Nchan, Nfilt, 'single'); % create a sparse matrix with ones if a channel K belongs to a template -ix = iC(:, iU) + int32([0:Nchan:(Nchan*Nfilt-1)]); % use the template primary channel to obtain its neighboring channels from iC -U(ix) = 1; % use this as an awkward index into U + chanFiltIndexes = int32(0:Nchan:(Nchan * Nfilt - 1)); + try % need to retry the below command, sometimes it fails on first try, with: + % "An unexpected error occurred trying to launch a kernel. The CUDA error was: + % invalid argument" + ix = iC(:, iU) + chanFiltIndexes; % use the template primary channel to obtain its neighboring channels from iC + catch + ix = iC(:, iU) + chanFiltIndexes; % use the template primary channel to obtain its neighboring channels from iC + end + U(ix) = 1; % use this as an awkward index into U -UtU = (U'*U) > 0; % if this is 0, the templates had not pair of channels in common + UtU = (U' * U) > 0; % This is a matrix with ones if two templates share any channels, if this is 0, the templates had not pair of channels in common -maskU = mask(:, iU); % we also return the masks for each template, picked from the corresponding mask of their primary channel + maskU = mask(:, iU); % we also return the masks for each template, picked from the corresponding mask of their primary channel -if nargin>3 && nargout>2 - cc = UtU; - [~, isort] = sort(cc, 1, 'descend'); % sort template pairs in order of how many channels they share - iList = int32(gpuArray(isort(1:Nnearest, :))); % take the Nnearest templates for each template -end + if nargin > 3 && nargout > 2 + cc = UtU; + [~, isort] = sort(cc, 1, 'descend'); % sort template pairs in order of how many channels they share + iList = int32(gpuArray(isort(1:Nnearest, :))); % take the Nnearest templates for each template + end diff --git a/sorting/Kilosort-3.0/mainLoop/getMeWtW.m b/sorting/Kilosort-3.0/mainLoop/getMeWtW.m index 89d8c752..4dbdf121 100644 --- a/sorting/Kilosort-3.0/mainLoop/getMeWtW.m +++ b/sorting/Kilosort-3.0/mainLoop/getMeWtW.m @@ -15,7 +15,7 @@ for j = 1:Nrank % the dot product factorizes into separable products for each spatio-temporal component utu0 = U0(:,:,i)' * U0(:,:,j); % spatial products - wtw0 = mexWtW2(Params, W(:,:,i), W(:,:,j), utu0); % temporal convolutions get multiplied wit hthe spatial products + wtw0 = mexWtW2(Params, W(:,:,i), W(:,:,j), utu0); % temporal convolutions get multiplied with the spatial products wtw0 = gather(wtw0); % take this matrix off the GPU (it's big) WtW = WtW + wtw0; % add it to the full correlation array end diff --git a/sorting/Kilosort-3.0/mainLoop/trackAndSort.m b/sorting/Kilosort-3.0/mainLoop/trackAndSort.m index cec8e78c..99e41e38 100644 --- a/sorting/Kilosort-3.0/mainLoop/trackAndSort.m +++ b/sorting/Kilosort-3.0/mainLoop/trackAndSort.m @@ -1,243 +1,236 @@ function [rez, st3, fWpc] = trackAndSort(rez, varargin) -if ~isempty(varargin) - iorder = varargin{1}; -else - iorder = 1:rez.ops.Nbatch; -end -% This is the extraction phase of the optimization. -% iorder is the order in which to traverse the batches + if ~isempty(varargin) + iorder = varargin{1}; + else + iorder = 1:rez.ops.Nbatch; + end + % This is the extraction phase of the optimization. + % iorder is the order in which to traverse the batches -% iorder = 1:rez.ops.Nbatch; + % iorder = 1:rez.ops.Nbatch; -% Turn on sorting of spikes before subtracting and averaging in mpnu8 -rez.ops.useStableMode = getOr(rez.ops, 'useStableMode', 1); -useStableMode = rez.ops.useStableMode; + % Turn on sorting of spikes before subtracting and averaging in mpnu8 + rez.ops.useStableMode = getOr(rez.ops, 'useStableMode', 1); + useStableMode = rez.ops.useStableMode; -ops = rez.ops; + ops = rez.ops; -% revert to the saved templates -W = gpuArray(rez.W); -U = gpuArray(rez.U); -mu = gpuArray(rez.mu); + % revert to the saved templates + W = gpuArray(rez.W); % temporal components of the templates (transformation matrix out of SVD space) + U = gpuArray(rez.U); % spatial components of the templates (transformation matrix into SVD space) + mu = gpuArray(rez.mu); % norm of the spatial components, used for scaling the templates -Nfilt = size(W,2); -nt0 = ops.nt0; -Nchan = ops.Nchan; + Nfilt = size(W, 2); + nt0 = ops.nt0; + Nchan = ops.Nchan; -dWU = gpuArray.zeros(nt0, Nchan, Nfilt, 'double'); -for j = 1:Nfilt - dWU(:,:,j) = mu(j) * squeeze(W(:, j, :)) * squeeze(U(:, j, :))'; -end + dWU = gpuArray.zeros(nt0, Nchan, Nfilt, 'double'); + for j = 1:Nfilt + dWU(:, :, j) = mu(j) * squeeze(W(:, j, :)) * squeeze(U(:, j, :))'; + end % dWU has shape nt0 by Nchan by Nfilt + ops.fig = getOr(ops, 'fig', 1); % whether to show plots every N batches -ops.fig = getOr(ops, 'fig', 1); % whether to show plots every N batches + NrankPC = ops.nPCs; % this one is the rank of the PCs, used to detect spikes with threshold crossings + Nrank = ops.nEig; % this one is the rank of the templates + rng('default'); rng(1); % initializing random number generator -NrankPC = ops.nPCs; % this one is the rank of the PCs, used to detect spikes with threshold crossings -Nrank = ops.nEig; % this one is the rank of the templates -rng('default'); rng(1); % initializing random number generator + % move these to the GPU + wPCA = gpuArray(ops.wPCA); -% move these to the GPU -wPCA = gpuArray(ops.wPCA); + nt0min = rez.ops.nt0min; + rez.ops = ops; + nBatches = rez.temp.Nbatch; + NT = ops.NT; -nt0min = rez.ops.nt0min; -rez.ops = ops; -nBatches = rez.temp.Nbatch; -NT = ops.NT; + % two variables for the same thing? number of nearest channels to each primary channel + NchanNear = min(ops.Nchan, 16); + Nnearest = min(ops.Nchan, 32); + % decay of gaussian spatial mask centered on a channel + sigmaMask = ops.sigmaMask; -% two variables for the same thing? number of nearest channels to each primary channel -NchanNear = min(ops.Nchan, 16); -Nnearest = min(ops.Nchan, 32); - -% decay of gaussian spatial mask centered on a channel -sigmaMask = ops.sigmaMask; - -% spike threshold for finding missed spikes in residuals -ops.spkTh = -6; % why am I overwriting this here? - -batchstart = 0:NT:NT*nBatches; - -% find the closest NchanNear channels, and the masks for those channels -[iC, mask, C2C] = getClosestChannels(rez, sigmaMask, NchanNear); - -niter = numel(iorder); - -% this is the absolute temporal offset in seconds corresponding to the start of the -% spike sorted time segment -t0 = ceil(rez.ops.trange(1) * ops.fs); - -nInnerIter = 60; % this is for SVD for the power iteration - -% schedule of learning rates for the model fitting part -% starts small and goes high, it corresponds approximately to the number of spikes -% from the past that were averaged to give rise to the current template -pm = exp(-1/400); - -Nsum = min(Nchan,7); % how many channels to extend out the waveform in mexgetspikes -% lots of parameters passed into the CUDA scripts -Params = double([NT Nfilt ops.Th(1) nInnerIter nt0 Nnearest ... - Nrank ops.lam pm Nchan NchanNear ops.nt0min 1 Nsum NrankPC ops.Th(1) useStableMode]); - -% initialize average number of spikes per batch for each template -nsp = gpuArray.zeros(Nfilt,1, 'int32'); - -% extract ALL features on the last pass -Params(13) = 2; % this is a flag to output features (PC and template features) - -% different threshold on last pass? -Params(3) = ops.Th(end); % usually the threshold is much lower on the last pass - -% kernels for subsample alignment -[Ka, Kb] = getKernels(ops, 10, 1); - -p1 = .95; % decay of nsp estimate in each batch - -% the list of channels each template lives on -% also, covariance matrix between templates -[~, iW] = max(abs(dWU(nt0min, :, :)), [], 2); -iW = int32(squeeze(iW)); -[WtW, iList] = getMeWtW(single(W), single(U), Nnearest); - -fprintf('Time %3.0fs. Final spike extraction ...\n', toc) - -fid = fopen(ops.fproc, 'r'); - -% allocate variables for collecting results -st3 = zeros(1e7, 5); % this holds spike times, clusters and other info per spike -ntot = 0; - - -% these ones store features per spike -fW = zeros(Nnearest, 1e7, 'single'); % Nnearest is the number of nearest templates to store features for -fWpc = zeros(NchanNear, 2*Nrank, 1e7, 'single'); % NchanNear is the number of nearest channels to take PC features from - - -dWU1 = dWU; - -[UtU, maskU] = getMeUtU(iW, iC, mask, Nnearest, Nchan); % this needs to change (but I don't know why!) - -for ibatch = 1:niter - k = iorder(ibatch); % k is the index of the batch in absolute terms - - % loading a single batch (same as everywhere) - offset = 2 * ops.Nchan*batchstart(k); - fseek(fid, offset, 'bof'); - dat = fread(fid, [ops.Nchan NT + ops.ntbuff], '*int16'); - dat = dat'; - dataRAW = single(gpuArray(dat))/ ops.scaleproc; - Params(1) = size(dataRAW,1); - - % decompose dWU by svd of time and space (via covariance matrix of 61 by 61 samples) - % this uses a "warm start" by remembering the W from the previous - % iteration - - % we don't need to update this anymore on every iteraton.... -% [W, U, mu] = mexSVDsmall2(Params, dWU, W, iC-1, iW-1, Ka, Kb); - - % UtU is the gram matrix of the spatial components of the low-rank SVDs - % it tells us which pairs of templates are likely to "interfere" with each other - % such as when we subtract off a template -% [UtU, maskU] = getMeUtU(iW, iC, mask, Nnearest, Nchan); % this needs to change (but I don't know why!)% - - % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - - % main CUDA function in the whole codebase. does the iterative template matching - % based on the current templates, gets features for these templates if requested (featW, featPC), - % gets scores for the template fits to each spike (vexp), outputs the average of - % waveforms assigned to each cluster (dWU0), - % and probably a few more things I forget about - - [st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp, errmsg] = ... - mexMPnu8(Params, dataRAW, single(U), single(W), single(mu), iC-1, iW-1, UtU, iList-1, ... - wPCA); - - % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - - % Sometimes nsp can get transposed (think this has to do with it being - % a single element in one iteration, to which elements are added - % nsp, nsp0, and pm must all be row vectors (Nfilt x 1), so force nsp - % to be a row vector. - [nsprow, nspcol] = size(nsp); - if nsprowsize(st3,1) - % if we exceed the original allocated memory, double the allocated sizes - fW(:, 2*size(st3,1)) = 0; - fWpc(:,:,2*size(st3,1)) = 0; - st3(2*size(st3,1), 1) = 0; - end - - st3(irange,1) = double(st); % spike times - st3(irange,2) = double(id0+1); % spike clusters (1-indexing) - st3(irange,3) = double(x0); % template amplitudes - st3(irange,4) = double(vexp); % residual variance of this spike - st3(irange,5) = ibatch; % batch from which this spike was found - - fW(:, irange) = gather(featW); % template features for this batch - fWpc(:, :, irange) = gather(featPC); % PC features - - ntot = ntot + numel(x0); % keeps track of total number of spikes so far - - - if (rem(ibatch, 100)==1) - % this is some of the relevant diagnostic information to be printed during training - fprintf('%2.2f sec, %d / %d batches, %d units, nspks: %d, mu: %2.4f, nst0: %d \n', ... - toc, ibatch, niter, Nfilt, ntot, median(mu), numel(st0)) - - % these diagnostic figures should be mostly self-explanatory - if ops.fig - if ibatch==1 - figHand = figure; - else - figure(figHand); - end - - make_fig(W, U, mu, nsp) + % spike threshold for finding missed spikes in residuals + % ops.spkTh = -6; % why am I overwriting this here? + + batchstart = 0:NT:NT * nBatches; + + % find the closest NchanNear channels, and the masks for those channels + [iC, mask, C2C] = getClosestChannels(rez, sigmaMask, NchanNear); + + niter = numel(iorder); + + % this is the absolute temporal offset in seconds corresponding to the start of the + % spike sorted time segment + t0 = ceil(rez.ops.trange(1) * ops.fs); + + nInnerIter = 60; % this is for SVD for the power iteration + + % schedule of learning rates for the model fitting part + % starts small and goes high, it corresponds approximately to the number of spikes + % from the past that were averaged to give rise to the current template + pm = exp(-1/400); + + Nsum = min(Nchan, 7); % how many channels to extend out the waveform in mexgetspikes + % lots of parameters passed into the CUDA scripts + Params = double([NT Nfilt ops.Th(1) nInnerIter nt0 Nnearest ... + Nrank ops.lam pm Nchan NchanNear ops.nt0min 1 Nsum NrankPC ops.Th(1) useStableMode]); + + % initialize average number of spikes per batch for each template + nsp = gpuArray.zeros(Nfilt, 1, 'int32'); + + % extract ALL features on the last pass + Params(13) = 2; % this is a flag to output features (PC and template features) + + % different threshold on last pass? + Params(3) = ops.Th(end); % usually the threshold is much lower on the last pass + + % kernels for subsample alignment, unused + [Ka, Kb] = getKernels(ops, 10, 1); + + p1 = .95; % decay of nsp estimate in each batch + + % the list of channels each template lives on + % also, covariance matrix between templates + [~, iW] = max(abs(dWU(nt0min, :, :)), [], 2); + iW = int32(squeeze(iW)); + [WtW, iList] = getMeWtW(single(W), single(U), Nnearest); + + fprintf('Time %3.0fs. Final spike extraction ...\n', toc) + + fid = fopen(ops.fproc, 'r'); + + % allocate variables for collecting results + st3 = zeros(1e7, 5); % this holds spike times, clusters and other info per spike + ntot = 0; + + % these ones store features per spike + fW = zeros(Nnearest, 1e7, 'single'); % Nnearest is the number of nearest templates to store features for + fWpc = zeros(NchanNear, ops.nPCs, 1e7, 'single'); % NchanNear is the number of nearest channels to take PC features from + + dWU1 = dWU; + % get covariance matrix of the spatial components, whether templates were shared across channels + [UtU, maskU] = getMeUtU(iW, iC, mask, Nnearest, Nchan); % this needs to change (but I don't know why!) + + for ibatch = 1:niter % loop over batches, in order determined by drift correction (stored in ibatch) + k = iorder(ibatch); % k is the index of the batch in absolute terms + + % loading a single batch (same as everywhere) + offset = 2 * ops.Nchan * batchstart(k); + fseek(fid, offset, 'bof'); + dat = fread(fid, [ops.Nchan NT + ops.ntbuff], '*int16'); + dat = dat'; + dataRAW = single(gpuArray(dat)) / ops.scaleproc; + Params(1) = size(dataRAW, 1); + + % decompose dWU by svd of time and space (via covariance matrix of 61 by 61 samples) + % this uses a "warm start" by remembering the W from the previous + % iteration + + % we don't need to update this anymore on every iteraton.... + % [W, U, mu] = mexSVDsmall2(Params, dWU, W, iC-1, iW-1, Ka, Kb); + + % UtU is the gram matrix of the spatial components of the low-rank SVDs + % it tells us which pairs of templates are likely to "interfere" with each other + % such as when we subtract off a template + % [UtU, maskU] = getMeUtU(iW, iC, mask, Nnearest, Nchan); % this needs to change (but I don't know why!)% + + % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + + % main CUDA function in the whole codebase. does the iterative template matching + % based on the current templates, gets features for these templates if requested (featW, featPC), + % gets scores for the template fits to each spike (vexp), outputs the average of + % waveforms assigned to each cluster (dWU0), + % and probably a few more things I forget about + [st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp, errmsg] = ... + mexMPnu8_pcTight(Params, dataRAW, single(U), single(W), single(mu), iC - 1, iW - 1, UtU, iList - 1, ... + wPCA); + % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + + % Sometimes nsp can get transposed (think this has to do with it being + % a single element in one iteration, to which elements are added + % nsp, nsp0, and pm must all be row vectors (Nfilt x 1), so force nsp + % to be a row vector. + [nsprow, nspcol] = size(nsp); + if nsprow < nspcol + nsp = nsp'; end - end -end -fclose(fid); -toc -% discards the unused portion of the arrays -st3 = st3(1:ntot, :); -fW = fW(:, 1:ntot); -fWpc = fWpc(:,:, 1:ntot); + % updates the templates as a running average weighted by recency + % since some clusters have different number of spikes, we need to apply the + % exp(pm) factor several times, and fexp is the resulting update factor + % for each template + dWU1 = dWU1 + dWU0; + nsp = nsp + nsp0; + + % nsp just gets updated according to the fixed factor p1 + + % \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ + + % during the final extraction pass, this keesp track of all spikes and features + + % we memorize the spatio-temporal decomposition of the waveforms at this batch + % this is currently only used in the GUI to provide an accurate reconstruction + % of the raw data at this time + + % we carefully assign the correct absolute times to spikes found in this batch + toff = nt0min + t0 + NT * (k - 1); + st = toff + double(st0); + + irange = ntot + [1:numel(x0)]; % spikes and features go into these indices + + if ntot + numel(x0) > size(st3, 1) + % if we exceed the original allocated memory, double the allocated sizes + fW(:, 2 * size(st3, 1)) = 0; + fWpc(:, :, 2 * size(st3, 1)) = 0; + st3(2 * size(st3, 1), 1) = 0; + end + + st3(irange, 1) = double(st); % spike times + st3(irange, 2) = double(id0 + 1); % spike clusters (1-indexing) + st3(irange, 3) = double(x0); % template amplitudes + st3(irange, 4) = double(vexp); % residual variance of this spike + st3(irange, 5) = ibatch; % batch from which this spike was found + + fW(:, irange) = gather(featW); % template features for this batch + fWpc(:, :, irange) = gather(featPC); % PC features + + ntot = ntot + numel(x0); % keeps track of total number of spikes so far + + if (rem(ibatch, round(niter / 10)) == 1) % every 10 % of the batches + % this is some of the relevant diagnostic information to be printed during training + % fprintf('%2.2f sec, %d / %d batches, %d units, nspks: %d, mu: %2.4f, nst0: %d \n', ... + % toc, ibatch, niter, Nfilt, ntot, median(mu), numel(st0)) + + % these diagnostic figures should be mostly self-explanatory + if ops.fig + if ibatch == 1 + figHand = figure; + else + figure(figHand); + end + + make_fig(W, U, mu, nsp) + end + end + end + fclose(fid); + toc -rez.dWU = dWU1 ./ (1e-10 + single(reshape(nsp, [1,1,Nfilt]))); -rez.nsp = nsp; + % discards the unused portion of the arrays + st3 = st3(1:ntot, :); + fW = fW(:, 1:ntot); + fWpc = fWpc(:, :, 1:ntot); -rez.iC = iC; + rez.dWU = dWU1 ./ (1e-10 + single(reshape(nsp, [1, 1, Nfilt]))); + rez.nsp = nsp; -fWpc = permute(fWpc, [3, 2, 1]); + rez.iC = iC; -%% + fWpc = permute(fWpc, [3, 2, 1]); + % dbstop in trackAndSort at 241 + % disp('stop') + %% diff --git a/sorting/Kilosort-3.0/plot_templates_on_raw_data_fast.m b/sorting/Kilosort-3.0/plot_templates_on_raw_data_fast.m new file mode 100644 index 00000000..922b9783 --- /dev/null +++ b/sorting/Kilosort-3.0/plot_templates_on_raw_data_fast.m @@ -0,0 +1,59 @@ + +function plot_templates_on_raw_data_fast(rez, st3) + ops = rez.ops; + fid = fopen(ops.fproc, 'r'); % open the preprocessed data file + Nbatch = rez.temp.Nbatch; + NT = ops.NT; + batchstart = 0:NT:NT * Nbatch; + Nfilt = size(rez.W, 2); + nt0 = ops.nt0; + Nchan = ops.Nchan; + RGBA_colors = [rand(Nfilt, 3) 0.7*ones(Nfilt, 1)]; + for ibatch = 2:2:4 + offset = 2 * ops.Nchan * batchstart(ibatch); + fseek(fid, offset, 'bof'); + dat = fread(fid, [ops.Nchan NT], '*int16'); + dat = dat'; + + % move data to GPU and scale it back to unit variance + dataRAW = dat; + dataRAW = single(dataRAW); + dataRAW = dataRAW / ops.scaleproc; % dataRAW is size + + % add offset to each channel, shift time to correct batch offset, then plot + spacing = 30; + dataRAW = dataRAW + spacing*ops.chanMap'; + batch_time = (1:size(dataRAW, 1)) + batchstart(ibatch); + figure(15+round(ibatch/2)); hold on; + % plot the raw data + plot(repmat(batch_time', 1, size(dataRAW, 2)), dataRAW, 'k', 'LineWidth', 1); + + + % next plot each template for each cluster, WU(:,:,j), on top of the raw data at each cluster's corresponding spike time + % use only valid channel locations and scale the color by cluster ID, along RGB + spike_times_in_batch_for_each_cluster = cell(Nfilt, 1); + for jfilt = 5:10 % only show first four clusters %Nfilt + % WU(:, :, j) = rez.mu(j) * squeeze(rez.W(:, j, :)) * squeeze(rez.U(:, j, :))'; + spike_times_in_batch_for_this_cluster = st3(st3(:, 2) == jfilt & st3(:, 1) > batchstart(ibatch) & st3(:, 1) < batchstart(ibatch+1), 1); + Nspikes = length(spike_times_in_batch_for_this_cluster); + if Nspikes == 0 + continue + end + spike_times_in_batch_for_each_cluster{jfilt} = spike_times_in_batch_for_this_cluster; + unit_var_cluster_waveforms = rez.dWU(:, :, jfilt)./std(rez.dWU(:, :, jfilt)); + % get the 1D waveform for each channel in WU, and plot it at the corresponding spike time and channel location + cluster_waveforms_for_all_channels = unit_var_cluster_waveforms + spacing*ops.chanMap'; + % create a time range centered on the spike time for all spikes, put result in 2D matrix + time_ranges_for_each_spike_time = repmat((-nt0/2:nt0/2-1)', 1, Nspikes); + offset_time_ranges_for_each_spike_time = time_ranges_for_each_spike_time + spike_times_in_batch_for_this_cluster'; + offset_time_ranges_for_each_spike_time_rep = repmat(offset_time_ranges_for_each_spike_time, 1, Nchan); + cluster_waveforms_for_all_channels_rep = zeros(size(offset_time_ranges_for_each_spike_time_rep)); + for ichan = 1:Nchan + cluster_waveforms_for_all_channels_rep(:, ((ichan-1)*Nspikes:ichan*Nspikes-1)+1) = repmat(cluster_waveforms_for_all_channels(:, ichan), 1, Nspikes); + end + % plot all waveforms for this cluster + plot(offset_time_ranges_for_each_spike_time_rep, cluster_waveforms_for_all_channels_rep, 'Color',RGBA_colors(jfilt,:)); + end % rez.dWU has shape nt0 by Nchan by Nfilt + title(["Prototype template matches for batch ", num2str(ibatch)]) + end +end \ No newline at end of file diff --git a/sorting/Kilosort-3.0/postProcess/find_merges.m b/sorting/Kilosort-3.0/postProcess/find_merges.m index c8f6538f..a90b99ab 100644 --- a/sorting/Kilosort-3.0/postProcess/find_merges.m +++ b/sorting/Kilosort-3.0/postProcess/find_merges.m @@ -50,7 +50,7 @@ rez.K_CCG = {}; end -for j = 1:Nk +for j = 1:Nk % now we traverse the neurons to find pairs to merge based on template correlation s1 = rez.st3(rez.st3(:,2)==isort(j), 1)/ops.fs; % find all spikes from this cluster if numel(s1)~=nspk(isort(j)) fprintf('lost track of spike counts') %this is a check for myself to make sure new cluster are combined correctly into bigger clusters @@ -60,7 +60,7 @@ ienu = find(ccsort<.7, 1) - 1; % find the first pair which has too low of a correlation - % for all pairs above 0.5 correlation + % for all pairs above 0.7, check if they are refractory for k = 1:ienu s2 = rez.st3(rez.st3(:,2)==ix(k), 1)/ops.fs; % find the spikes of the pair % compute cross-correlograms, refractoriness scores (Qi and rir), and normalization for these scores diff --git a/sorting/Kilosort-3.0/preProcess/datashift2.m b/sorting/Kilosort-3.0/preProcess/datashift2.m index 55db4284..c3489ad7 100644 --- a/sorting/Kilosort-3.0/preProcess/datashift2.m +++ b/sorting/Kilosort-3.0/preProcess/datashift2.m @@ -1,169 +1,162 @@ function rez = datashift2(rez, do_correction) -NrankPC = 6; -[wTEMP, wPCA] = extractTemplatesfromSnippets(rez, NrankPC); -rez.wTEMP = gather(wTEMP); -rez.wPCA = gather(wPCA); - -ops = rez.ops; - -% The min and max of the y and x ranges of the channels -ymin = min(rez.yc); -ymax = max(rez.yc); -xmin = min(rez.xc); -xmax = max(rez.xc); - -dmin = median(diff(unique(rez.yc))); -fprintf('vertical pitch size is %d \n', dmin) -rez.ops.dmin = dmin; -rez.ops.yup = ymin:dmin/2:ymax; % centers of the upsampled y positions - -% dminx = median(diff(unique(rez.xc))); -yunq = unique(rez.yc); -mxc = zeros(numel(yunq), 1); -for j = 1:numel(yunq) - xc = rez.xc(rez.yc==yunq(j)); - if numel(xc)>1 - mxc(j) = median(diff(sort(xc))); + NrankPC = rez.ops.nPCs; + [wTEMP, wPCA] = extractTemplatesfromSnippets(rez, NrankPC); + rez.wTEMP = gather(wTEMP); + rez.wPCA = gather(wPCA); + + ops = rez.ops; + + % The min and max of the y and x ranges of the channels + ymin = min(rez.yc); + ymax = max(rez.yc); + xmin = min(rez.xc); + xmax = max(rez.xc); + + dmin = median(diff(unique(rez.yc))); + fprintf('vertical pitch size is %d \n', dmin) + rez.ops.dmin = dmin; + rez.ops.yup = ymin:dmin / 2:ymax; % centers of the upsampled y positions + + % dminx = median(diff(unique(rez.xc))); + yunq = unique(rez.yc); + mxc = zeros(numel(yunq), 1); + for j = 1:numel(yunq) + xc = rez.xc(rez.yc == yunq(j)); + if numel(xc) > 1 + mxc(j) = median(diff(sort(xc))); + end end -end -dminx = max(5, median(mxc)); -fprintf('horizontal pitch size is %d \n', dminx) - -rez.ops.dminx = dminx; -nx = round((xmax-xmin) / (dminx/2)) + 1; -rez.ops.xup = linspace(xmin, xmax, nx); % centers of the upsampled x positions -disp(rez.ops.xup) - - -if getOr(rez.ops, 'nblocks', 1)==0 - rez.iorig = 1:rez.temp.Nbatch; - return; -end - - - -% binning width across Y (um) -dd = 5; -% min and max for the range of depths -dmin = ymin - 1; -dmax = 1 + ceil((ymax-dmin)/dd); -disp(dmax) - - -spkTh = 10; % same as the usual "template amplitude", but for the generic templates - -% Extract all the spikes across the recording that are captured by the -% generic templates. Very few real spikes are missed in this way. -[st3, rez] = standalone_detector(rez, spkTh); -%% - -% detected depths -% dep = st3(:,2); -% dep = dep - dmin; - -Nbatches = rez.temp.Nbatch; -% which batch each spike is coming from -batch_id = st3(:,5); %ceil(st3(:,1)/dt); - -% preallocate matrix of counts with 20 bins, spaced logarithmically -F = zeros(dmax, 20, Nbatches); -for t = 1:Nbatches - % find spikes in this batch - ix = find(batch_id==t); - - % subtract offset - dep = st3(ix,2) - dmin; - - % amplitude bin relative to the minimum possible value - amp = log10(min(99, st3(ix,3))) - log10(spkTh); - - % normalization by maximum possible value - amp = amp / (log10(100) - log10(spkTh)); - - % multiply by 20 to distribute a [0,1] variable into 20 bins - % sparse is very useful here to do this binning quickly - M = sparse(ceil(dep/dd), ceil(1e-5 + amp * 20), ones(numel(ix), 1), dmax, 20); - - % the counts themselves are taken on a logarithmic scale (some neurons - % fire too much!) - F(:, :, t) = log2(1+M); -end - -%% -% determine registration offsets -ysamp = dmin + dd * [1:dmax] - dd/2; -[imin,yblk, F0, F0m] = align_block2(F, ysamp, ops.nblocks); - -if isfield(rez, 'F0') - d0 = align_pairs(rez.F0, F0); - % concatenate the shifts - imin = imin - d0; -end - -%% -if getOr(ops, 'fig', 1) - figure; - set(gcf, 'Color', 'w') - - % plot the shift trace in um - plot(imin * dd) - box off - xlabel('batch number') - ylabel('drift (um)') - title('Estimated drift traces') - drawnow - - figure; - set(gcf, 'Color', 'w') - % raster plot of all spikes at their original depths - st_shift = st3(:,2); %+ imin(batch_id)' * dd; - for j = spkTh:100 - % for each amplitude bin, plot all the spikes of that size in the - % same shade of gray - ix = st3(:, 3)==j; % the amplitudes are rounded to integers - plot(st3(ix, 1)/ops.fs, st_shift(ix), '.', 'color', [1 1 1] * max(0, 1-j/40)) % the marker color here has been carefully tuned - hold on + dminx = max(5, median(mxc)); + fprintf('horizontal pitch size is %d \n', dminx) + + rez.ops.dminx = dminx; + nx = round((xmax - xmin) / (dminx / 2)) + 1; + rez.ops.xup = linspace(xmin, xmax, nx); % centers of the upsampled x positions + disp(rez.ops.xup) + + if getOr(rez.ops, 'nblocks', 1) == 0 + rez.iorig = 1:rez.temp.Nbatch; + return; end - axis tight - box off - - xlabel('time (sec)') - ylabel('spike position (um)') - title('Drift map') - -end -%% -% convert to um -dshift = imin * dd; - -% this is not really used any more, should get taken out eventually -[~, rez.iorig] = sort(mean(dshift, 2)); - -if do_correction - % sigma for the Gaussian process smoothing - sig = rez.ops.sig; - % register the data batch by batch - dprev = gpuArray.zeros(ops.ntbuff,ops.Nchan, 'single'); - for ibatch = 1:Nbatches - dprev = shift_batch_on_disk2(rez, ibatch, dshift(ibatch, :), yblk, sig, dprev); + + % binning width across Y (um) + dd = 5; + % min and max for the range of depths + dmin = ymin - 1; + dmax = 1 + ceil((ymax - dmin) / dd); + disp(dmax) + + spkTh = 10; % same as the usual "template amplitude", but for the generic templates + + % Extract all the spikes across the recording that are captured by the + % generic templates. Very few real spikes are missed in this way. + [st3, rez] = standalone_detector(rez, spkTh); + %% + + % detected depths + % dep = st3(:,2); + % dep = dep - dmin; + + Nbatches = rez.temp.Nbatch; + % which batch each spike is coming from + batch_id = st3(:, 5); %ceil(st3(:,1)/dt); + + % preallocate matrix of counts with 20 bins, spaced logarithmically + F = zeros(dmax, 20, Nbatches); + for t = 1:Nbatches + % find spikes in this batch + ix = find(batch_id == t); + + % subtract offset + dep = st3(ix, 2) - dmin; + + % amplitude bin relative to the minimum possible value + amp = log10(min(99, st3(ix, 3))) - log10(spkTh); + + % normalization by maximum possible value + amp = amp / (log10(100) - log10(spkTh)); + + % multiply by 20 to distribute a [0,1] variable into 20 bins + % sparse is very useful here to do this binning quickly + M = sparse(ceil(dep / dd), ceil(1e-5 + amp * 20), ones(numel(ix), 1), dmax, 20); + + % the counts themselves are taken on a logarithmic scale (some neurons + % fire too much!) + F(:, :, t) = log2(1 + M); end - fprintf('time %2.2f, Shifted up/down %d batches. \n', toc, Nbatches) -else - fprintf('time %2.2f, Skipped shifting %d batches. \n', toc, Nbatches) -end -% keep track of dshift -rez.dshift = dshift; -% keep track of original spikes -rez.st0 = st3; -rez.F = F; -rez.F0 = F0; -rez.F0m = F0m; + %% + % determine registration offsets + ysamp = dmin + dd * [1:dmax] - dd / 2; + [imin, yblk, F0, F0m] = align_block2(F, ysamp, ops.nblocks); -% next, we can just run a normal spike sorter, like Kilosort1, and forget about the transformation that has happened in here + if isfield(rez, 'F0') + d0 = align_pairs(rez.F0, F0); + % concatenate the shifts + imin = imin - d0; + end + + %% + if getOr(ops, 'fig', 1) + figure; + set(gcf, 'Color', 'w') + + % plot the shift trace in um + plot(imin * dd) + box off + xlabel('batch number') + ylabel('drift (um)') + title('Estimated drift traces') + drawnow + + figure; + set(gcf, 'Color', 'w') + % raster plot of all spikes at their original depths + st_shift = st3(:, 2); %+ imin(batch_id)' * dd; + for j = spkTh:100 + % for each amplitude bin, plot all the spikes of that size in the + % same shade of gray + ix = st3(:, 3) == j; % the amplitudes are rounded to integers + plot(st3(ix, 1) / ops.fs, st_shift(ix), '.', 'color', [1 1 1] * max(0, 1 - j / 40)) % the marker color here has been carefully tuned + hold on + end + axis tight + box off + + xlabel('time (sec)') + ylabel('spike position (um)') + title('Drift map') -%% + end + %% + % convert to um + dshift = imin * dd; + + % this is not really used any more, should get taken out eventually + [~, rez.iorig] = sort(mean(dshift, 2)); + + if do_correction + % sigma for the Gaussian process smoothing + sig = rez.ops.sig; + % register the data batch by batch + dprev = gpuArray.zeros(ops.ntbuff, ops.Nchan, 'single'); + for ibatch = 1:Nbatches + dprev = shift_batch_on_disk2(rez, ibatch, dshift(ibatch, :), yblk, sig, dprev); + end + fprintf('time %2.2f, Shifted up/down %d batches. \n', toc, Nbatches) + else + fprintf('time %2.2f, Skipped shifting %d batches. \n', toc, Nbatches) + end + % keep track of dshift + rez.dshift = dshift; + % keep track of original spikes + rez.st0 = st3; + rez.F = F; + rez.F0 = F0; + rez.F0m = F0m; + % next, we can just run a normal spike sorter, like Kilosort1, and forget about the transformation that has happened in here + %% diff --git a/sorting/Kilosort-3.0/preProcess/get_channel_delays.m b/sorting/Kilosort-3.0/preProcess/get_channel_delays.m new file mode 100644 index 00000000..a561bdd8 --- /dev/null +++ b/sorting/Kilosort-3.0/preProcess/get_channel_delays.m @@ -0,0 +1,87 @@ +function [channelDelays] = get_channel_delays(rez) +% based on a subset of the data, compute the channel delays to maximize cross correlations +% this requires temporal filtering first (gpufilter) + +ops = rez.ops; +Nbatch = ops.Nbatch; +twind = ops.twind; +NchanTOT = ops.NchanTOT; +NT = ops.NT; +NTbuff = ops.NTbuff; +Nchan = rez.ops.Nchan; + +fprintf('Getting channel delays... \n'); +fid = fopen(ops.fbinary, 'r'); +maxlag = ops.fs/500; % 2 ms max time shift + +% we'll estimate the cross correlation across channels from data batches +ibatch = 1; +chan_CC = zeros(2*maxlag+1, NchanTOT^2, 'single', 'gpuArray'); +while ibatch<=Nbatch + offset = max(0, twind + 2*NchanTOT*((NT - ops.ntbuff) * (ibatch-1) - 2*ops.ntbuff)); + fseek(fid, offset, 'bof'); + buff = fread(fid, [NchanTOT NTbuff], '*int16'); + + if isempty(buff) + break; + end + nsampcurr = size(buff,2); + if nsampcurr sum(last_maxes) % if these delays produce higher correlation + best_peak_locs = this_chan_corr_peak_locs; + last_maxes = these_maxes; + end +end +% remove nan values for display +chan_corr_peak_maxes(isnan(chan_corr_peak_maxes)) = 0; +% use the earliest channel as a reference to compute delays +channelDelays = gather(best_peak_locs - maxlag - 1); % -1 because of zero-lag +disp("Channel delays with best correlation computed for all channel combinations: ") +disp(reshape(chan_corr_peak_locs, Nchan, Nchan)-maxlag-1) +disp("Using channel delays with best reference channel: ") +disp(channelDelays) + +disp("Correlation values trying each reference channel: ") +disp(reshape(chan_corr_peak_maxes, Nchan, Nchan)) +disp(" + ___________________________________________________________") +disp(sum(reshape(chan_corr_peak_maxes, Nchan, Nchan))) + +disp("Using best reference channel, with maximal correlation: ") +disp(sum(last_maxes)) + +end + diff --git a/sorting/Kilosort-3.0/preProcess/get_whitening_matrix.m b/sorting/Kilosort-3.0/preProcess/get_whitening_matrix.m index 40aa3b0d..65bfd805 100644 --- a/sorting/Kilosort-3.0/preProcess/get_whitening_matrix.m +++ b/sorting/Kilosort-3.0/preProcess/get_whitening_matrix.m @@ -39,7 +39,7 @@ buff(:, nsampcurr+1:NTbuff) = repmat(buff(:,nsampcurr), 1, NTbuff-nsampcurr); end - datr = gpufilter(buff, ops, rez.ops.chanMap); % apply filters and median subtraction + datr = gpufilter(buff, ops, chanMap); % apply filters and median subtraction CC = CC + (datr' * datr)/NT; % sample covariance diff --git a/sorting/Kilosort-3.0/preProcess/get_whitening_matrix_faster.m b/sorting/Kilosort-3.0/preProcess/get_whitening_matrix_faster.m index 49db072b..076a65d0 100644 --- a/sorting/Kilosort-3.0/preProcess/get_whitening_matrix_faster.m +++ b/sorting/Kilosort-3.0/preProcess/get_whitening_matrix_faster.m @@ -8,10 +8,15 @@ NchanTOT = ops.NchanTOT; NT = ops.NT; NTbuff = ops.NTbuff; -chanMap = ops.chanMap; +numDummy = rez.ops.numDummy; +numChansToUse = NchanTOT-numDummy; +origChanMap = ops.chanMap; +chanMap = origChanMap(1:numChansToUse); Nchan = rez.ops.Nchan; -xc = rez.xc; -yc = rez.yc; +origxc = rez.xc; +origyc = rez.yc; +xc = origxc(1:numChansToUse); +yc = origyc(1:numChansToUse); % load data into patches, filter, compute covariance fprintf('Getting channel whitening matrix... \n'); @@ -44,8 +49,8 @@ % Memory map the raw data file to m.Data.x of size [nChan, nSamples) mmfRaw = dir(ops.fbinary); - nsamp = mmfRaw.bytes/2/ops.NchanTOT; - mmfRaw = memmapfile(ops.fbinary, 'Format',{'int16', [ops.NchanTOT, nsamp], 'x'}); % ignore tstart here b/c we'll account for it on each mapped read (using .twind) + nsamp = mmfRaw.bytes/2/NchanTOT; + mmfRaw = memmapfile(ops.fbinary, 'Format',{'int16', [NchanTOT, nsamp], 'x'}); % ignore tstart here b/c we'll account for it on each mapped read (using .twind) %parmmf = parallel.pool.Constant(mmf); tic parfor i = 1:length(ibatch) @@ -148,7 +153,7 @@ if ops.whiteningRange 1); + dups_to_remove = []; + for iDup = 1:length(times_which_have_duplicates) + % find the indices of the duplicates + dup_inds = find(full_row == times_which_have_duplicates(iDup)); + % find the index of the max amplitude + [~, max_ind] = max(full_mu(dup_inds)); + % remove all but the max amplitude + dup_inds(max_ind) = []; + dups_to_remove = [dups_to_remove; dup_inds]; + end + full_row(dups_to_remove) = []; + full_col(dups_to_remove) = []; + full_mu(dups_to_remove) = []; + if length(row_uniq) ~= length(full_row) + error('Duplicate removal failed, numbers of unique spikes and spikes left after removal do not match.'); + end + % disp(['Removed ' num2str(length(dups_to_remove)) ' duplicate spikes, out of ' num2str(length(full_row)) ' total spikes.']); + + % now plot the remaining spikes at each amplitude with red squares + % if ops.fig && (mod(ibatch, 20) == 0 || ibatch == ops.Nbatch) + % figure(999); hold on; + % plot(full_row + time_offset, full_col * channel_offset - full_mu, 'rs'); + % end + row = full_row; + col = full_col; + mu = full_mu; +end diff --git a/sorting/Kilosort-3.0/preProcess/isolated_peaks_new.m b/sorting/Kilosort-3.0/preProcess/isolated_peaks_new.m index 3411fbf1..3a3180a8 100644 --- a/sorting/Kilosort-3.0/preProcess/isolated_peaks_new.m +++ b/sorting/Kilosort-3.0/preProcess/isolated_peaks_new.m @@ -24,4 +24,4 @@ [row, col, mu] = find(peaks); % find the non-zero peaks, and take their amplitudes -mu = - mu; % invert the sign of the amplitudes +mu = abs(mu); % take the absolute value of the amplitudes \ No newline at end of file diff --git a/sorting/Kilosort-3.0/preProcess/loadChanMap.m b/sorting/Kilosort-3.0/preProcess/loadChanMap.m index ff8d80d7..c41b43bc 100644 --- a/sorting/Kilosort-3.0/preProcess/loadChanMap.m +++ b/sorting/Kilosort-3.0/preProcess/loadChanMap.m @@ -1,6 +1,6 @@ -function [chanMap, xcoords, ycoords, kcoords, NchanTOT] = loadChanMap(cmIn) +function [chanMap, xcoords, ycoords, kcoords, NchanTOT, numDummy] = loadChanMap(cmIn) % function [chanMap, xcoords, ycoords, kcoords] = loadChanMap(cmIn) % % Load and sanitize a channel map provided to Kilosort @@ -76,3 +76,9 @@ NchanTOT = numel(chanMap); end +if ~isfield(cmIn, 'numDummy') + numDummy = 0; +else + numDummy = cmIn.numDummy; +end +end diff --git a/sorting/Kilosort-3.0/preProcess/preprocessDataSub.m b/sorting/Kilosort-3.0/preProcess/preprocessDataSub.m index 8daae783..f66ac4b9 100644 --- a/sorting/Kilosort-3.0/preProcess/preprocessDataSub.m +++ b/sorting/Kilosort-3.0/preProcess/preprocessDataSub.m @@ -1,152 +1,336 @@ function rez = preprocessDataSub(ops) -% this function takes an ops struct, which contains all the Kilosort2 settings and file paths -% and creates a new binary file of preprocessed data, logging new variables into rez. -% The following steps are applied: -% 1) conversion to float32; -% 2) common median subtraction; -% 3) bandpass filtering; -% 4) channel whitening; -% 5) scaling to int16 values - -% track git repo(s) with new utility (see /utils/gitStatus.m) -if getOr(ops, 'useGit', 1) - ops = gitStatus(ops); -end - -tic; -ops.nt0 = getOr(ops, {'nt0'}, 61); % number of time samples for the templates (has to be <=81 due to GPU shared memory) -ops.nt0min = getOr(ops, 'nt0min', ceil(20 * ops.nt0/61)); % time sample where the negative peak should be aligned - -NT = ops.NT ; % number of timepoints per batch -NchanTOT = ops.NchanTOT; % total number of channels in the raw binary file, including dead, auxiliary etc - -bytes = get_file_size(ops.fbinary); % size in bytes of raw binary -nTimepoints = floor(bytes/NchanTOT/2); % number of total timepoints -ops.tstart = ceil(ops.trange(1) * ops.fs); % starting timepoint for processing data segment -ops.tend = min(nTimepoints, ceil(ops.trange(2) * ops.fs)); % ending timepoint -ops.sampsToRead = ops.tend-ops.tstart; % total number of samples to read -ops.twind = ops.tstart * NchanTOT*2; % skip this many bytes at the start - -Nbatch = ceil(ops.sampsToRead /NT); % number of data batches -ops.Nbatch = Nbatch; - -[chanMap, xc, yc, kcoords, NchanTOTdefault] = loadChanMap(ops.chanMap); % function to load channel map file -ops.NchanTOT = getOr(ops, 'NchanTOT', NchanTOTdefault); % if NchanTOT was left empty, then overwrite with the default - -% determine bad channels -fprintf('Time %3.0fs. Determining good channels.. \n', toc); -igood = true(size(chanMap)); - -% if isfield(ops, 'brokenChan') -% if isfile(ops.brokenChan) -% load(ops.brokenChan) -% igood = true(NchanTOT,1); -% igood(brokenChan) = false; -% end -% end - - -chanMap = chanMap(igood); %it's enough to remove bad channels from the channel map, which treats them as if they are dead -xc = xc(igood); % removes coordinates of bad channels -yc = yc(igood); -kcoords = kcoords(igood); -ops.igood = igood; - -ops.Nchan = numel(chanMap); % total number of good channels that we will spike sort -disp(ops.Nchan) -ops.Nfilt = getOr(ops, 'nfilt_factor', 4) * ops.Nchan; % upper bound on the number of templates we can have - -rez.ops = ops; % memorize ops - -rez.xc = xc; % for historical reasons, make all these copies of the channel coordinates -rez.yc = yc; -rez.xcoords = xc; -rez.ycoords = yc; -% rez.connected = connected; -rez.ops.chanMap = chanMap; -rez.ops.kcoords = kcoords; - - -NTbuff = NT + 3*ops.ntbuff; % we need buffers on both sides for filtering - -rez.ops.Nbatch = Nbatch; -rez.ops.NTbuff = NTbuff; -rez.ops.chanMap = chanMap; - - -fprintf('Time %3.0fs. Computing whitening matrix.. \n', toc); - -% this requires removing bad channels first -Wrot = get_whitening_matrix(rez); % outputs a rotation matrix (Nchan by Nchan) which whitens the zero-timelag covariance of the data -% Wrot = gpuArray.eye(size(Wrot,1), 'single'); -% Wrot = diag(Wrot); - -condition_number = cond(gather(Wrot)); -disp(['Computed the whitening matrix cond = ' num2str(condition_number)]) -if condition_number > 50 - disp('Warning: Your whitening matrix value is above 50.') - disp('High conditioning of the whitening matrix can result in noisy and poor results.') - disp('CHECK YO-SELF BEFORE YOU WRECK YO-SELF') -end - -fprintf('Time %3.0fs. Loading raw data and applying filters... \n', toc); - -fid = fopen(ops.fbinary, 'r'); % open for reading raw data -if fid<3 - error('Could not open %s for reading.',ops.fbinary); -end -fidW = fopen(ops.fproc, 'w+'); % open for writing processed data -if fidW<3 - error('Could not open %s for writing.',ops.fproc); -end - -% weights to combine batches at the edge -w_edge = linspace(0, 1, ops.ntbuff)'; -ntb = ops.ntbuff; -datr_prev = gpuArray.zeros(ntb, ops.Nchan, 'single'); - -for ibatch = 1:Nbatch - % we'll create a binary file of batches of NT samples, which overlap consecutively on ops.ntbuff samples - % in addition to that, we'll read another ops.ntbuff samples from before and after, to have as buffers for filtering - offset = max(0, ops.twind + 2*NchanTOT*(NT * (ibatch-1) - ntb)); % number of samples to start reading at. - - fseek(fid, offset, 'bof'); % fseek to batch start in raw file - - buff = fread(fid, [NchanTOT NTbuff], '*int16'); % read and reshape. Assumes int16 data (which should perhaps change to an option) - if isempty(buff) - break; % this shouldn't really happen, unless we counted data batches wrong + % this function takes an ops struct, which contains all the Kilosort2 settings and file paths + % and creates a new binary file of preprocessed data, logging new variables into rez. + % The following steps are applied: + % 1) conversion to float32; + % 2) common median subtraction; + % 3) bandpass filtering; + % 4) channel whitening; + % 5) scaling to int16 values + % + % [ks25] updates: + % - adds git tracking with complete status, revisions, & changes to kilosort repo + % - uses memory mapped file reads by default (much faster) + % - updated to "_faster" version of get_whitening_matrix (memmapped reads AND parallelized loading...mmmuch faster) + % - creates handle to memmapped preprocessed data file in: rez.ops.fprocmmf + % - disabled [linear] weighted smoothing of batches + % - seems unnecessary & potentially problematic + % - esp in cases where batch buffer [.ntbuff] is significantly longer than waveform length [.nt0] + % - removed creation of rez.temp (unclear why this existed in first place) + % - required replacing instances of [rez.temp] to normal [rez] struct throughout codebase + % + % --- + % 202x-xx-xx TBC Evolved from original Kilosort + % 2021-04-28 TBC Cleaned & commented + % 2021-05-05 TBC Updated to generate preprocessed data file ranging from t0:tend+ntbuff + % - accomodates non-zero tstart w/o sacrificing temporal correl w/ raw data file + % 2021-06-24 TBC if [ops.CAR] value >1, will use as sliding window for median outlier (spike) + % when computing median (prevents high responsivity from skewing adjacent channels) + % 2023-09-07 SMO Added option to remove channel delays, saving channel delays, and handle dummy + % channels in channel delay and whitening matrix calculations + + %% Parse ops input & setup defaults + % record date & time of Kilosort execution + % - used by addFigInfo.m + ops.datenumSorted = now; + + % tic for this function + t00 = tic; + + % track git repo(s) with new utility (see /utils/gitStatus.m) + if getOr(ops, 'useGit', 1) + ops = gitStatus(ops); end - nsampcurr = size(buff,2); % how many time samples the current batch has - if nsampcurr 0 + % append batches reaching back to start of file such that batch indices within tstart:tend are maintained + % - as consequence, first batch may start with a negative value & contain 0 + WrotDummy = eye(NchanTOT, NchanTOT); + WrotDummy(1:NchanTOT - numDummy, 1:NchanTOT - numDummy) = Wrot; + Wrot = WrotDummy; end -end -fclose(fidW); % close the files -fclose(fid); -rez.Wrot = gather(Wrot); % gather the whitening matrix as a CPU variable + disp('Whitening matrix computed...') + disp(Wrot) + + cmdLog('Loading raw data and applying filters...', toc(t00)); + if true + % open for reading raw data + fid = fopen(ops.fbinary, 'r'); + if fid < 3 + error('Could not open %s for reading.', ops.fbinary); + end + % open for writing processed data + fidW = fopen(ops.fproc, 'wb+'); + if fidW < 3 + error('Could not open %s for writing.', ops.fproc); + end + + ntb = ops.ntbuff; + + % exponential smoothing + dnom = 3; % rate of padding exponential decay + + allBatches = 1:ops.NprocBatch; + + % Progress bar in command window + % pb = progBar(allBatches, 20); + + for ibatch = allBatches + % we'll create a preprocessed binary file by reading from the raw file (NOT memory mapped) batches of NT samples, + % with each batch padded by ops.ntbuff samples from before & after, to have as buffers for filtering + bstart = procBatchStarts(ibatch) - ntb; % start reading from ntbuffer samples before first batch start + bsamp = bstart + (0:NTbuff - 1); % determine sample indices (0-based) + bsampTrue = bsamp >= 0 & bsamp <= nTimepoints; % determine number & validity of samples being read (if all are valid, this will sum to NTbuff) + bsampTrueNT = bsampTrue(ntb + (1:NT)); % validity of batch samples (excluding buffers) -fprintf('Time %3.0fs. Finished preprocessing %d batches. \n', toc, Nbatch); + offset = max(0, bstart * NchanTOT * 2); % number of BYTES to offset start of standard read operation + fseek(fid, offset, 'bof'); % fseek to batch start in raw file + dat = fread(fid, [NchanTOT sum(bsampTrue)], '*int16'); % read and reshape. Assumes int16 data (which should perhaps change to an option) + + if isempty(dat) + break; % this shouldn't really happen, unless we counted data batches wrong + else + nsampcurr = size(dat, 2); % how many time samples the current batch has + end + + % % --------------------------------------------------------------------------------------------------------------- + % % --- step inside gpufilter.m operations ---% + % - unpacked this utility function b/c unhappy with buffer padding before demeaning + + % subsample only good channels & transpose for filtering + datr = double(dat(chanMap, :))'; % dat dims now: [samples, channel] + + % --- Demean before padding --- + % subtract within-channel means from each channel + datr = datr - mean(datr, 1); % nans not possible, since just converted from int16 raw dat values + + % CAR, common average referencing by median + % ----------------------------- + % Ugly SUM(...,'omitnan') workaround for demeaning & subtracting in presence of nan without injecting spurrious zeros + % ----------------------------- + % - "mean(...'omitnan')" is marginally faster than "nanmean(...)" + % - BUT: "nanmedian(...)" is significantly faster than "median(...'omitnan')" + + if doCAR + if doCAR > 1 + % Demean across cahnnels, exclude outlier values (spikes) from mean calc + % - useful for moderate channel counts where reasonable for significant spiking + % activity to influence median value across channels + datr = sum(cat(3, datr, repmat(-nanmedian(filloutliers(datr, nan), 2), [1, NchanMapTOT])), 3, 'omitnan'); + else + datr = sum(cat(3, datr, repmat(-nanmedian(datr, 2), [1, NchanMapTOT])), 3, 'omitnan'); % subtract median across channels + end + end + + if any(isnan(datr)) + warning('NANs detected in batch raw data...inspect data validity and/or consider different CAR filtering option'); + keyboard + end + + % Now can pad first & last batches with zeros + % if nsampcurr0 and first raw batch sample is negative) + datr(~bsampTrueNT, :) = []; + + % datcpu = gather(int16(datr')); % convert to int16, and gather on the CPU side + % doesn't actually get sent to gpu right now (TBD: test if faster) + count = fwrite(fidW, int16(datr'), 'int16'); % write this batch to binary file + + %hit = pb.check(ibatch); % update progress bar in command window + % updateProgressMessage(ibatch, ops.NprocBatch, t00,100,20); + + if count ~= numel(datr) + error('Error writing batch %g to %s. Check available disk space.', ibatch, ops.fproc); + end + end + disp('Done.') + + % close the files + fclose(fidW); + fclose(fid); + end + + if getOr(ops, 'useMemMapping', 1) + % memory map [ops.fproc] file + % - don't use Offset pv-pair, precludes using samples before tstart as buffer + filename = ops.fproc; + datatype = 'int16'; + chInFile = ops.Nchan; + bytes = get_file_size(filename); % size in bytes of [new] preprocessed data file + nSamp = floor(bytes / chInFile / 2); + % memory map file + rez.ops.fprocmmf = memmapfile(filename, 'Format', {datatype, [chInFile nSamp], 'chXsamp'}); + fprintf('\tMemMapped preprocessed dat file: %s\n\tas: rez.ops.fprocmmf.Data.chXsamp\n', ops.fproc); + end + script_dir = pwd; % get directory where repo exists + load(fullfile(script_dir, '/tmp/config.mat')); + if remove_channel_delays + channelDelays = get_channel_delays(rez); + % if present, make sure dummy channels are 0 delay + if numDummy > 0 + channelDelays(NchanTOT - numDummy + 1:end) = 0; + end + rez.ops.channelDelays = channelDelays; % save channel delays to rez + % figure(222); hold on; + % remove channel delays from proc.dat by seeking through the batches + % with ibatch*NT+max(channelDelays) and shifting each delayed channel backwards + % by the appropriate amount found in channelDelays + % this will effectively move some throwaway data to the end of all batches + % but now the spikes will be aligned in time across channels + fidOff = fopen(ops.fproc, 'r+'); + if fidOff < 3 + error('Could not open %s for reading.', ops.fbinary); + end + data = fread(fidOff, [NchanTOT inf], '*int16'); % read and reshape. Assumes int16 data + % circularly shift each channel by the appropriate amount + % plot(data') + for i = 1:length(channelDelays) + data(i, :) = circshift(data(i, :), channelDelays(i)); + end + % plot(data' + max(abs(data(:)))) % plot shifted data + fseek(fidOff, 0, 'bof'); % fseek to start in raw file, to overwrite + fwrite(fidOff, data, 'int16'); + fclose(fidOff); + disp('Removed channel delays from proc.dat, which were:') + disp(channelDelays) + % save(fullfile(myo_sorted_dir, 'channelDelays.mat'), 'channelDelays') + disp('Delay information will be saved in ops.mat') + else + channelDelays = zeros(ops.NchanTOT, 1); + rez.ops.channelDelays = channelDelays; % save channel delays to rez + end + rez.Wrot = gather(Wrot); % gather the whitening matrix as a CPU variable -rez.temp.Nbatch = Nbatch; + cmdLog(sprintf('Finished preprocessing %d batches.', Nbatch), toc(t00)); + rez.temp.Nbatch = Nbatch; +end %main function diff --git a/sorting/Kilosort-3.0/utils/addFigInfo.m b/sorting/Kilosort-3.0/utils/addFigInfo.m index 95b0853f..4154bce8 100644 --- a/sorting/Kilosort-3.0/utils/addFigInfo.m +++ b/sorting/Kilosort-3.0/utils/addFigInfo.m @@ -7,9 +7,15 @@ function addFigInfo(ops, H) try figure(H); fsz = 10; % info font size - dateVer = sprintf('Sorted on: %s',datestr(now)); + % date of Kilosort processing + if isfield(ops, 'datenumSorted') && ~isempty(ops.datenumSorted) + dateVer = sprintf('Sorted on: %s',datestr(ops.datenumSorted)); + else + dateVer = sprintf('Sorted on: %s',datestr(now)); + end + try - % kilosort git source + % append kilosort git repo information gitstat = strsplit(ops.git.kilosort.status, '\n'); dateVer = [dateVer, sprintf(' Kilosort git: %s, commit %s', gitstat{1}, ops.git.kilosort.revision(1:7))]; end diff --git a/sorting/Kilosort-3.0/utils/calc_SimScore.m b/sorting/Kilosort-3.0/utils/calc_SimScore.m new file mode 100644 index 00000000..7b4147c7 --- /dev/null +++ b/sorting/Kilosort-3.0/utils/calc_SimScore.m @@ -0,0 +1,26 @@ +function simScore = calc_SimScore(rez) + + % recalculate simScore including temporal lag + % - cranky implementation due to differences in rez.W and wPCA dimensions + % - weird/deep inconsistencies & hardcoding of how many PCs are preserved (3, 6, 2*3?', 'PaperSize', sort(type,'descend')) +end + + +set(h,'PaperUnits', 'inches') +sz = get(h,'PaperSize'); +set(h,'PaperPositionMode','manual'); +n = min([.1*sz, 0.5]); +set(h,'PaperPosition',[n/2, n/2, sz(2)-n, sz(1)-n]); + +% just keep doing this till it sticks +set(h,'PaperOrientation', ori); +orient(h, ori); %...and again (thanks Matlab) + +% silence unneeded outputs +if nargout>0 + hout = h; +end + +end %main function diff --git a/sorting/Kilosort-3.0/utils/get_batch.m b/sorting/Kilosort-3.0/utils/get_batch.m index 2b84a56d..cef8c9c3 100644 --- a/sorting/Kilosort-3.0/utils/get_batch.m +++ b/sorting/Kilosort-3.0/utils/get_batch.m @@ -1,19 +1,164 @@ -function dataRAW = get_batch(ops, ibatch) +function dat = get_batch(ops, ibatch, fid, varargin) +% function dat = get_batch(ops, ibatch, fid, varargin) +% +% Retrieve buffered batch of data from preprocessed data file [ops.fproc], +% return as gpuArray. +% +% Unified version of original Kilosort get_batch.m for either standard file reads +% or memory mapped files. Defaults to ops.useMemMapping = 1 +% +% +% INPUTS: +% [ops] standard kilosort ops struct +% [ibatch] 1-based index of batch to load +% --optional-- +% [fid] handle to memmappedfile opject or to standard fopen file +% [flags] string flags for modifying batch features: 'nobuffer', 'noscale' (see below) +% +% OUTPUTS: +% [dat] gpuArray of batch data, sized [nsamples-by-nchannels] +% - loaded directly to a gpuArray, after converting to single & scaling by ops.scaleproc +% - NOTE: dat orientation is transposed relative to saved file (for efficient usage w/in GPU code) +% +% [fid] options: +% - if is handle to a Memory Mapped file, load batch via memmapping +% - if is index to standard open file, load batch via standard reads +% - if [fid] absent or empty, follow ops.useMemMapping flag, open/initialize as needed +% (def: useMemMapping = 1) +% - if useMemMapping, will check for memmappedfile object in standard location: ops.fprocmmf +% - if ~useMemMapping, will fopen & fclose access to: ops.fproc +% +% varargin accepts modifier string flag(s): +% 'nobuffer' don't add buffer to either side of batch indices; size(dat)==[nsamp,nchan] +% 'noscale' don't scale data by ops.scaleproc +% +% If using memory mapped reads, must follow convention of preprocessDataSub.m +% - handle to memmapfile object must be in: ops.fprocmmf +% - .Data field must be: '.chXsamp' +% (i.e. rez.ops.fprocmmf = memmapfile(filename, 'Format',{datatype, [chInFile nSamp], 'chXsamp'}); +% - in theory, memory mapped reads *might* work without having set up any .fprocmmf handle in advance, +% but good chance it would cause a mess & be very slow in the least. +% ...best bet is to properly setup use of ops.useMemMapping = 1 from the get go +% +% --- +% 2021-04-13 TBC Updated with proper usage of .ntbuff +% Sends data directly to gpuArray +% memory mapped version +% 2021-04-20 TBC Unified version determines read method based on [fid] input type or ops.useMemMapping +% 2021-04-28 TBC Cleaned & commented. +% -Nbatch = ops.Nbatch; -NT = ops.NT; +%% parse inputs & determine read method from inputs -batchstart = 0:NT:NT*Nbatch; % batches start at these timepoints -offset = 2 * ops.Nchan*batchstart(ibatch); % binary file offset in bytes +useMemMapping = getOr(ops, 'useMemMapping', 1); +cleanupFid = 0; % flag to close fid, if opened w/in this function -fid = fopen(ops.fproc, 'r'); -fseek(fid, offset, 'bof'); -dat = fread(fid, [ops.Nchan NT+ops.ntbuff], '*int16'); -dat = dat'; -fclose(fid); +if nargin<3 || isempty(fid) + if useMemMapping + if isfield(ops, 'fprocmmf') && ~isempty(ops.fprocmmf) + % look for memmapped handle in ops struct + mmf = ops.fprocmmf; + else + % prob slow on-the-fly, should pass this as input (ops.fprocmmf) whenever feasible + filename = ops.fproc; + chInFile = ops.NchanTOT; + nSamp = ops.tend; + datatype = 'int16'; + mmf = memmapfile(filename, 'Format',{datatype, [chInFile nSamp], 'chXsamp'}); + end + else + % open preprocessed data file for standard reads + fid = fopen(ops.fproc, 'r'); + cleanupFid = 1; + end + +elseif isa(fid, 'memmapfile') + % input is memmapped file handle + mmf = fid; + useMemMapping = 1; + +elseif ~isempty(fid) + %`standard reads from [fid] input + useMemMapping = 0; +end -% move data to GPU and scale it -dataRAW = gpuArray(dat); -dataRAW = single(dataRAW); -dataRAW = dataRAW / ops.scaleproc; +%% Define batch start, size, & padding (if necessary) + +% check for [varargin] flags +useBuffer = ~any(contains(varargin,'nobuffer', 'IgnoreCase',1)); + +if ~any(contains(varargin,'noscale', 'IgnoreCase',1)) + scaleDat = ops.scaleproc; % scale integer data on load +else + scaleDat = 1; +end + +NT = ops.NT; % samples per batch +ntbuff = ops.ntbuff * useBuffer; % single-end buffer size (samples) +NTwin = [-ntbuff, NT+ntbuff-1]; % first & last sample of this buffered batch + % iNTbuff = -ntbuff:1:(NT+ntbuff-1); % == .NT samples padded w/ .ntbuff on either side + +tstart = ops.tstart; % accomodate non-zero first sample time +tend = ops.tend; + +% starting sample offset for this batch (in samples; 1-based index) +offset = 1 + tstart + NT*(ibatch-1); + +% sample window +sampWin = NTwin + offset; + +% prepad error check +if sampWin(1)<1 % any <=0 + prepad = -sampWin(1) + 1; + sampWin(1) = 1; +else + prepad = 0; +end + +% postpad error check +if sampWin(2)>tend % any <=0 + postpad = sampWin(2) - tend; + sampWin(2) = tend; +else + postpad = 0; +end + + +%% Read data directly to GPU +% - transpose for gpu (gpu orientation ~= dat orientation) +% - convert to singles +% - scale by [ops.scaleproc] + +if useMemMapping + % Read mmf data directly to GPU + dat = gpuArray( single(mmf.Data.chXsamp(:, sampWin(1):sampWin(2))') / scaleDat); + +else + % adapt for fseek bytes & fread inputs + % go to starting point (in bytes) + fseek(fid, 2*(ops.Nchan * (sampWin(1)-1)), 'bof'); + dat = gpuArray( single( fread(fid, [ops.Nchan, diff(sampWin)], '*int16')') / scaleDat); + + % close file if opened w/in this function + if cleanupFid + % clean up + fclose(fid); + end +end + + +%% Pad as necessary (***gpu orientation==[samples, channels]***) +dnom = 3; % rate of padding exponential decay +if prepad + % smooth padding + dat = [dat(1,:) .* exp(linspace(-prepad/dnom,0,prepad+1))'; dat(2:end, :)]; +end + +if postpad + % smooth padding + dat = [dat(1:end-1,:); dat(end,:) .* exp(linspace(0,-postpad/dnom,postpad+1))']; +end + + +end %main function diff --git a/sorting/Kilosort-3.0/utils/gitStatus.m b/sorting/Kilosort-3.0/utils/gitStatus.m index e20944e6..5bd8441f 100644 --- a/sorting/Kilosort-3.0/utils/gitStatus.m +++ b/sorting/Kilosort-3.0/utils/gitStatus.m @@ -20,7 +20,7 @@ % [REVIVAL EXAMPLE] % To revive a git repo to the source/state of Kilosort [rez] struct: % -% % load a saved [rez] struct data file [.PDS] +% % load a saved [rez] struct data file % thisFile = fullfile( myDataPath, 'rez.mat'); % load(thisFile); % ops = rez.ops; diff --git a/sorting/Kilosort-3.0/utils/plotTemplateDynamics.m b/sorting/Kilosort-3.0/utils/plotTemplateDynamics.m new file mode 100644 index 00000000..2f82b629 --- /dev/null +++ b/sorting/Kilosort-3.0/utils/plotTemplateDynamics.m @@ -0,0 +1,103 @@ +function plotTemplateDynamics(rez, theseUnits) +% function plotTemplateDynamics(rez, theseUnits) +% +% Scrappy function to visualize changes in template dynamics across spike extraction. +% - Produces image plots of temporal (W) & spatial (U) templates across each batch of extraction +% - Overlays number of spikes extracted in each batch +% +% INPUTS: +% rez = standard rez struct from Kilosort session (must be from [ks25] codebase; must include .WA & .UA) +% theseUnits = indices of which units to plot +% - if none provided, will randomly select 24 templates to plot +% - if theseUnits=='all', will plot all units (...not recommended for >=40 total units) +% --- +% EXAMPLE +% - run ks25 sort from GUI interface +% - [ks] (a handle to the kilosort object) should be created in base workspace +% - from command window: +% plotTemplateDynamics(ks.rez, ks.rez.troubleUnits) +% --- +% 2021-xx-xx TBC Wrote it. +% + +sz = size(rez.WA); +if isempty(rez.WA) || numel(sz)<3 + fprintf(2, '\tNo record of template dynamics in this rez struct\n') + return +end + +if nargin>1 + if strcmp(theseUnits,'all') + theseUnits = 1:sz(2); + end + nplots = length(theseUnits) +else + nplots = 24; +end +spx = ceil(sqrt(nplots)); +spy = ceil(sqrt(nplots)); + +if ~exist('theseUnits','var') || isempty(theseUnits) + theseUnits = sort(randperm(sz(2),nplots)); +end + +iPC = [1,2]; % which pc to plot + +for k = 1:length(iPC) + H = figure; + set(H, 'name',sprintf('PC%d',iPC(k))); + + for i = 1:nplots + u = theseUnits(i); + + subplot(spx, spy, i) + imagesc( sq(rez.WA(:, u, iPC(k), :)) ); + box off + title(sprintf('unit % 3d || % 3d',u, rez.nsp(u) )); + hold on + if isfield(rez,'invDetected') && any(rez.invDetected(u,:)) + plot(find(rez.invDetected(u,:)), 5, '.r','markersize',5); + end + if isfield(rez,'nspA') + yyaxis right + hl = plot(rez.nspA(u,:)','.'); + set(hl.MarkerHandle, 'Style','hbar', 'size',3); + ylabel('spikes added','fontsize',8) + end + + + end + + addFigInfo(rez.ops, H) +end + +iUA = [1,2,3]; % which UA weight to plot + +for k = 1:length(iUA) + H = figure; + set(H, 'name',sprintf('UA-%d',iUA(k))); + + for i = 1:nplots + u = theseUnits(i); + + subplot(spx, spy, i) + imagesc( sq(rez.UA(:, u, iUA(k), :)) ); + box off + title(sprintf('unit % 3d || % 3d',u, rez.nsp(u) )); + set(gca, 'clim',[-.2,1]) + hold on + if isfield(rez,'invDetected') && any(rez.invDetected(u,:)) + plot(find(rez.invDetected(u,:)), 5, '.r','markersize',5); + end + if isfield(rez,'nspA') + yyaxis right + hl = plot(rez.nspA(u,:)','.'); + set(hl.MarkerHandle, 'Style','hbar', 'size',3); + ylabel('spikes added','fontsize',8) + end + + + end + + addFigInfo(rez.ops, H) +end \ No newline at end of file diff --git a/sorting/Kilosort-3.0/utils/progBar.m b/sorting/Kilosort-3.0/utils/progBar.m new file mode 100644 index 00000000..dd1a8dd8 --- /dev/null +++ b/sorting/Kilosort-3.0/utils/progBar.m @@ -0,0 +1,114 @@ +classdef progBar < dynamicprops + % Command line text progress bar for indexed loops + % EXAMPLE + % Setup: pb = progBar(allIdx, updatePct); + % % allIdx == set of indices used in loop + % % updatePct == percentiles to trigger text updates (def=[10:10:90]) + % + % Use: for i = allIdx + % % do stuff + % pb.check(i); + % end + % + % + % 2020-07-23 TBC wrote object oriented progress bar class (czuba@utexas.edu) + + properties (Access = public) + pct + vals + vals0 + idx + n + txt + d + lims + end + + methods + % constructor + function pb = progBar(vin, pct, varargin) + % Parse inputs & setup default parameters + pp = inputParser(); + pp.addParameter('pct',10:10:90); % update points + pp.parse(varargin{:}); + argin = pp.Results; + + % Apply to object + fn = fieldnames(argin); + for i = 1:length(fn) + % add property if non-standard + if ~isprop(pb, fn{i}) + pb.addprop(fn{i}); + end + pb.(fn{i}) = argin.(fn{i}); + end + + if nargin>1 && ~isempty(pct) + if isscalar(pct) + pct = linspace(0,100,pct); + pct = pct(2:end-1); + end + pb.pct = pct; + end + + % select update values from [vin] + pb.lims = vin([1,end]); + pb.vals = prctile(vin, pb.pct); + % find nearest vals present + [~,pb.idx] = min(abs(pb.vals - vin')); + pb.vals = vin(pb.idx); + pb.vals0 = pb.vals; % backup init state + % info + pb.n = length(pb.vals)+1; + initialize(pb); + pb.d = [repmat('\b',1,pb.n+3),'\n']; + + end + + function initialize(pb) + pb.vals = pb.vals0; + pb.txt = char(kron('.|', ones(1,pb.n))); + end + + + % check for update + function out = check(pb, i) + tmp = []; + hit = false; + if nargin<2 + % display text + tmp = sprintf(['\n[',pb.txt(1:pb.n),']']); + elseif i==pb.lims(1) + % reset text + initialize(pb); + % display text + tmp = sprintf(['\n[',pb.txt(1:pb.n),']']); + + elseif i==pb.vals(1) || i==pb.lims(end) + %increment vals & text list + pb.vals = circshift(pb.vals, -1); + pb.txt = circshift(pb.txt, 1); + % update text & display + tmp = sprintf([pb.d,'[',pb.txt(1:pb.n),']']); + if i==pb.lims(end) + tmp = sprintf('%s Done.\n',tmp); + end + hit = true; + end + + if ~isempty(tmp) + fprintf(tmp) + end + if nargout>0 + out = hit; + end + + end + + function reset(pb) + initialize(pb); + end + end +end + + diff --git a/sorting/Kilosort-3.0/utils/rezMergeToPhy.m b/sorting/Kilosort-3.0/utils/rezMergeToPhy.m new file mode 100644 index 00000000..b2492a6e --- /dev/null +++ b/sorting/Kilosort-3.0/utils/rezMergeToPhy.m @@ -0,0 +1,302 @@ +function rezMergeToPhy(rez1, rez2, savePath) +% function rezMergeToPhy(rez1, rez2, savePath) +% +% Merge two kilosort rez structs into one, +% Save all requisit output files for loading merged dataset into Phy +% +% ~~~ Not Recommended ~~~ +% Integration of template & feature projections of two independent +% kilosort sessions requires recomputing all spike projections +% and reassessing template similarity across merged set. +% ...short of that, only thing recoverable from rough merge of two +% sessions is really high amp units that would probably be tracked +% just fine if they were sorted together in the first place. +% +% --- +% W.I.P.:: does not handle template or feature projections, which +% are typically excluded from rez.mat save struct, and should really +% be recomputed based on merged content (e.g. template similarity +% & feature projections of coherent clusters from each rez session) +% --- +% 2021-06-xx TBC Hacked together based on standard rezToPhy.m +% 2021-06-21 TBC Abandon all hope, ye who enter here... +% + + +%% st3 content: +% % % % From learnAndSolve8b >> runTemplates >> trackAndSort.m +% % % st3(irange,1) = double(st); % spike times +% % % st3(irange,2) = double(id0+1); % spike clusters (1-indexing) +% % % st3(irange,3) = double(x0); % template amplitudes +% % % st3(irange,4) = double(vexp); % residual variance of this spike +% % % st3(irange,5) = ibatch; % batch from which this spike was found + + +%% Parse inputs & combine rez structs + +if ~exist(savePath,'dir') + mkdir(savePath); +elseif strcmp(rez1.ops.saveDir, savePath) || strcmp(rez2.ops.saveDir, savePath) + error('Merged destination directory cannot be the same as either of the input rez structs.'); +elseif exist(savePath,'dir') + savePath = uigetdir(savePath, 'Destination exists, please confirm'); +end + + +% add index for each rez struct +rez1.rid = 1; +rez2.rid = 2; + +% combine rez structs +rez(1) = rez1; +rez(2) = rez2; +ntemps = cumsum([0, arrayfun(@(x) length(x.mu), rez)]); + + +%% clear input rez vars (excess memory overhead) +clear rez1 rez2 + + +%% clear existing/conflicting files from destination +fs = dir(fullfile(savePath, '*.npy')); +for i = 1:length(fs) + delete(fullfile(savePath, fs(i).name)); +end +if exist(fullfile(savePath, '.phy'), 'dir') + rmdir(fullfile(savePath, '.phy'), 's'); +end + + +%% Compile params from input rez structs +spikeTimes = cell2mat(arrayfun(@(x) uint64(x.st3(:,1)), rez, 'uni',0)'); +% spikeTimes = cell2mat(spikeTimes'); % concatenate + +[spikeTimes, ii] = sort(spikeTimes); + +% - add offset to template indices of second rez struct to ensure ids are unique +% - offset must match with index of concatenated template shapes as well +spikeTemplates = cell2mat(arrayfun(@(x) uint32(x.st3(:,2) + ntemps(x.rid)), rez, 'uni',0)'); +spikeTemplates = spikeTemplates(ii); +% NO: st3(:,5) is really batch#, not cluster# (!??...KS1 holdover?) +% if size(rez.st3,2)>4 +% spikeClusters = uint32(1+rez.st3(:,5)); +% end + +... unused: spikeBatch = uint32(rez.st3(:,5)); + +amplitudes = cell2mat(arrayfun(@(x) x.st3(:,3), rez, 'uni',0)'); +amplitudes = amplitudes(ii); +% Calc amplitudes to reflect temporal variations in waveform templates +isgood = cell2mat(arrayfun(@(x) x.good, rez, 'uni',0)'); + +estContam = cell2mat(arrayfun(@(x) x.est_contam_rate, rez, 'uni',0)'); + +% the following fields MUST BE IDENTICAL for both rez structs +Nchan = rez(1).ops.Nchan; + +xcoords = rez(1).xcoords(:); +ycoords = rez(1).ycoords(:); +chanMap = rez(1).ops.chanMap(:); +chanMap0ind = chanMap - 1; + +nt0 = size(rez(1).W,1); + + +U = arrayfun(@(x) x.U, rez, 'uni',0); +U = cat(2, U{:}); % must do two step for multi dimensional +W = arrayfun(@(x) x.W, rez, 'uni',0); +W = cat(2, W{:}); + +% total number of templates +Nfilt = ntemps(end);% size(W,2); + +templates = zeros(Nchan, nt0, Nfilt, 'single'); +for iNN = 1:size(templates,3) + templates(:,:,iNN) = squeeze(U(:,iNN,:)) * squeeze(W(:,iNN,:))'; +end +templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels +templatesInds = repmat((0:size(templates,3)-1), size(templates,1), 1); % we include all channels so this is trivial + +%% Feature & PC projections +% Nope, this really fails. Kludge between this half-measure and Phy's readout of these features +% is no better than actually concatenating the two Kilosort sessions in the first place +% +% % cProj & cProjPC fields are typically excluded from rez.mat save because can balloon file into gigs of data +% % - simply concatenating these values is not quite legitimate, but may roughly gets the job done +% % (...with no more evils than already present in standard kilosort feature calc) +% % - Really, this should recompute features, simiilarity, & pc projections based on concatenated template set (W & U). +% % BUT that involves running through every template & spike waveform (extracted from processed data) +% % which would really need it's own CUDA function, but is probably better done w/in Phy anyway.... +% % - ...so this will have to do for now. +% if isfield(rez, 'cProj') && all(arrayfun(@(x) ~isempty(x.cProj), rez)) +% % cProj are template feature projections +% templateFeatures = cell2mat(arrayfun(@(x) x.cProj, rez, 'uni',0)'); +% % iNeigh are indices into similar **templates** & need to be adjusted to match concatenated template indices +% templateFeatureInds = arrayfun(@(x) uint32(x.iNeigh + ntemps(x.rid)), rez, 'uni',0); +% templateFeatureInds = cat(2, templateFeatureInds{:}); +% % cProjPC are PC projections onto nearby channels +% pcFeatures = arrayfun(@(x) x.cProjPC, rez, 'uni',0); +% pcFeatures = cat(1, pcFeatures{:}); +% % iNeighPC are indices into nearby **channels** & DO NOT need to be adjusted +% pcFeatureInds = arrayfun(@(x) uint32(x.iNeighPC), rez, 'uni',0); +% pcFeatureInds = cat(2, pcFeatureInds{:}); +% +% % templateFeatures = rez.cProj; +% % templateFeatureInds = uint32(rez.iNeigh); +% % pcFeatures = rez.cProjPC; +% % pcFeatureInds = uint32(rez.iNeighPC); +% end + +% Combine whitening matrix & inverse +% Here things get tricky...or maybe not. +% - rezToPhy stopped using the actual whitening matrix when transitioned to datashift method; +% 'whitening_mat.npy' (& the inverse) is just undoing the scaleproc now. +% whiteningMatrix = rez.Wrot/rez.ops.scaleproc; % pre-datashift +% whiteningMatrix = eye(size(rez.Wrot)) / rez.ops.scaleproc; % post-datashift +% So as long as both structs use the same scaleproc, this should be fine +if rez(1).ops.scaleproc ~= rez(end).ops.scaleproc + warning('Incompatible scaling parameters used in rez structs. [rez.ops.scaleproc] must be identical.') + keyboard +end +whiteningMatrix = eye(size(rez(1).Wrot)) / rez(1).ops.scaleproc; +whiteningMatrixInv = whiteningMatrix^-1; + + +%% This section should all 'just work' on the concatenated data +% here we compute the amplitude of every template... + +% unwhiten all the templates +tempsUnW = zeros(size(templates)); +for t = 1:size(templates,1) + tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv; +end + +% The amplitude on each channel is the positive peak minus the negative +tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2)); + +% The template amplitude is the amplitude of its largest channel +tempAmpsUnscaled = max(tempChanAmps,[],2); + +% assign all spikes the amplitude of their template multiplied by their +% scaling amplitudes +spikeAmps = tempAmpsUnscaled(spikeTemplates).*amplitudes; + +% take the average of all spike amps to get actual template amps (since +% tempScalingAmps are equal mean for all templates) +ta = clusterAverage(spikeTemplates, spikeAmps); +tids = unique(spikeTemplates); +tempAmps = zeros(ntemps(end),1); % zeros(numel(rez.mu),1); +tempAmps(tids) = ta; % because ta only has entries for templates that had at least one spike +tempAmps = tempAmps'; % gain is fixed +% gain = getOr(rez.ops, 'gain', 1); +% tempAmps = gain*tempAmps'; % for consistency, make first dimension template number + +if ~isempty(savePath) + fileID = fopen(fullfile(savePath, 'cluster_KSLabel.tsv'),'w'); + fprintf(fileID, 'cluster_id%sKSLabel', char(9)); + fprintf(fileID, char([13 10])); + + fileIDCP = fopen(fullfile(savePath, 'cluster_ContamPct.tsv'),'w'); + fprintf(fileIDCP, 'cluster_id%sContamPct', char(9)); + fprintf(fileIDCP, char([13 10])); + + fileIDA = fopen(fullfile(savePath, 'cluster_Amplitude.tsv'),'w'); + fprintf(fileIDA, 'cluster_id%sAmplitude', char(9)); + fprintf(fileIDA, char([13 10])); + + for j = 1:length(isgood) + if isgood(j) + fprintf(fileID, '%d%sgood', j-1, char(9)); + else + fprintf(fileID, '%d%smua', j-1, char(9)); + end + fprintf(fileID, char([13 10])); + + if isfield(rez, 'est_contam_rate') + fprintf(fileIDCP, '%d%s%.1f', j-1, char(9), estContam(j)*100); + fprintf(fileIDCP, char([13 10])); + end + + fprintf(fileIDA, '%d%s%.1f', j-1, char(9), tempAmps(j)); + fprintf(fileIDA, char([13 10])); + + end + fclose(fileID); + fclose(fileIDCP); + fclose(fileIDA); + + + writeNPY(spikeTimes, fullfile(savePath, 'spike_times.npy')); + writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_templates.npy')); % -1 for zero indexing + + writeNPY(amplitudes, fullfile(savePath, 'amplitudes.npy')); + writeNPY(templates, fullfile(savePath, 'templates.npy')); + writeNPY(templatesInds, fullfile(savePath, 'templates_ind.npy')); + + %chanMap0ind = int32(chanMap0ind); + chanMap0ind = int32([1:Nchan]-1); + writeNPY(chanMap0ind, fullfile(savePath, 'channel_map.npy')); + writeNPY([xcoords ycoords], fullfile(savePath, 'channel_positions.npy')); + + % % Template projections may be salvagable, but exclude for now + % if exist('templateFeatures','var') + % writeNPY(templateFeatures, fullfile(savePath, 'template_features.npy')); + % writeNPY(templateFeatureInds'-1, fullfile(savePath, 'template_feature_ind.npy'));% -1 for zero indexing + % end + + % % Feature projections excluded from rez merge...must be fully recomputed, & beyond scope of this bandaid + % if exist('pcFeatures','var') + % writeNPY(pcFeatures, fullfile(savePath, 'pc_features.npy')); + % writeNPY(pcFeatureInds'-1, fullfile(savePath, 'pc_feature_ind.npy'));% -1 for zero indexing + % end + + writeNPY(whiteningMatrix, fullfile(savePath, 'whitening_mat.npy')); + writeNPY(whiteningMatrixInv, fullfile(savePath, 'whitening_mat_inv.npy')); + + if isfield(rez, 'simScore') + % similarTemplates = cell2mat(arrayfun(@(x) x.simScore, rez, 'uni',0)'); + similarTemplates = zeros(ntemps(end)); + sims = arrayfun(@(x) x.simScore, rez, 'uni',0); + for i = 1:length(sims) + nt = size(sims{i},1); + ii = (1:nt)+ntemps(i); + similarTemplates(ii,ii) = sims{i}; + end + writeNPY(similarTemplates, fullfile(savePath, 'similar_templates.npy')); + end + + + % Duplicate "KSLabel" as "group", a special metadata ID for Phy, so that + % filtering works as expected in the cluster view + KSLabelFilename = fullfile(savePath, 'cluster_KSLabel.tsv'); + copyfile(KSLabelFilename, fullfile(savePath, 'cluster_group.tsv')); + + %make params file + if ~exist(fullfile(savePath,'params.py'),'file') + fid = fopen(fullfile(savePath,'params.py'), 'w'); + + % use relative path name for preprocessed data file in params.py + % - defaults to preprocessed file of last rez struct + % - assuming they're in order + % - **** AND that the preprocessed data file was created with the [ks25] branch + % - which include everything in the preprocessed file from t0 to tend + [~, fname, ext] = fileparts(rez(end).ops.fproc); + copyfile(rez(end).ops.fproc, fullfile(savePath, [fname,ext])); + fprintf(fid, 'dat_path = ''%s''\n', fullfile('.',[fname,ext])); + fprintf(fid,'n_channels_dat = %i\n', Nchan); + fprintf(fid,'dtype = ''int16''\n'); + fprintf(fid,'offset = 0\n'); + if mod(rez(1).ops.fs,1) + fprintf(fid,'sample_rate = %i\n', rez(1).ops.fs); + else + fprintf(fid,'sample_rate = %i.\n', rez(1).ops.fs); + end + fprintf(fid,'hp_filtered = True\n'); + fprintf(fid,'template_scaling = 5.0\n'); + + fclose(fid); + end +end + +end %main function + diff --git a/sorting/Kilosort-3.0/utils/rezToPhy.m b/sorting/Kilosort-3.0/utils/rezToPhy.m deleted file mode 100644 index b1a0fcd6..00000000 --- a/sorting/Kilosort-3.0/utils/rezToPhy.m +++ /dev/null @@ -1,217 +0,0 @@ - -function rezToPhy(rez, savePath, varargin) -% pull out results from kilosort's rez to either return to workspace or to -% save in the appropriate format for the phy GUI to run on. If you provide -% a savePath it should be a folder, and you will need to have npy-matlab -% available (https://github.com/kwikteam/npy-matlab) - - -[~, Nfilt, Nrank] = size(rez.W); -rez.Wphy = cat(1, zeros(1+rez.ops.nt0min, Nfilt, Nrank), rez.W); % for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of the window - -% spikeTimes will be in samples, not seconds -rez.W = gather(single(rez.Wphy)); -rez.U = gather(single(rez.U)); -rez.mu = gather(single(rez.mu)); - -if size(rez.st3,2)>4 - rez.st3 = rez.st3(:,1:4); -end - -[~, isort] = sort(rez.st3(:,1), 'ascend'); -rez.st3 = rez.st3(isort, :); -if ~isempty(rez.cProj) - rez.cProj = rez.cProj(isort, :); - rez.cProjPC = rez.cProjPC(isort, :, :); -end - -% ix = rez.st3(:,4)>12; -% rez.st3 = rez.st3(ix, :); -% rez.cProj = rez.cProj(ix, :); -% rez.cProjPC = rez.cProjPC(ix, :,:); - -fs = dir(fullfile(savePath, '*.npy')); -for i = 1:length(fs) - delete(fullfile(savePath, fs(i).name)); -end -if exist(fullfile(savePath, '.phy'), 'dir') - rmdir(fullfile(savePath, '.phy'), 's'); -end - -spikeTimes = uint64(rez.st3(:,1)); -% account for ops.trange(1) to accomodate real time -spikeTimes = spikeTimes - rez.ops.trange(1)*rez.ops.fs; -% [spikeTimes, ii] = sort(spikeTimes); -spikeTemplates = uint32(rez.st3(:,2)); -if size(rez.st3,2)>4 - spikeClusters = uint32(1+rez.st3(:,5)); -end -amplitudes = rez.st3(:,3); - -Nchan = rez.ops.Nchan; - -xcoords = rez.xcoords(:); -ycoords = rez.ycoords(:); -chanMap = rez.ops.chanMap(:); -chanMap0ind = chanMap - 1; - -nt0 = size(rez.W,1); -U = rez.U; -W = rez.W; - -Nfilt = size(W,2); - -templates = zeros(Nchan, nt0, Nfilt, 'single'); -for iNN = 1:size(templates,3) - templates(:,:,iNN) = rez.mu(iNN,1) * squeeze(U(:,iNN,:)) * squeeze(W(:,iNN,:))'; -end -templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels -templatesInds = repmat([0:size(templates,3)-1], size(templates,1), 1); % we include all channels so this is trivial - -templateFeatures = rez.cProj; -templateFeatureInds = uint32(rez.iNeigh); -pcFeatures = rez.cProjPC; -pcFeatureInds = uint32(rez.iNeighPC); - -% whiteningMatrix = rez.Wrot/rez.ops.scaleproc; -whiteningMatrix = eye(size(rez.Wrot)) / rez.ops.scaleproc; -whiteningMatrixInv = whiteningMatrix^-1; - -% here we compute the amplitude of every template... - -% unwhiten all the templates -tempsUnW = zeros(size(templates)); -for t = 1:size(templates,1) - tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv; -end - -% The amplitude on each channel is the positive peak minus the negative -tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2)); - -% The template amplitude is the amplitude of its largest channel -tempAmpsUnscaled = max(tempChanAmps,[],2); - -% assign all spikes the amplitude of their template multiplied by their -% scaling amplitudes -spikeAmps = tempAmpsUnscaled(spikeTemplates).*amplitudes; - -% take the average of all spike amps to get actual template amps (since -% tempScalingAmps are equal mean for all templates) -ta = clusterAverage(spikeTemplates, spikeAmps); -tids = unique(spikeTemplates); -tempAmps = zeros(numel(rez.mu),1); -tempAmps(tids) = ta; % because ta only has entries for templates that had at least one spike -gain = getOr(rez.ops, 'gain', 1); -tempAmps = gain*tempAmps'; % for consistency, make first dimension template number - - -templateFeatures = []; -if ~isempty(savePath) - fileID = fopen(fullfile(savePath, 'cluster_KSLabel.tsv'),'w'); - fprintf(fileID, 'cluster_id%sKSLabel', char(9)); - fprintf(fileID, char([13 10])); - - fileIDCP = fopen(fullfile(savePath, 'cluster_ContamPct.tsv'),'w'); - fprintf(fileIDCP, 'cluster_id%sContamPct', char(9)); - fprintf(fileIDCP, char([13 10])); - - fileIDA = fopen(fullfile(savePath, 'cluster_Amplitude.tsv'),'w'); - fprintf(fileIDA, 'cluster_id%sAmplitude', char(9)); - fprintf(fileIDA, char([13 10])); - - rez.est_contam_rate(isnan(rez.est_contam_rate)) = 1; - for j = 1:length(rez.good) - if rez.good(j) - fprintf(fileID, '%d%sgood', j-1, char(9)); - else - fprintf(fileID, '%d%smua', j-1, char(9)); - end - fprintf(fileID, char([13 10])); - - fprintf(fileIDCP, '%d%s%.1f', j-1, char(9), rez.est_contam_rate(j)*100); - fprintf(fileIDCP, char([13 10])); - - fprintf(fileIDA, '%d%s%.1f', j-1, char(9), tempAmps(j)); - fprintf(fileIDA, char([13 10])); - - end - fclose(fileID); - fclose(fileIDCP); - fclose(fileIDA); - - - writeNPY(spikeTimes, fullfile(savePath, 'spike_times.npy')); - writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_templates.npy')); % -1 for zero indexing - if size(rez.st3,2)>4 - writeNPY(uint32(spikeClusters-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing - else - writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing - end - writeNPY(amplitudes, fullfile(savePath, 'amplitudes.npy')); - writeNPY(templates, fullfile(savePath, 'templates.npy')); - writeNPY(templatesInds, fullfile(savePath, 'templates_ind.npy')); - - %chanMap0ind = int32(chanMap0ind); - chanMap0ind = int32([1:rez.ops.Nchan]-1); - writeNPY(chanMap0ind, fullfile(savePath, 'channel_map.npy')); - writeNPY([xcoords ycoords], fullfile(savePath, 'channel_positions.npy')); - - if ~isempty(templateFeatures) - writeNPY(templateFeatures, fullfile(savePath, 'template_features.npy')); - writeNPY(templateFeatureInds'-1, fullfile(savePath, 'template_feature_ind.npy'));% -1 for zero indexing - writeNPY(pcFeatures, fullfile(savePath, 'pc_features.npy')); - writeNPY(pcFeatureInds'-1, fullfile(savePath, 'pc_feature_ind.npy'));% -1 for zero indexing - end - - writeNPY(whiteningMatrix, fullfile(savePath, 'whitening_mat.npy')); - writeNPY(whiteningMatrixInv, fullfile(savePath, 'whitening_mat_inv.npy')); - - if isfield(rez, 'simScore') - similarTemplates = rez.simScore; - writeNPY(similarTemplates, fullfile(savePath, 'similar_templates.npy')); - end - - % save a list of "good" clusters for Phy -% fileID = fopen(fullfile(savePath, 'channel_names.tsv'), 'w'); -% fprintf(fileID, 'cluster_id%sKSLabel', char(9)); -% for j = 1:Nchan -% fprintf(fileID, '%d%s%d', j-1,char(9),chanMap0ind(j)); -% fprintf(fileID, char([13 10])); -% end -% fclose(fileID); - - % Duplicate "KSLabel" as "group", a special metadata ID for Phy, so that - % filtering works as expected in the cluster view - KSLabelFilename = fullfile(savePath, 'cluster_KSLabel.tsv'); - copyfile(KSLabelFilename, fullfile(savePath, 'cluster_group.tsv')); - - %make params file - if ~exist(fullfile(savePath,'params.py'),'file') - fid = fopen(fullfile(savePath,'params.py'), 'w'); - -% [~, fname, ext] = fileparts(rez.ops.fbinary); -% fprintf(fid,['dat_path = ''',fname ext '''\n']); -% fprintf(fid,'n_channels_dat = %i\n',rez.ops.NchanTOT); - if ~isempty(varargin) - [root, fname, ext] = fileparts(rez.ops.fbinary); - else - [root, fname, ext] = fileparts(rez.ops.fproc); - end -% fprintf(fid,['dat_path = ''',fname ext '''\n']); - fprintf(fid,['dat_path = ''', strrep(rez.ops.fproc, '\', '/') '''\n']); - - fprintf(fid,'n_channels_dat = %i\n',rez.ops.Nchan); - - fprintf(fid,'dtype = ''int16''\n'); - fprintf(fid,'offset = 0\n'); - if mod(rez.ops.fs,1) - fprintf(fid,'sample_rate = %i\n',rez.ops.fs); - else - fprintf(fid,'sample_rate = %i.\n',rez.ops.fs); - end -% fprintf(fid,'hp_filtered = False'); - fprintf(fid,'hp_filtered = True'); - - fclose(fid); - end -end diff --git a/sorting/Kilosort-3.0/utils/rezToPhy2.m b/sorting/Kilosort-3.0/utils/rezToPhy2.m index 0d2cb451..64a63359 100644 --- a/sorting/Kilosort-3.0/utils/rezToPhy2.m +++ b/sorting/Kilosort-3.0/utils/rezToPhy2.m @@ -1,213 +1,213 @@ function rezToPhy2(rez, savePath, varargin) -% pull out results from kilosort's rez to either return to workspace or to -% save in the appropriate format for the phy GUI to run on. If you provide -% a savePath it should be a folder, and you will need to have npy-matlab -% available (https://github.com/kwikteam/npy-matlab) - - -[~, Nfilt, Nrank] = size(rez.W); -%rez.Wphy = cat(1, zeros(1+rez.ops.nt0min, Nfilt, Nrank), rez.W); % for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of the window -rez.Wphy = rez.W; % if nt0min is centered, we don't need to pad.. - -% spikeTimes will be in samples, not seconds -rez.W = gather(single(rez.Wphy)); -rez.U = gather(single(rez.U)); -rez.mu = gather(single(rez.mu)); - -if size(rez.st3,2)>4 - rez.st3 = rez.st3(:,1:4); -end - -[~, isort] = sort(rez.st3(:,1), 'ascend'); -rez.st3 = rez.st3(isort, :); -if ~isempty(rez.cProj) - rez.cProj = rez.cProj(isort, :); - rez.cProjPC = rez.cProjPC(isort, :, :); -end - -% ix = rez.st3(:,4)>12; -% rez.st3 = rez.st3(ix, :); -% rez.cProj = rez.cProj(ix, :); -% rez.cProjPC = rez.cProjPC(ix, :,:); - -fs = dir(fullfile(savePath, '*.npy')); -for i = 1:length(fs) - delete(fullfile(savePath, fs(i).name)); -end -if exist(fullfile(savePath, '.phy'), 'dir') - rmdir(fullfile(savePath, '.phy'), 's'); -end - -spikeTimes = uint64(rez.st3(:,1)); -% account for ops.trange(1) to accomodate real time -spikeTimes = spikeTimes - rez.ops.trange(1)*rez.ops.fs; -% [spikeTimes, ii] = sort(spikeTimes); -spikeTemplates = uint32(rez.st3(:,2)); -if size(rez.st3,2)>4 - spikeClusters = uint32(1+rez.st3(:,5)); -end -amplitudes = rez.st3(:,3); - -Nchan = rez.ops.Nchan; - -xcoords = rez.xcoords(:); -ycoords = rez.ycoords(:); -chanMap = rez.ops.chanMap(:); -chanMap0ind = chanMap - 1; - -nt0 = size(rez.W,1); -U = rez.U; -W = rez.W; - -Nfilt = size(W,2); - -templates = zeros(Nchan, nt0, Nfilt, 'single'); -for iNN = 1:size(templates,3) - templates(:,:,iNN) = squeeze(U(:,iNN,:)) * squeeze(W(:,iNN,:))'; -end -templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels -templatesInds = repmat([0:size(templates,3)-1], size(templates,1), 1); % we include all channels so this is trivial - -templateFeatures = rez.cProj; -templateFeatureInds = uint32(rez.iNeigh); -pcFeatures = rez.cProjPC; -pcFeatureInds = uint32(rez.iNeighPC); - -whiteningMatrix = rez.Wrot; -whiteningMatrixInv = whiteningMatrix^-1; - -% here we compute the amplitude of every template... - -% unwhiten all the templates -tempsUnW = zeros(size(templates)); -for t = 1:size(templates,1) - tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv; -end - -% The amplitude on each channel is the positive peak minus the negative -tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2)); - -% The template amplitude is the amplitude of its largest channel -tempAmpsUnscaled = max(tempChanAmps,[],2); - -% assign all spikes the amplitude of their template multiplied by their -% scaling amplitudes -spikeAmps = tempAmpsUnscaled(spikeTemplates).*amplitudes; - -% take the average of all spike amps to get actual template amps (since -% tempScalingAmps are equal mean for all templates) -ta = clusterAverage(spikeTemplates, spikeAmps); -tids = unique(spikeTemplates); -tempAmps = zeros(numel(rez.mu),1); -tempAmps(tids) = ta; % because ta only has entries for templates that had at least one spike -gain = getOr(rez.ops, 'gain', 1); -tempAmps = gain*tempAmps'; % for consistency, make first dimension template number - -if ~isempty(savePath) - fileID = fopen(fullfile(savePath, 'cluster_KSLabel.tsv'),'w'); - fprintf(fileID, 'cluster_id%sKSLabel', char(9)); - fprintf(fileID, char([13 10])); - - fileIDCP = fopen(fullfile(savePath, 'cluster_ContamPct.tsv'),'w'); - fprintf(fileIDCP, 'cluster_id%sContamPct', char(9)); - fprintf(fileIDCP, char([13 10])); - - fileIDA = fopen(fullfile(savePath, 'cluster_Amplitude.tsv'),'w'); - fprintf(fileIDA, 'cluster_id%sAmplitude', char(9)); - fprintf(fileIDA, char([13 10])); - - rez.est_contam_rate(isnan(rez.est_contam_rate)) = 1; - for j = 1:length(rez.good) - if rez.good(j) - fprintf(fileID, '%d%sgood', j-1, char(9)); - else - fprintf(fileID, '%d%smua', j-1, char(9)); - end - fprintf(fileID, char([13 10])); - - fprintf(fileIDCP, '%d%s%.1f', j-1, char(9), rez.est_contam_rate(j)*100); - fprintf(fileIDCP, char([13 10])); - - fprintf(fileIDA, '%d%s%.1f', j-1, char(9), tempAmps(j)); - fprintf(fileIDA, char([13 10])); - + % pull out results from kilosort's rez to either return to workspace or to + % save in the appropriate format for the phy GUI to run on. If you provide + % a savePath it should be a folder, and you will need to have npy-matlab + % available (https://github.com/kwikteam/npy-matlab) + + + [~, Nfilt, Nrank] = size(rez.W); + %rez.Wphy = cat(1, zeros(1+rez.ops.nt0min, Nfilt, Nrank), rez.W); % for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of the window + rez.Wphy = rez.W; % if nt0min is centered, we don't need to pad.. + + % spikeTimes will be in samples, not seconds + rez.W = gather(single(rez.Wphy)); + rez.U = gather(single(rez.U)); + rez.mu = gather(single(rez.mu)); + + if size(rez.st3,2)>4 + rez.st3 = rez.st3(:,1:4); end - fclose(fileID); - fclose(fileIDCP); - fclose(fileIDA); + [~, isort] = sort(rez.st3(:,1), 'ascend'); + rez.st3 = rez.st3(isort, :); + if ~isempty(rez.cProj) + rez.cProj = rez.cProj(isort, :); + rez.cProjPC = rez.cProjPC(isort, :, :); + end - writeNPY(spikeTimes, fullfile(savePath, 'spike_times.npy')); - writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_templates.npy')); % -1 for zero indexing + % ix = rez.st3(:,4)>12; + % rez.st3 = rez.st3(ix, :); + % rez.cProj = rez.cProj(ix, :); + % rez.cProjPC = rez.cProjPC(ix, :,:); + + fs = dir(fullfile(savePath, '*.npy')); + for i = 1:length(fs) + delete(fullfile(savePath, fs(i).name)); + end + if exist(fullfile(savePath, '.phy'), 'dir') + rmdir(fullfile(savePath, '.phy'), 's'); + end + + spikeTimes = uint64(rez.st3(:,1)); + % account for ops.trange(1) to accomodate real time + spikeTimes = spikeTimes - rez.ops.trange(1)*rez.ops.fs; + % [spikeTimes, ii] = sort(spikeTimes); + spikeTemplates = uint32(rez.st3(:,2)); if size(rez.st3,2)>4 - writeNPY(uint32(spikeClusters-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing - else - writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing + spikeClusters = uint32(1+rez.st3(:,5)); end - writeNPY(amplitudes, fullfile(savePath, 'amplitudes.npy')); - writeNPY(templates, fullfile(savePath, 'templates.npy')); - writeNPY(templatesInds, fullfile(savePath, 'templates_ind.npy')); - - %chanMap0ind = int32(chanMap0ind); - chanMap0ind = int32([1:rez.ops.Nchan]-1); - writeNPY(chanMap0ind, fullfile(savePath, 'channel_map.npy')); - writeNPY([xcoords ycoords], fullfile(savePath, 'channel_positions.npy')); - - if ~isempty(templateFeatures) - writeNPY(templateFeatures, fullfile(savePath, 'template_features.npy')); - writeNPY(templateFeatureInds'-1, fullfile(savePath, 'template_feature_ind.npy'));% -1 for zero indexing - writeNPY(pcFeatures, fullfile(savePath, 'pc_features.npy')); - writeNPY(pcFeatureInds'-1, fullfile(savePath, 'pc_feature_ind.npy'));% -1 for zero indexing + amplitudes = rez.st3(:,3); + + Nchan = rez.ops.Nchan; + + xcoords = rez.xcoords(:); + ycoords = rez.ycoords(:); + chanMap = rez.ops.chanMap(:); + chanMap0ind = chanMap - 1; + + nt0 = size(rez.W,1); + U = rez.U; + W = rez.W; + + Nfilt = size(W,2); + + templates = zeros(Nchan, nt0, Nfilt, 'single'); + for iNN = 1:size(templates,3) + templates(:,:,iNN) = squeeze(U(:,iNN,:)) * squeeze(W(:,iNN,:))'; end - - writeNPY(whiteningMatrix, fullfile(savePath, 'whitening_mat.npy')); - writeNPY(whiteningMatrixInv, fullfile(savePath, 'whitening_mat_inv.npy')); - - if isfield(rez, 'simScore') - similarTemplates = rez.simScore; - writeNPY(similarTemplates, fullfile(savePath, 'similar_templates.npy')); + templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels + templatesInds = repmat([0:size(templates,3)-1], size(templates,1), 1); % we include all channels so this is trivial + + templateFeatures = rez.cProj; + templateFeatureInds = uint32(rez.iNeigh); + pcFeatures = rez.cProjPC; + pcFeatureInds = uint32(rez.iNeighPC); + + whiteningMatrix = rez.Wrot; + whiteningMatrixInv = whiteningMatrix^-1; + + % here we compute the amplitude of every template... + + % unwhiten all the templates + tempsUnW = zeros(size(templates)); + for t = 1:size(templates,1) + tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv; end - - % save a list of "good" clusters for Phy -% fileID = fopen(fullfile(savePath, 'channel_names.tsv'), 'w'); -% fprintf(fileID, 'cluster_id%sKSLabel', char(9)); -% for j = 1:Nchan -% fprintf(fileID, '%d%s%d', j-1,char(9),chanMap0ind(j)); -% fprintf(fileID, char([13 10])); -% end -% fclose(fileID); - - % Duplicate "KSLabel" as "group", a special metadata ID for Phy, so that - % filtering works as expected in the cluster view - KSLabelFilename = fullfile(savePath, 'cluster_KSLabel.tsv'); - copyfile(KSLabelFilename, fullfile(savePath, 'cluster_group.tsv')); - - %make params file - if ~exist(fullfile(savePath,'params.py'),'file') - fid = fopen(fullfile(savePath,'params.py'), 'w'); - -% [~, fname, ext] = fileparts(rez.ops.fbinary); -% fprintf(fid,['dat_path = ''',fname ext '''\n']); -% fprintf(fid,'n_channels_dat = %i\n',rez.ops.NchanTOT); - if ~isempty(varargin) - [root, fname, ext] = fileparts(rez.ops.fbinary); - else - [root, fname, ext] = fileparts(rez.ops.fproc); + + % The amplitude on each channel is the positive peak minus the negative + tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2)); + + % The template amplitude is the amplitude of its largest channel + tempAmpsUnscaled = max(tempChanAmps,[],2); + + % assign all spikes the amplitude of their template multiplied by their + % scaling amplitudes + spikeAmps = tempAmpsUnscaled(spikeTemplates).*amplitudes; + + % take the average of all spike amps to get actual template amps (since + % tempScalingAmps are equal mean for all templates) + ta = clusterAverage(spikeTemplates, spikeAmps); + tids = unique(spikeTemplates); + tempAmps = zeros(numel(rez.mu),1); + tempAmps(tids) = ta; % because ta only has entries for templates that had at least one spike + gain = getOr(rez.ops, 'gain', 1); + tempAmps = gain*tempAmps'; % for consistency, make first dimension template number + + if ~isempty(savePath) + fileID = fopen(fullfile(savePath, 'cluster_KSLabel.tsv'),'w'); + fprintf(fileID, 'cluster_id%sKSLabel', char(9)); + fprintf(fileID, char([13 10])); + + fileIDCP = fopen(fullfile(savePath, 'cluster_ContamPct.tsv'),'w'); + fprintf(fileIDCP, 'cluster_id%sContamPct', char(9)); + fprintf(fileIDCP, char([13 10])); + + fileIDA = fopen(fullfile(savePath, 'cluster_Amplitude.tsv'),'w'); + fprintf(fileIDA, 'cluster_id%sAmplitude', char(9)); + fprintf(fileIDA, char([13 10])); + + rez.est_contam_rate(isnan(rez.est_contam_rate)) = 1; + for j = 1:length(rez.good) + if rez.good(j) + fprintf(fileID, '%d%sgood', j-1, char(9)); + else + fprintf(fileID, '%d%smua', j-1, char(9)); + end + fprintf(fileID, char([13 10])); + + fprintf(fileIDCP, '%d%s%.1f', j-1, char(9), rez.est_contam_rate(j)*100); + fprintf(fileIDCP, char([13 10])); + + fprintf(fileIDA, '%d%s%.1f', j-1, char(9), tempAmps(j)); + fprintf(fileIDA, char([13 10])); + end -% fprintf(fid,['dat_path = ''',fname ext '''\n']); - % make path flexible to final location of results for Phy - fprintf(fid,'import pathlib\ndat_path = f"{pathlib.Path().resolve()}/proc.dat"\n'); %['dat_path = ''', strrep(rez.ops.fproc, '\', '/') '''\n']); - fprintf(fid,'n_channels_dat = %i\n',rez.ops.Nchan); - fprintf(fid,'dtype = ''int16''\n'); - fprintf(fid,'offset = 0\n'); - if mod(rez.ops.fs,1) - fprintf(fid,'sample_rate = %i\n',rez.ops.fs); + fclose(fileID); + fclose(fileIDCP); + fclose(fileIDA); + + + writeNPY(spikeTimes, fullfile(savePath, 'spike_times.npy')); + writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_templates.npy')); % -1 for zero indexing + if size(rez.st3,2)>4 + writeNPY(uint32(spikeClusters-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing else - fprintf(fid,'sample_rate = %i.\n',rez.ops.fs); + writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing end -% fprintf(fid,'hp_filtered = False'); - fprintf(fid,'hp_filtered = True'); - - fclose(fid); - end -end + writeNPY(amplitudes, fullfile(savePath, 'amplitudes.npy')); + writeNPY(templates, fullfile(savePath, 'templates.npy')); + writeNPY(templatesInds, fullfile(savePath, 'templates_ind.npy')); + + %chanMap0ind = int32(chanMap0ind); + chanMap0ind = int32([1:rez.ops.Nchan]-1); + writeNPY(chanMap0ind, fullfile(savePath, 'channel_map.npy')); + writeNPY([xcoords ycoords], fullfile(savePath, 'channel_positions.npy')); + + if ~isempty(templateFeatures) + writeNPY(templateFeatures, fullfile(savePath, 'template_features.npy')); + writeNPY(templateFeatureInds'-1, fullfile(savePath, 'template_feature_ind.npy'));% -1 for zero indexing + writeNPY(pcFeatures, fullfile(savePath, 'pc_features.npy')); + writeNPY(pcFeatureInds'-1, fullfile(savePath, 'pc_feature_ind.npy'));% -1 for zero indexing + end + + writeNPY(whiteningMatrix, fullfile(savePath, 'whitening_mat.npy')); + writeNPY(whiteningMatrixInv, fullfile(savePath, 'whitening_mat_inv.npy')); + + if isfield(rez, 'simScore') + similarTemplates = rez.simScore; + writeNPY(similarTemplates, fullfile(savePath, 'similar_templates.npy')); + end + + % save a list of "good" clusters for Phy + % fileID = fopen(fullfile(savePath, 'channel_names.tsv'), 'w'); + % fprintf(fileID, 'cluster_id%sKSLabel', char(9)); + % for j = 1:Nchan + % fprintf(fileID, '%d%s%d', j-1,char(9),chanMap0ind(j)); + % fprintf(fileID, char([13 10])); + % end + % fclose(fileID); + + % Duplicate "KSLabel" as "group", a special metadata ID for Phy, so that + % filtering works as expected in the cluster view + KSLabelFilename = fullfile(savePath, 'cluster_KSLabel.tsv'); + copyfile(KSLabelFilename, fullfile(savePath, 'cluster_group.tsv')); + + %make params file + if ~exist(fullfile(savePath,'params.py'),'file') + fid = fopen(fullfile(savePath,'params.py'), 'w'); + + % [~, fname, ext] = fileparts(rez.ops.fbinary); + % fprintf(fid,['dat_path = ''',fname ext '''\n']); + % fprintf(fid,'n_channels_dat = %i\n',rez.ops.NchanTOT); + if ~isempty(varargin) + [root, fname, ext] = fileparts(rez.ops.fbinary); + else + [root, fname, ext] = fileparts(rez.ops.fproc); + end + % fprintf(fid,['dat_path = ''',fname ext '''\n']); + % make path flexible to final location of results for Phy + fprintf(fid,'import pathlib\ndat_path = f"{pathlib.Path().resolve()}/proc.dat"\n'); %['dat_path = ''', strrep(rez.ops.fproc, '\', '/') '''\n']); + fprintf(fid,'n_channels_dat = %i\n',rez.ops.Nchan); + fprintf(fid,'dtype = ''int16''\n'); + fprintf(fid,'offset = 0\n'); + if mod(rez.ops.fs,1) + fprintf(fid,'sample_rate = %i\n',rez.ops.fs); + else + fprintf(fid,'sample_rate = %i.\n',rez.ops.fs); + end + % fprintf(fid,'hp_filtered = False'); + fprintf(fid,'hp_filtered = True'); + + fclose(fid); + end + end \ No newline at end of file diff --git a/sorting/Kilosort-3.0/utils/saveFigTriplet.m b/sorting/Kilosort-3.0/utils/saveFigTriplet.m new file mode 100644 index 00000000..6386a696 --- /dev/null +++ b/sorting/Kilosort-3.0/utils/saveFigTriplet.m @@ -0,0 +1,161 @@ +function saveFigTriplet(withdate, infostr, fileflags, figSubDir, varargin) +% function saveFigTriplet(withdate, infostr, fileflags, figSubDir, varargin) +% fileflags is logical array for saving: [tiff, png, eps, mat, pdf] +% +% Saves current figure as .fig, .pdf, and .eps (by default) +% 3rd input selects filetypes as logical index: [tiff, png, eps, mat, pdf] +% File name comes from figure name: get(gcf, 'name') +% --by default adds current date to file name +% Mines calling workspace for path to figure directory ('figDir') +% --creates dir structure within figDir if not already present +% +% +% 2014-04-16 TBC Wrote it. (czuba@utexas.edu) +% 2021-05-18 TBC Cleaned & commented +% + +% NOTE: Rasterized formats (jpg, png, tiff) *consistently* come out rotated +% by 90 deg, but content spacing and aspect ratio are correct & unclipped (!?@!!) +% Vector formats (pdf & eps) of the same figures are oriented correctly...go figure. +% -- TBC + +H = gcf; + +%% Parse inputs +if ~exist('withdate','var') || isempty(withdate) + % default: append date to filename as '_yyyymmmdd' + withdate = 1; +end + +% file save types (default [.mat, .pdf, .eps]) ...convoluted, but backwards compatible and extensible +filetypes = struct('tiff',0, 'png',0, 'eps',1, 'matlab',1, 'pdf',1); +filefn = fieldnames(filetypes); + +if nargin>2 && ~isempty(fileflags) + if ischar(fileflags) + fileflags = {fileflags}; + end + if iscell(fileflags) + % string input + for i = 1:length(filefn) + filetypes.(filefn{i}) = contains(filefn{i},fileflags); + end + else + for i = 1:length(fileflags) + filetypes.(filefn{i}) = fileflags(i); + end + end +end + +if nargin <4 || isempty(figSubDir) + figSubDir = []; +end + +% get vars from figure tag or caller wkspc +Htag = get(H, 'tag'); +if ~isempty(Htag) && contains(Htag, filesep) + figDir = Htag; + +elseif evalin('caller','exist(''figDir'',''var'')') + figDir = evalin('caller','figDir'); +end + + +%% Determine filename +fname = get(H,'name'); +if withdate + datefmt = 'yyyymmmdd'; + if withdate>1 + datefmt = [datefmt,'-HHMMSS']; + end + fname = [fname,'_',datestr(now, datefmt)]; +end + +fname(fname==32)='_'; % no spaces in figure name! +fname(fname==44)='_'; % no commas either! + +% default name as date w/time to prevent crash for saving w/o name +if isempty(fname) + fname = datestr(now,'yyyymmmdd_HH-MM-ss'); +end + +if ~exist('figDir','var') || isempty(figDir) + figDir = fullfile(pwd,'figs'); +end + + +%% Append info string to bottom of figure +% - if provided as input or if exist in the workspace +ax = gca; +fsz = 10; +% check for infostr in calling directory if not passed as input +if ~exist('infostr','var') || isempty(infostr) + if evalin('caller','exist(''infostr'',''var'')') + infostr = evalin('caller','infostr'); + end +end +% now apply if one exists +if exist('infostr','var') && ~isempty(infostr) + axes('position',[0,.002,1,.02],'visible','off'); + % shrinking text if multiple lines + if contains(infostr, {sprintf('\n'),sprintf('\n\r')}), fsz = 8; end + text(0,0, infostr, 'verticalAlignment','bottom', 'interpreter','none', 'fontsize',fsz); +end +axes(ax); + + +%% Create dest directories if don't already exist +t = logical(struct2array(filetypes)); filefn = fieldnames(filetypes); +t = filefn(t); + +% Make dirs +for i = 1:length(t) + if ~exist(fullfile(figDir,t{i}),'dir') + mkdir(fullfile(figDir,t{i})); + end + if ~isempty('figSubDir') && ~exist(fullfile(figDir,t{i},figSubDir),'dir') + mkdir(fullfile(figDir,t{i},figSubDir)); + end +end + + +%% Do the saving + +% Matlab figure +if filetypes.matlab + savefig(H, fullfile(figDir,'matlab',figSubDir,[fname,'.fig']), 'compact') % smaller +end + +if filetypes.tiff + saveas(H,fullfile(figDir,'tiff',figSubDir,[fname,'.tif']),'tiff') +end + +% PNG +if filetypes.png + saveas(H,fullfile(figDir,'png',figSubDir,[fname,'.png']),'png') +end + + +% PDF +if filetypes.pdf + try + eval(['print ','-dpdf -r400',' -painters ',fullfile(figDir,'pdf',figSubDir,[fname,'.','pdf'])]); + % eval(['print ','-dpdf -r400',' -opengl ',fullfile(figDir,'pdf',[fname,'.','pdf'])]); + catch, fprintf(2,'\t\tErrored trying to save %s file: %s\n', 'pdf', fname); + end +end + +% EPS +if filetypes.eps + try + eval(['print ','-depsc',' -painters ',fullfile(figDir,'eps',figSubDir,[fname,'.','eps'])]); + % eval(['print ','-depsc',' -opengl ',fullfile(figDir,'eps',[fname,'.','eps'])]); + catch, fprintf(2,'\t\tErrored trying to save %s file: %s\n', 'eps', fname); + end +end + +% Tell em what you did +fprintf(2, '\b\tSaved to: %s\n ', figDir); + + +end %main function diff --git a/sorting/Kilosort-3.0/utils/svdecon.m b/sorting/Kilosort-3.0/utils/svdecon.m index b9eb832f..574fc053 100644 --- a/sorting/Kilosort-3.0/utils/svdecon.m +++ b/sorting/Kilosort-3.0/utils/svdecon.m @@ -51,6 +51,6 @@ U = X*V; % convert evecs from X'*X to X*X'. the evals are the same. %s = sqrt(sum(U.^2,1))'; s = sqrt(d); - U = bsxfun(@(x,c)x./c, gather(U), s'); + U = bsxfun(@(x,c)x./c, U, s'); % U = bsxfun(@(x,c)x./c, gather(U), s'); S = diag(s); end diff --git a/sorting/Kilosort-3.0/utils/updateProgressMessage.m b/sorting/Kilosort-3.0/utils/updateProgressMessage.m new file mode 100644 index 00000000..1d3b4276 --- /dev/null +++ b/sorting/Kilosort-3.0/utils/updateProgressMessage.m @@ -0,0 +1,53 @@ +function updateProgressMessage(n, ntot, tbase, len, freq) + +%% Defaults +if nargin<5 || isempty(freq) + freq = 1; % freq of full message updates; else just print '.' on each call +end + +if nargin<4 || isempty(len) + len = 100; % padded string length +end + +if nargin<3 || isempty(tbase) + t = toc; +else + t = toc(tbase); +end + + +%% Make string + +if mod(n, freq)>0 && n~=ntot + % do nothing...just passing through + fprintf('.') + return +else + % Create progress message & print to command window + % update times + if t<90 + tstr = sprintf('%2.2f sec elapsed', t); + else + tstr = sprintf('%2.2f min elapsed', t/60); + end + secPerN = t/n; + tRemEstimate = secPerN * (ntot-n); + if tRemEstimate<90 + rstr = sprintf('%2.2f sec remaining', tRemEstimate); + else + rstr = sprintf('%2.2f min remaining', tRemEstimate/60); + end + % update message + msg = sprintf('\nfinished %4i of %i. (%s, ~%s; %2.2f sec/each)..',n, ntot, tstr, rstr, secPerN); + + % clear previous message + if n>freq + fprintf(repmat('\b',1,len + freq-1)); + end + % print message + fprintf(pad(msg, len, '.')); + +end + +end %main function + diff --git a/sorting/Kilosort-3.0/utils/vec2tick.m b/sorting/Kilosort-3.0/utils/vec2tick.m new file mode 100644 index 00000000..96455a19 --- /dev/null +++ b/sorting/Kilosort-3.0/utils/vec2tick.m @@ -0,0 +1,22 @@ +function st = vec2tick(vec, fmt) +% function st = vec2tick(vec, fmt) +% +% convert vector[vec] of numbers to a cell of strings. +% [fmt] = text format string (def= '%2.2f ') +% only outputs single row of stringified numbers (no matrix possible...) +% +% 2107-01-04 TBC Wrote it. + +if nargin<2 || isempty('fmt') || ~ischar(fmt) + % must be a formatting string + fmt = '%2.2f '; +elseif (fmt(end)*1) ~= (1*' ') + % format must end in space + fmt = [fmt, ' ']; +end + +st = textscan( num2str(vec(:)', fmt), '%s'); +st = st{:}; + +end %main function + diff --git a/sorting/Kilosort_config_3.m b/sorting/Kilosort_config_3.m index e60c5dd0..3aa07aa8 100644 --- a/sorting/Kilosort_config_3.m +++ b/sorting/Kilosort_config_3.m @@ -11,13 +11,13 @@ ops.minfr_goodchannels = 0; % this does nothing in kilosort 3 % threshold on projections (like in Kilosort1, can be different for last pass like [10 4]) -ops.Th = [9 8]; +ops.Th = [10 4]; % how important is the amplitude penalty (like in Kilosort1, 0 means not used, 10 is average, 50 is a lot) ops.lam = 10; % splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) -ops.AUCsplit = 0.9; +ops.AUCsplit = 0.9; % this measures area of the gaussian mixture component minus the area of overlap with other bimodal component % minimum spike rate (Hz), if a cluster falls below this for too long it gets removed ops.minFR = 1/50; @@ -28,12 +28,12 @@ % spatial constant in um for computing residual variance of spike ops.sigmaMask = 30; -% threshold crossings for pre-clustering (in PCA projection space) -ops.ThPre = 8; +% threshold crossings for pre-clustering (in PCA projection space) (not used) +% ops.ThPre = 8; %% danger, changing these settings can lead to fatal errors % options for determining PCs -ops.spkTh = -6; % spike threshold in standard deviations (-6) +ops.spkTh = -6; % spike threshold in standard deviations (-6) (not used) ops.reorder = 1; % whether to reorder batches for drift correction. ops.nskip = 25; % how many batches to skip for determining spike PCs @@ -52,4 +52,8 @@ ops.nPCs = 3; % how many PCs to project the spikes into ops.useRAM = 0; % not yet available +%% additional settings +ops.CAR = 1; % whether to perform CAR +ops.loc_range = [5 4]; % area to detect peaks; plus/minus for both time and channel +ops.long_range = [30 6]; % range to detect isolated peaks: [timepoints channels] %% \ No newline at end of file diff --git a/sorting/Kilosort_config_czuba.m b/sorting/Kilosort_config_czuba.m index fb00ac5d..c9a560e2 100644 --- a/sorting/Kilosort_config_czuba.m +++ b/sorting/Kilosort_config_czuba.m @@ -1,4 +1,3 @@ - ops = []; ops.fs = 30000; @@ -67,7 +66,8 @@ %% Apply changes to standard ops -ops.fshigh = 300; % map system has hardware high pass filters at 300 +load(fullfile(script_dir, '/tmp/config.mat'), 'myo_data_passband') +ops.fshigh = myo_data_passband(1); % map system has hardware high pass filters at 300 % make waveform length independent of sampling rate % ops.nt0 = ceil( 0.002 * ops.fs); % width of waveform templates (makes consistent N-ms either side of spike, regardless of sampling rate) @@ -98,7 +98,7 @@ % threshold(s) used when establishing baseline templates from raw data % - standard codebase tends to [frustratingly] overwrite this param, but working to straighten out those instances ops.spkTh = -6; % [def= -6] -ops.ThPre = 8; % [def= 8] +% ops.ThPre = 8; % [def= 8] Not used % splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) % - only relevant if post-hoc merges & splits are used (which is not recommended, see flags above) diff --git a/sorting/Kilosort_gridsearch_config.py b/sorting/Kilosort_gridsearch_config.py new file mode 100644 index 00000000..df4290d3 --- /dev/null +++ b/sorting/Kilosort_gridsearch_config.py @@ -0,0 +1,24 @@ +from sklearn.model_selection import ParameterGrid + + +# Define the parameter grid to be searched over using ops variables from Kilosort_config_3.m +# All parameter combinations are tried, so be careful to consider the total number of combinations, +# which is the product of the numbers of elements in each dictionary element. +def get_KS_params_grid(): + grid = dict( + # Th=[[12, 6], [10, 5], [8, 4], [7, 3], [5, 2], [4, 2], [2, 1], [1, 0.5]], + # Th=[[10, 4], [7, 3], [5, 2.5], [4, 2], [3, 1.5], [2, 1], [1.5, 0.75], [1, 0.5]], + # Th=[[10, 2], [9, 2], [8, 2], [7, 2], [6, 2], [5, 2], [4, 2]], + # Th=[[10, 4], [5, 2], [2, 1], [1, 0.5]], + # Th=[[3, 1.5], [2, 1], [1.5, 0.75], [1, 0.5]], + ## v + Th=[[10, 4], [7, 3], [5, 2], [2, 1]], + spkTh=[[-6], [-3, -6, -9]], + ## ^ + # long_range=[[30, 3], [30, 1]], + # lam=[10, 15], + # nfilt_factor=[12, 4], + # AUCsplit=[0.8, 0.9], + # momentum=[[20, 400], [60, 600]], + ) + return ParameterGrid(grid) diff --git a/sorting/Kilosort_run_czuba.m b/sorting/Kilosort_run_czuba.m index fbc0ca8c..e59207a3 100644 --- a/sorting/Kilosort_run_czuba.m +++ b/sorting/Kilosort_run_czuba.m @@ -38,7 +38,7 @@ rez = datashift2(rez, 1); -rez.W = []; rez.U = [];, rez.mu = []; +rez.W = []; rez.U = []; rez.mu = []; rez = learnAndSolve8b(rez, now); % OPTIONAL: remove double-counted spikes - solves issue in which individual spikes are assigned to multiple templates. diff --git a/sorting/Kilosort_run_myo_3.m b/sorting/Kilosort_run_myo_3.m index 263f4ec1..d7dee32c 100644 --- a/sorting/Kilosort_run_myo_3.m +++ b/sorting/Kilosort_run_myo_3.m @@ -1,71 +1,121 @@ -script_dir = pwd; % get directory where repo exists -load(fullfile(script_dir, '/tmp/config.mat')) -load(fullfile(myo_sorted_dir, 'brokenChan.mat')) -chanMapFile = myo_chan_map_file -disp(['Using this channel map: ' chanMapFile]) +function rez = Kilosort_run_myo_3(ops_input_params) + dbstop if error + script_dir = pwd; % get directory where repo exists + load(fullfile(script_dir, '/tmp/config.mat')) + load(fullfile(myo_sorted_dir, 'brokenChan.mat')) -% load and modify channel map variables to remove broken channel elements, if desired -if length(brokenChan) > 0 && remove_bad_myo_chans(1) ~= false - load(chanMapFile) - disp('Broken channels were just removed from that channel map') - load(myo_chan_map_file) - chanMap(end-length(brokenChan)+1:end) = []; % take off end to save indexing - chanMap0ind(end-length(brokenChan)+1:end) = []; % take off end to save indexing - connected(brokenChan) = []; - kcoords(brokenChan) = []; - xcoords(brokenChan) = []; - ycoords(brokenChan) = []; - save(fullfile(myo_sorted_dir, 'chanMap_minus_brokenChans.mat'), 'chanMap', 'connected', 'xcoords', 'ycoords', 'kcoords', 'chanMap0ind', 'fs', 'name') - chanMapFile = fullfile(myo_sorted_dir, 'chanMap_minus_brokenChans.mat'); -end + % set GPU to use + disp(strcat("Setting GPU device to use: ", num2str(GPU_to_use))) + gpuDevice(GPU_to_use); -try - restoredefaultpath -end -dbstop if error + % get and set channel map + if ~isempty(brokenChan) && remove_bad_myo_chans(1) ~= false + chanMapFile = fullfile(myo_sorted_dir, 'chanMapAdjusted.mat'); + else + chanMapFile = myo_chan_map_file; + end + disp(['Using this channel map: ' chanMapFile]) -addpath(genpath([script_dir '/sorting/Kilosort-3.0'])) -addpath(genpath([script_dir '/sorting/npy-matlab'])) + try + restoredefaultpath + end -run([script_dir '/sorting/Kilosort_config_3.m']); -ops.fbinary = fullfile(myo_sorted_dir, 'data.bin'); -ops.fproc = fullfile(myo_sorted_dir, 'proc.dat'); -ops.brokenChan = fullfile(myo_sorted_dir, 'brokenChan.mat'); -ops.chanMap = fullfile(chanMapFile); -ops.NchanTOT = double(num_chans - length(brokenChan)); -ops.nt0 = 201; -ops.NT = 2 * 64 * 1024 + ops.ntbuff; -ops.sigmaMask = Inf; % we don't want a distance-dependant decay -ops.Th = [9 8]; -ops.nfilt_factor = 4; -ops.nblocks = 0; -ops.nt0min = ceil(ops.nt0 / 2); -ops.nPCs = 6; -ops.nEig = 3; -ops.lam = 10; % amplitude penalty (0 means not used, 10 is average, 50 is a lot) -ops.ThPre = 8; % threshold crossings for pre-clustering (in PCA projection space) + addpath(genpath([script_dir '/sorting/Kilosort-3.0'])) + addpath(genpath([script_dir '/sorting/npy-matlab'])) -if trange(2) == 0 - ops.trange = [0 Inf]; -else - ops.trange = trange; -end + run([script_dir '/sorting/Kilosort_config_3.m']); -ops + ops.fbinary = fullfile(myo_sorted_dir, 'data.bin'); + ops.fproc = fullfile(myo_sorted_dir, 'proc.dat'); + ops.brokenChan = fullfile(myo_sorted_dir, 'brokenChan.mat'); + ops.chanMap = fullfile(chanMapFile); + ops.NchanTOT = double(num_chans); %double(max(num_chans - length(brokenChan), 9)); + ops.nt0 = 61; + ops.ntbuff = 64; % defined as 64; + ops.NT = 2048 * 32 + ops.ntbuff; % convert to 32 count increments of samples % defined as 2048 * 32 + ops.ntbuff; + ops.sigmaMask = Inf; % we don't want a distance-dependant decay + ops.nPCs = 9; % how many PCs to project the spikes into (also used as number of template prototypes) + ops.nEig = ops.nPCs; % rank of svd for templates, % keep same as nPCs to avoid error + ops.Th = [6 2]; % threshold crossings for pre-clustering (in PCA projection space) + ops.spkTh = -2; % spike threshold in standard deviations (-6 default) (only used in isolated_peaks_new) + ops.nfilt_factor = 12; % max number of clusters per good channel (even temporary ones) + ops.nblocks = 0; + ops.nt0min = ceil(ops.nt0 / 2); % peak of template match will be this many points away from beginning + ops.nskip = 1; % how many batches to skip for determining spike PCs + ops.nSkipCov = 1; % compute whitening matrix and prototype templates using every N-th batch + ops.lam = 15; % amplitude penalty (0 means not used, 10 is average, 50 is a lot) + ops.CAR = 0; % whether to perform CAR + ops.loc_range = [5 1]; % [timepoints channels], area to detect peaks; plus/minus for both time and channel. Doing abs() of data during peak isolation, so using 4 instead of default 5. Only 1 channel to avoid elimination of waves + ops.long_range = [ops.nt0min 1]; % [timepoints channels], range within to use only the largest peak + ops.fig = 1; % whether to plot figures + ops.recordings = recordings; -rez = preprocessDataSub(ops); -rez = datashift2(rez, 1); -[rez, st3, tF] = extract_spikes(rez); -rez = template_learning(rez, tF, st3); -[rez, st3, tF] = trackAndSort(rez); -rez = final_clustering(rez, tF, st3); -rez = find_merges(rez, 1); + %% gridsearch section + % only try to use gridsearch values if ops_input_params is a struct and fields are present + if isa(ops_input_params, 'struct') && ~isempty(fieldnames(ops_input_params)) + % Combine input ops into the existing ops struct + fields = fieldnames(ops_input_params); + for iField = 1:size(fields, 1) + ops.(fields{iField}) = ops_input_params.(fields{iField}); + end + % ops.NT = ops.nt0 * 32 + ops.ntbuff; % 2*87040 % 1024*(32+ops.ntbuff); + end + %% end gridsearch section -% write to Phy -fprintf('Saving results to Phy \n') -rezToPhy2(rez, myo_sorted_dir); -save(fullfile(myo_sorted_dir, '/ops.mat'), 'ops') + if trange(2) == 0 + ops.trange = [0 Inf]; + else + ops.trange = trange; + end -% delete(ops.fproc); + rez = preprocessDataSub(ops); + ops.channelDelays = rez.ops.channelDelays; + rez = datashift2(rez, 1); + [rez, st3, tF] = extract_spikes(rez); + %%% plots + % figure(5); + % plot(st3(:, 1), '.') + % title('Spike times versus spike ID') + % figure(6); + % plot(st3(:, 2), '.') + % title('Upsampled grid location of best template match spike ID') + % figure(7); + % plot(st3(:, 3), '.') + % title('Amplitude of template match for each spike ID') + % figure(8); hold on; + % plot(st3(:, 4), 'g.') + % for kSpatialDecay = 1:6 + % less_than_idx = find(st3(:, 4) < 6 * kSpatialDecay); + % more_than_idx = find(st3(:, 4) >= 6 * (kSpatialDecay - 1)); + % idx = intersect(less_than_idx, more_than_idx); + % bit_idx = bitand(st3(:, 4) < 6 * kSpatialDecay, st3(:, 4) >= 6 * (kSpatialDecay - 1)); + % plot(idx, st3(bit_idx, 4), '.') + % end + % title('Prototype templates for each spatial decay value (1:6:30) resulting in each best match spike ID') + % figure(9); + % plot(st3(:, 5), '.') + % title('Amplitude of template match for each spike ID (Duplicate of st3(:,3))') + % figure(10); + % plot(st3(:, 6), '.') + % title('Batch ID versus spike ID') + % figure(11); + % for iTemp = 1:size(tF, 2) + % subplot(size(tF, 2), 1, iTemp) + % plot(squeeze(tF(:, iTemp, :)), '.') + % end + %%% end plots + [rez, ~] = template_learning(rez, tF, st3); + [rez, st3, tF] = trackAndSort(rez); + % plot_templates_on_raw_data_fast(rez, st3); + rez = final_clustering(rez, tF, st3); + rez = find_merges(rez, 1); -quit; + % write to Phy + fprintf('Saving results to Phy \n') + rezToPhy2(rez, myo_sorted_dir); + save(fullfile(myo_sorted_dir, '/ops.mat'), '-struct', 'ops'); + + ops + + quit; +end diff --git a/sorting/Kilosort_run_myo_3_czuba.m b/sorting/Kilosort_run_myo_3_czuba.m new file mode 100644 index 00000000..651042d0 --- /dev/null +++ b/sorting/Kilosort_run_myo_3_czuba.m @@ -0,0 +1,176 @@ +function rez = Kilosort_run_myo_3_czuba(ops_input_params, worker_id, worker_dir) + script_dir = pwd; % get directory where repo exists + load(fullfile(script_dir, '/tmp/config.mat')) + load(fullfile(myo_sorted_dir, 'brokenChan.mat')) + + if num_KS_jobs > 1 + myo_sorted_dir = [myo_sorted_dir num2str(worker_id)]; + else + dbstop if error % stop if error, if only one job + end + + % get and set channel map + if ~isempty(brokenChan) && remove_bad_myo_chans(1) ~= false + chanMapFile = fullfile(myo_sorted_dir, 'chanMapAdjusted.mat'); + else + chanMapFile = myo_chan_map_file; + end + disp(['Using this channel map: ' chanMapFile]) + + % set paths + try + restoredefaultpath + end + addpath(genpath([script_dir '/sorting/Kilosort-3.0'])) + addpath(genpath([script_dir '/sorting/npy-matlab'])) + + % phyDir = 'sorted-czuba'; + % rootZ = [neuropixel_folder '/']; + % rootH = [rootZ phyDir '/']; + % mkdir(rootH); + + run([script_dir '/sorting/Kilosort_config_czuba.m']); + % ops.fbinary = fullfile(neuropixel); + % ops.fproc = fullfile(rootH, 'proc.dat'); + % ops.chanMap = fullfile(chanMapFile); + % ops.NchanTOT = 385; + % ops.saveDir = rootH; + ops.saveDir = myo_sorted_dir; % set directory for writes + ops.fbinary = fullfile(ops.saveDir, 'data.bin'); + ops.fproc = fullfile(ops.saveDir, 'proc.dat'); + ops.brokenChan = fullfile(ops.saveDir, 'brokenChan.mat'); + ops.chanMap = fullfile(chanMapFile); + ops.nt0 = 61; + ops.ntbuff = 1024; %ceil(bufferSec * ops.fs / 64) * 64; % ceil(batchSec/4*ops.fs/64)*64; % (def=64) + ops.NT = 2048 * 32 + ops.ntbuff; %ceil(batchSec * ops.fs / 32) * 32; % convert to 32 count increments of samples + ops.sigmaMask = Inf; % we don't want a distance-dependant decay + ops.nEig = double(num_KS_components); % rank of svd for templates, % keep same as nPCs to avoid error + ops.nPCs = ops.nEig; % how many PCs to project the spikes into (also used as number of template prototypes) + ops.NchanTOT = double(max(num_chans - length(brokenChan), ops.nEig)); + ops.Th = [10 4]; % threshold crossings for pre-clustering (in PCA projection space) + ops.spkTh = [-6]; % spike threshold in standard deviations (-6 default) (used in isolated_peaks_new/buffered and spikedetector3PC.cu) + ops.nfilt_factor = 12; % max number of clusters per good channel in a batch (even temporary ones) + ops.nblocks = 0; + ops.nt0min = ceil(ops.nt0 / 2); % peak of template match will be this many points away from beginning + ops.nskip = 1; % how many batches to skip for determining spike PCs + ops.nSkipCov = 1; % compute whitening matrix and prototype templates using every N-th batch + ops.lam = 10; % amplitude penalty (0 means not used, 10 is average, 50 is a lot) + ops.CAR = 0; % whether to perform CAR + ops.loc_range = [5 4]; % [timepoints channels], area to detect peaks; plus/minus for both time and channel. Doing abs() of data during peak isolation, so using 4 instead of default 5. Only 1 channel to avoid elimination of waves + ops.long_range = [ops.nt0min - 1 6]; % [timepoints channels], range within to use only the largest peak + ops.fig = 1; % whether to plot figures + ops.recordings = recordings; + ops.momentum = [20 400]; + ops.clipMin = 200; % clip template updating to a minimum number of contributing spikes + ops.clipMinFit = .8; + %batchSec = 10; % number of seconds in each batch (TBC: 8:10 seems good for 1-2 hr files and/or 32 channels) + %bufferSec = 2; % define number of seconds of data for buffer + % sample from batches more sparsely (in certain circumstances/analyses) + % batchSkips = ceil(60 / batchSec); % do high-level assessments at least once every minute of data + + %% gridsearch section + % will override the above ops struct values, if specified in Kilosort_gridsearch_config.py + + % make sure ops_input_params is a struct and fields are present + if isa(ops_input_params, 'struct') && ~isempty(fieldnames(ops_input_params)) + % Combine input ops into the existing ops struct + fields = fieldnames(ops_input_params); + for iField = 1:size(fields, 1) + ops.(fields{iField}) = ops_input_params.(fields{iField}); + end + % ops.NT = ops.nt0 * 32 + ops.ntbuff; % 2*87040 % 1024*(32+ops.ntbuff); + end + %% end gridsearch section + + disp(['Using ' ops.fbinary]) + + if trange(2) == 0 + ops.trange = [0 Inf]; + else + ops.trange = trange; + end + + % create parallel pool for all downstream parallel processing + pc = parcluster('local'); + pc.JobStorageLocation = worker_dir; + % ensure the number of processes across all workers does not exceed number of CPU cores + % num_processes = 2*round(feature('numcores')/num_KS_jobs); + % poolobj = parpool(pc, num_processes); + poolobj = parpool(pc); % let matlab decide how many workers to use + % ensure all parallel workers queues are cleared in the event of an error + cleanup_worker_obj = onCleanup(@() cleanup_worker(poolobj)); + + rez = preprocessDataSub(ops); + ops.channelDelays = rez.ops.channelDelays; + rez = datashift2(rez, 1); + [rez, st3, tF] = extract_spikes(rez); + if ops.fig + %%% plots + figure(5); + plot(st3(:, 1), '.') + title('Spike times versus spike ID') + figure(6); + plot(st3(:, 2), '.') + title('Upsampled grid location of best template match spike ID') + figure(7); + plot(st3(:, 3), '.') + title('Amplitude of template match for each spike ID') + figure(8); hold on; + plot(st3(:, 4), 'g.') + for kSpatialDecay = 1:6 + less_than_idx = find(st3(:, 4) < ops.nPCs * kSpatialDecay); + more_than_idx = find(st3(:, 4) >= ops.nPCs * (kSpatialDecay - 1)); + idx = intersect(less_than_idx, more_than_idx); + bit_idx = bitand(st3(:, 4) < ops.nPCs * kSpatialDecay, st3(:, 4) >= ops.nPCs * (kSpatialDecay - 1)); + plot(idx, st3(bit_idx, 4), '.') + end + title('Prototype templates for each spatial decay value (1:6:30) resulting in each best match spike ID') + figure(9); + plot(st3(:, 5), '.') + title('Amplitude of template match for each spike ID (Duplicate of st3(:,3))') + figure(10); + plot(st3(:, 6), '.') + title('Batch ID versus spike ID') + figure(11); + for iPC = 1:size(tF, 2) + subplot(size(tF, 2), 1, iPC) + plot(squeeze(tF(:, iPC, :)), '.') + end + title('PC Weights for each Spike Example, Colored by Channel') + xlabel('Spike Examples') + ylabel('Principal Component Weight') + %%% end plots + end + [rez, ~] = template_learning(rez, tF, st3); + [rez, st3, tF] = trackAndSort(rez); + % keyboard + % plot_templates_on_raw_data_fast(rez, st3); + rez = final_clustering(rez, tF, st3); + rez = find_merges(rez, 1); + + % write to Phy + disp(['Saving sorting results to Phy in', ops.saveDir]) + rezToPhy2(rez, ops.saveDir); + + disp(['Saving rez and ops structs to', ops.saveDir]) + ops % show final ops struct in command window + rez % show final rez struct in command window + + % save variables as full struct, for MATLAB + save(fullfile(ops.saveDir, '/ops_struct.mat'), 'ops'); + save(fullfile(ops.saveDir, '/rez_struct.mat'), 'rez'); + + % save variables without struct, for python + save(fullfile(ops.saveDir, '/ops.mat'), '-struct', 'ops'); + rez.ops = []; rez.temp = []; % remove substructs from rez struct before saving + save(fullfile(ops.saveDir, '/rez.mat'), '-struct', 'rez'); +end + +% cleanup function to ensure all parallel workers queues are cleared +function cleanup_worker(poolobj) + % check if parallel pool processes exist + if ~isempty(poolobj) + delete(poolobj) + end + quit; % exit matlab to return to python +end diff --git a/sorting/myomatrix/concatenate_myo_data.m b/sorting/myomatrix/concatenate_myo_data.m index b3dd09c2..877c96ad 100644 --- a/sorting/myomatrix/concatenate_myo_data.m +++ b/sorting/myomatrix/concatenate_myo_data.m @@ -1,48 +1,66 @@ -function concatenate_myo_data(myomatrix_folder) +function concatenate_myo_data(myomatrix_folder, recordings_to_concatenate) listing = struct2cell(dir(myomatrix_folder)); subdir = listing(1, :); - recordNodeFiles = []; + recordNodeFolders = []; dbstop if error % Determine Record Node folder for i = 1:length(subdir) if (contains(subdir(i), 'Record Node')) - recordNodeFiles = [recordNodeFiles, subdir(i)]; + recordNodeFolders = [recordNodeFolders, subdir(i)]; end end - for i = 1:length(recordNodeFiles) - currNode = strcat(myomatrix_folder, '/', recordNodeFiles{i}); + if length(recordNodeFolders) > 1 + error("Multiple 'Record Node' folders found in the myomatrix folder. Please remove all but one.") + end + for i = 1:length(recordNodeFolders) + currNode = strcat(myomatrix_folder, '/', recordNodeFolders{i}); folders = struct2cell(dir(currNode)); subdir = folders(1, :); - experimentFiles = []; + experimentFolders = []; % Determine experiment folders for j = 1:length(subdir) if (startsWith(subdir(j), 'experiment')) - experimentFiles = [experimentFiles, subdir(j)]; + experimentFolders = [experimentFolders, subdir(j)]; end end - for k = 1:length(experimentFiles) - currExp = strcat(currNode, '/', experimentFiles{k}); + if length(experimentFolders) > 1 + error("Multiple 'experiment' folders found in the Record Node folder. Please remove all but one.") + end + for k = 1:length(experimentFolders) + currExp = strcat(currNode, '/', experimentFolders{k}); d = dir(currExp); - d = d(~ismember({d.name}, {'.', '..'})); + d = d(~ismember({d.name}, {'.', '..', 'concatenated_data'})); folders = struct2cell(d); subdir = folders(1, :); - disp("Concatenating: ") - disp(subdir) - recordingFiles = []; + disp("Concatenating recordings: ") + disp(recordings_to_concatenate{1}) + recordingFolders = []; % Determine recording folders - for l = 1:length(subdir) - if (startsWith(subdir(l), 'recording') && ~contains(subdir(l), '99')) - recordingFiles = [recordingFiles, subdir(l)]; + if class(recordings_to_concatenate{1}) == "char" && recordings_to_concatenate{1} == "all" + for iRec = 1:length(subdir) + if subdir(iRec)==strcat('recording', iRec) + recordingFolders = [recordingFolders, subdir(iRec)]; + end end + elseif class(recordings_to_concatenate{1}) == "double" + rep_str=repmat('recording',length(recordings_to_concatenate),1); + recording_str_array = string(cellstr(strcat(rep_str, num2str(recordings_to_concatenate{1}')))); + for iRec = 1:length(subdir) + if ismember(subdir(iRec),recording_str_array) + recordingFolders = [recordingFolders, subdir(iRec)]; + end + end + else + error("Recordings to concatenate must be either 'all' or a double array.") end continuousFiles = []; - for m = 1:length(recordingFiles) - currRecording = strcat(currExp, '/', recordingFiles{m}, '/continuous', '/Acquisition_Board-100.Rhythm Data'); + for m = 1:length(recordingFolders) + currRecording = strcat(currExp, '/', recordingFolders{m}, '/continuous', '/Acquisition_Board-100.Rhythm Data'); if ~isfolder(currRecording) - currRecording = strcat(currExp, '/', recordingFiles{m}, '/continuous', '/Rhythm_FPGA-100.0'); + currRecording = strcat(currExp, '/', recordingFolders{m}, '/continuous', '/Rhythm_FPGA-100.0'); if ~isfolder(currRecording) error('Folder %s does not exist.', currRecording) end @@ -68,8 +86,8 @@ function concatenate_myo_data(myomatrix_folder) % open continuous.dat file for writing rhythmFolderNameCellArray = split(currRecording, '/'); rhythmFolderName = string(rhythmFolderNameCellArray(end)); % get last array element - concatenated_data_dir = strcat(currExp, '/concatenated_data/'); - continuous_folder = strcat(concatenated_data_dir, 'continuous/', rhythmFolderName); + concatenated_data_dir = strcat(currExp, '/concatenated_data/', join(string(recordings_to_concatenate{1}),',')); + continuous_folder = strcat(concatenated_data_dir, '/continuous/', rhythmFolderName); [~, ~, ~] = mkdir(continuous_folder); fid2 = fopen(strcat(continuous_folder, '/continuous.dat'), 'w'); end @@ -78,9 +96,9 @@ function concatenate_myo_data(myomatrix_folder) fwrite(fid2, outputDat, 'int16'); fclose(fid2); % copy a structure.oebin into concatenated_data folder - lastRecordingFolder = strcat(currExp, '/', recordingFiles{m}); - copyfile(strcat(lastRecordingFolder, '/structure.oebin'), concatenated_data_dir) - disp("Data from " + length(recordingFiles) + " files concatenated together"); - quit + lastRecordingFolder = strcat(currExp, '/', recordingFolders{m}); + copyfile(strcat(lastRecordingFolder, '/structure.oebin'), strcat(concatenated_data_dir,'/structure.oebin')); + disp("Data from " + length(recordingFolders) + " files concatenated together"); end + quit end diff --git a/sorting/myomatrix/myomatrix_binary.m b/sorting/myomatrix/myomatrix_binary.m index 759b4e90..504c180b 100644 --- a/sorting/myomatrix/myomatrix_binary.m +++ b/sorting/myomatrix/myomatrix_binary.m @@ -10,12 +10,11 @@ chanList = chans(1):chans(2); disp(['Starting with these channels: ' num2str(chanList)]) -chanMapFile = myo_chan_map_file; -disp(['Using this channel map: ' chanMapFile]) +disp(['Using this channel map: ' myo_chan_map_file]) dataChan = chanList; -if not(isfolder([myomatrix '/sorted' num2str(myomatrix_num) '/'])) - mkdir([myomatrix '/sorted' num2str(myomatrix_num) '/']); +if not(isfolder([myo_sorted_dir '/'])) + mkdir([myo_sorted_dir '/']); end % Check if we're dealing with .dat or .continuous @@ -45,37 +44,56 @@ ops.trange = trange * myo_data_sampling_rate + 1; end data = tempdata.Data.Data(1).mapped(dataChan, ops.trange(1):ops.trange(2))'; - analogData = tempdata.Data.Data(1).mapped(sync_chan, ops.trange(1):ops.trange(2))'; - analogData(analogData < 10000) = 0; - analogData(analogData >= 10000) = 1; + try + analogData = tempdata.Data.Data(1).mapped(sync_chan, ops.trange(1):ops.trange(2))'; + catch ME % to avoid "Index in position 1 exceeds array bounds (must not exceed XX)." + if strcmp(ME.identifier, 'MATLAB:badsubscript') + disp("No sync channel found, cannot save sync data") + analogData = []; + else + rethrow(ME) + end + end + if ~isempty(analogData) + analogData(analogData < 10000) = 0; + analogData(analogData >= 10000) = 1; + end clear tempdata end if length(dataChan) == 32 data = data(:, channelRemap); end +if ~isempty(analogData) + analogData(analogData > 5) = 5; + sync = logical(round(analogData / max(analogData))); + clear analogData -analogData(analogData > 5) = 5; -sync = logical(round(analogData / max(analogData))); -clear analogData - -save([myomatrix '/sync'], 'sync') -clear sync -disp('Saved sync data') + save([myomatrix '/sync'], 'sync') + clear sync + disp('Saved sync data') +end disp(['Total recording time: ' num2str(size(data, 1) / myo_data_sampling_rate / 60) ' minutes']) clf -S = zeros(size(data, 2), 2); +S = zeros(size(data, 2), 3); bipolarThresh = 90; unipolarThresh = 120; lowThresh = 0.1; -bipolar = length(chanList) == 16; -for q = 1:2 +% bipolar = length(chanList) == 16; +% when q is 1, we will compute count the number of spikes in the channel and compare to a threshold +% when q is 2, we will compute the std of the low freq noise in the channel +% when q is 3, we will compute the SNR of the channel +for q = 1:4 if q == 1 [b, a] = butter(2, [250 4400] / (myo_data_sampling_rate / 2), 'bandpass'); elseif q == 2 - [b, a] = butter(2, [5 70] / (myo_data_sampling_rate / 2), 'bandpass'); + [b, a] = butter(2, [5 100] / (myo_data_sampling_rate / 2), 'bandpass'); + elseif q == 3 + [b, a] = butter(2, 10000 / (myo_data_sampling_rate / 2), 'high'); + elseif q == 4 + [b, a] = butter(2, [300 1000] / (myo_data_sampling_rate / 2), 'bandpass'); end useSeconds = 600; if size(data, 1) < useSeconds * 2 * myo_data_sampling_rate @@ -83,128 +101,241 @@ end tRange = size(data, 1) - round(size(data, 1) / 2) - round(myo_data_sampling_rate * useSeconds / 2):size(data, 1) ... - round(size(data, 1) / 2) + round(myo_data_sampling_rate * useSeconds / 2); + data_norm = zeros(length(tRange), size(data, 2), 'single'); data_filt = zeros(length(tRange), size(data, 2), 'single'); for i = 1:size(data, 2) - data_filt(:, i) = single(filtfilt(b, a, double(data(tRange, i)))); - end - - if q == 2 - S(:, q) = std(data_filt, [], 1); - else - data_norm = data_filt ./ repmat(std(data_filt, [], 1), [size(data_filt, 1) 1]); - spk = sum(data_norm < -7, 1); - S(:, q) = spk / size(data_norm, 1) * myo_data_sampling_rate; + % standardize this data channel before filtering, but make sure not to divide by zero + chan_std = std(single(data(tRange, i))); + if chan_std == 0 + data_norm(:, i) = single(data(tRange, i)); + else + data_norm(:, i) = single(data(tRange, i)) ./ chan_std; + end + % data_norm(:, i) = single(data(tRange, i)) ./ std(single(data(tRange, i))); + % filter this data channel + data_filt(:, i) = single(filtfilt(b, a, double(data_norm(:, i)))); end - subplot(1, 2, q) if q == 1 - title('Filtered Signal Snippet (250-4400Hz)') - else - title('Filtered Noise Snippet (5-70Hz)') - end - hold on - for i = 1:size(data, 2) - cmap = [0 0 0]; - if q == 1 - if S(i, 1) < lowThresh - cmap = [1 0.2 0.2]; - end + % normalize channels by std + data_filt_norm = data_filt ./ repmat(std(data_filt, [], 1), [size(data_filt, 1) 1]); + spk = sum(data_filt_norm < -7, 1); % check for spikes crossing 7 std below mean + S(:, q) = spk / size(data_filt_norm, 1) * myo_data_sampling_rate; + elseif q == 2 + S(:, q) = std(data_filt, [], 1); % get the std of the low freq noise + % data_filt_norm = data_filt ./ repmat(S(:, q)', [size(data_filt, 1) 1]); % standardize + low_band_power = rms(data_filt, 1) .^ 2; + elseif q == 3 + S(:, q) = std(data_filt, [], 1); % get the std of the high freq noise + % data_filt_norm = data_filt ./ repmat(S(:, q)', [size(data_filt, 1) 1]); % standardize + high_band_power = rms(data_filt, 1) .^ 2; + elseif q == 4 + % data_filt_norm = data_filt ./ repmat(std(data_filt, [], 1), [size(data_filt, 1) 1]); % standardize + spike_band_power = rms(data_filt, 1) .^ 2; + SNR = spike_band_power ./ (low_band_power + high_band_power); + % replace any NaNs with 0 + SNR(isnan(SNR)) = 0; + [~, idx] = sort(SNR, 'ascend'); + mean_SNR = mean(SNR); + std_SNR = std(SNR); + median_SNR = median(SNR); + % get a MAD value for each channel + MAD = median(abs(data_filt-mean(data_filt, 1)), 1); + Gaussian_STDs = MAD / 0.6745; + disp("Gaussian STDs: " + num2str(Gaussian_STDs)) + if isa(remove_bad_myo_chans, "char") + rejection_criteria = remove_bad_myo_chans; else - if (bipolar && S(i, 2) > bipolarThresh) || (~bipolar && S(i, 2) > unipolarThresh) - cmap = [1 0.2 0.2]; + rejection_criteria = 'median'; + end + disp("Using " + rejection_criteria + " threshold for SNR rejection criteria") + + % check by what criteria we should reject channels + if strcmp(rejection_criteria, 'median') + % reject channels with SNR < median + SNR_reject_chans = SNR < median_SNR; + elseif strcmp(rejection_criteria, 'mean') + % reject channels with SNR < mean + SNR_reject_chans = SNR < mean_SNR; + elseif strcmp(rejection_criteria, 'mean-1std') + % reject channels with SNR < mean - std + SNR_reject_chans = SNR < mean_SNR - std_SNR; + elseif startsWith(rejection_criteria, 'percentile') + % ensure that the percentile is numeric and between 0 and 100 + percentile = str2double(rejection_criteria(11:end)); + if isnan(percentile) || percentile < 0 || percentile > 100 + error("Error with 'remove_bad_myo_chans' setting in config.yaml. Numeric value after 'percentile' must be between 0 and 100") + end + % reject channels with SNR < Nth percentile + percentile_SNR = prctile(SNR, percentile); + SNR_reject_chans = SNR < percentile_SNR; + elseif startsWith(rejection_criteria, 'lowest') + % ensure that the number of channels to reject is numeric and less than the number of channels + N_reject = str2double(rejection_criteria(7:end)); + if isnan(N_reject) || N_reject < 0 || N_reject > length(chanList) + error("Error with 'remove_bad_myo_chans' setting in config.yaml. Numeric value after 'lowest' must be between 0 and " + length(chanList)) end + % reject N_reject lowest SNR channels + SNR_reject_chans = idx(1:N_reject); end - plot(data_filt(:, i) + i * 1600, 'Color', cmap) + + % [~, idx] = sort(SNR, 'ascend'); + % idx = idx(1:floor(length(idx) / 2)); + % bitmask = zeros(length(chanList), 1); + % bitmask(idx) = 1; + % SNR_reject_chans = chanList(bitmask == 1); + disp("SNRs: " + num2str(SNR)) + disp("Mean +/- Std. SNR: " + num2str(mean_SNR) + " +/- " + num2str(std_SNR)) + disp("Median SNR: " + num2str(median_SNR)) + disp("Channels with rejectable SNRs: " + num2str(SNR_reject_chans)) end - set(gca, 'YTick', (1:size(data, 2)) * 1600, 'YTickLabels', 1:size(data, 2)) - axis([1 size(data_filt, 1) 0 (size(data, 2) + 1) * 1600]) -end -print([myomatrix '/sorted' num2str(myomatrix_num) '/brokenChan.png'], '-dpng') -S -if length(chanList) == 16 - brokenChan = int64(find(S(:, 2) > bipolarThresh | S(:, 1) < lowThresh)); -else - brokenChan = int64(find(S(:, 2) > unipolarThresh | S(:, 1) < lowThresh)); + % subplot(1, 4, q) + % if q == 1 + % title('Filtered Signal Snippet (250-4400Hz)') + % elseif q == 2 + % title('Filtered Noise Snippet (5-70Hz)') + % end + % hold on + % for i = 1:size(data, 2) + % cmap = [0 0 0]; + % if q == 1 + % if S(i, 1) < lowThresh + % cmap = [1 0.2 0.2]; + % end + % elseif q == 2 + % if (bipolar && S(i, 2) > bipolarThresh) || (~bipolar && S(i, 2) > unipolarThresh) + % cmap = [1 0.2 0.2]; + % end + % end + % plot(data_filt(:, i) + i * 1600, 'Color', cmap) + % end + % set(gca, 'YTick', (1:size(data, 2)) * 1600, 'YTickLabels', 1:size(data, 2)) + % axis([1 size(data_filt, 1) 0 (size(data, 2) + 1) * 1600]) end -disp(['Automatically detected broken/inactive channels are: ' num2str(brokenChan')]) +print([myo_sorted_dir '/brokenChan.png'], '-dpng') + +%if length(chanList) == 16 +% % check for broken channels if meeting various criteria, including: high std, low spike rate, low SNR. Eliminate if any true +% brokenChan = int64(union(find(S(:, 2) > bipolarThresh | S(:, 1) < lowThresh), SNR_reject_chans)); %S(:, 3) > bipolarThresh +%else +% brokenChan = int64(union(find(S(:, 2) > unipolarThresh | S(:, 1) < lowThresh), SNR_reject_chans)); %S(:, 3) > unipolarThresh +%end +brokenChan = SNR_reject_chans'; +disp(['Automatically detected rejectable channels are: ' num2str(brokenChan')]) % now actually remove the detected broken channels if True % if a list of broken channels is provided, use that instead % if false, just continue -if isa(remove_bad_myo_chans(1), 'logical') +if isa(remove_bad_myo_chans(1), 'logical') || isa(remove_bad_myo_chans, 'char') if remove_bad_myo_chans(1) == false - if length(brokenChan) > 0 - disp('Broken/inactive channels detected, but not removing them, because remove_bad_myo_chans is false') - elseif length(brokenChan) == 0 - disp('No broken/inactive channels detected, not removing any, because remove_bad_myo_chans is false') - end - disp(['Keeping channel list: ' num2str(chanList)]) - elseif remove_bad_myo_chans(1) == true - disp('Removing automatically detected broken/inactive channels') + brokenChan = []; + disp('Not removing any broken/noisy channels, because remove_bad_myo_chans is false') + % disp(['Keeping channel list: ' num2str(chanList)]) + elseif remove_bad_myo_chans(1) == true || isa(remove_bad_myo_chans, 'char') data(:, brokenChan) = []; chanList(brokenChan) = []; + disp(['Just removed automatically detected broken/noisy channels: ' num2str(brokenChan')]) disp(['New channel list is: ' num2str(chanList)]) end elseif isa(remove_bad_myo_chans, 'integer') - brokenChan = remove_bad_myo_chans; - disp(['Removing manually provided broken/inactive channels: ' num2str(brokenChan)]) + brokenChan = remove_bad_myo_chans; % overwrite brokenChan with manually provided list data(:, brokenChan) = []; chanList(brokenChan) = []; + disp(['Just removed manually provided broken/noisy channels: ' num2str(brokenChan)]) disp(['New channel list is: ' num2str(chanList)]) else - error('remove_bad_myo_chans must be a boolean or an integer list of broken channels') + error('remove_bad_myo_chans must be a boolean, string with SNR rejection method, or an integer list of channels to remove') end -save([myomatrix '/sorted' num2str(myomatrix_num) '/chanList.mat'], 'chanList') -save([myomatrix '/sorted' num2str(myomatrix_num) '/brokenChan.mat'], 'brokenChan'); -clear data_filt data_norm +save([myo_sorted_dir '/chanList.mat'], 'chanList') +save([myo_sorted_dir '/brokenChan.mat'], 'brokenChan'); -fileID = fopen([myomatrix '/sorted' num2str(myomatrix_num) '/data.bin'], 'w'); -if true - disp("Filtering raw data with passband:") - disp(strcat(string(myo_data_passband(1)), "-", string(myo_data_passband(2)), " Hz")) - mean_data = mean(data, 1); - [b, a] = butter(4, myo_data_passband / (myo_data_sampling_rate / 2), 'bandpass'); - intervals = round(linspace(1, size(data, 1), round(size(data, 1) / (myo_data_sampling_rate * 5)))); - buffer = 128; - for t = 1:length(intervals) - 1 - preBuff = buffer; postBuff = buffer; - if t == 1 - preBuff = 0; - elseif t == length(intervals) - 1 - postBuff = 0; - end - tRange = intervals(t) - preBuff:intervals(t + 1) + postBuff; - fdata = double(data(tRange, :)) - mean_data; - fdata = fdata - median(fdata, 2); - fdata = filtfilt(b, a, fdata); - fdata = fdata(preBuff + 1:end - postBuff - 1, :); - % fdata(:, brokenChan) = randn(size(fdata(:, brokenChan))) * 5; - fwrite(fileID, int16(fdata'), 'int16'); - end +% load and modify channel map variables to remove broken channel elements, if desired +if ~isempty(brokenChan) && remove_bad_myo_chans(1) ~= false + load(myo_chan_map_file) + % if size(data, 2) >= num_KS_components + % chanMap(brokenChan) = []; % take off end to save indexing + % chanMap0ind(brokenChan) = []; % take off end to save indexing + % connected(brokenChan) = []; + % kcoords(brokenChan) = []; + % xcoords(brokenChan) = []; + % ycoords(brokenChan) = []; + % else + numDummy = max(0, num_KS_components - size(data, 2)); % make sure it's not negative + dummyData = zeros(size(data, 1), numDummy, 'int16'); + data = [data dummyData]; % add dummy channels to make size larger than num_KS_components + chanMap = 1:size(data, 2); + chanMap0ind = chanMap - 1; + connected = true(size(data, 2), 1); + kcoords = ones(size(data, 2), 1); + xcoords = zeros(size(data, 2), 1); + ycoords = (size(data, 2):-1:1)'; + % end + disp('Broken channels were just removed from that channel map') + save(fullfile(myo_sorted_dir, 'chanMapAdjusted.mat'), 'chanMap', 'connected', 'xcoords', ... + 'ycoords', 'kcoords', 'chanMap0ind', 'fs', 'name', 'numDummy', 'Gaussian_STDs') else - data(:, brokenChan) = randn(size(data(:, brokenChan))) * 5; - fwrite(fileID, int16(data'), 'int16'); + copyfile(myo_chan_map_file, fullfile(myo_sorted_dir, 'chanMapAdjusted.mat')) + % add numDummy to chanMapAdjusted.mat + load(fullfile(myo_sorted_dir, 'chanMapAdjusted.mat')) + numDummy = 0; + save(fullfile(myo_sorted_dir, 'chanMapAdjusted.mat'), 'chanMap', 'connected', 'xcoords', ... + 'ycoords', 'kcoords', 'chanMap0ind', 'fs', 'name', 'numDummy', 'Gaussian_STDs') end -fclose(fileID); -if false - % Generate "Bulk EMG" dataset - notBroken = 1:size(data, 2); - notBroken(brokenChan) = []; - if length(dataChan) == 32 - bottomHalf = [9:16 25:32]; - topHalf = [1:8 17:24]; - bottomHalf(ismember(bottomHalf, brokenChan)) = []; - topHalf(ismember(topHalf, brokenChan)) = []; - bEMG = int16(mean(data(:, bottomHalf), 2)) - int16(mean(data(:, topHalf), 2)); - else - bEMG = int16(mean(data(:, notBroken), 2)); +clear data_filt data_norm + +fileID = fopen([myo_sorted_dir '/data.bin'], 'w'); +% if true +disp("Filtering raw data with passband:") +disp(strcat(string(myo_data_passband(1)), "-", string(myo_data_passband(2)), " Hz")) +mean_data = mean(data, 1); +[b, a] = butter(4, myo_data_passband / (myo_data_sampling_rate / 2), 'bandpass'); +intervals = round(linspace(1, size(data, 1), round(size(data, 1) / (myo_data_sampling_rate * 5)))); +if numDummy > 0 + chanIdxsToFilter = 1:num_KS_components-numDummy; +else + chanIdxsToFilter = 1:size(data, 2); +end +buffer = 128; +% now write the data to binary file in chunks of 5 seconds, but exclude dummy channels +for t = 1:length(intervals) - 1 + preBuff = buffer; postBuff = buffer; + if t == 1 + preBuff = 0; + elseif t == length(intervals) - 1 + postBuff = 0; end - save([myomatrix '/sorted' num2str(myomatrix_num) '/bulkEMG'], 'bEMG', 'notBroken', 'dataChan') - clear bEMG - disp('Saved generated bulk EMG') + tRange = intervals(t) - preBuff:intervals(t + 1) + postBuff; + fdata = double(data(tRange, :)) - mean_data; + fdata(:, chanIdxsToFilter) = fdata(:, chanIdxsToFilter) - median(fdata(:, chanIdxsToFilter), 1); + fdata = filtfilt(b, a, fdata); + fdata = fdata(preBuff + 1:end - postBuff - 1, :); + % fdata(:, brokenChan) = randn(size(fdata(:, brokenChan))) * 5; + fwrite(fileID, int16(fdata'), 'int16'); end +% else +% data(:, brokenChan) = randn(size(data(:, brokenChan))) * 5; +% fwrite(fileID, int16(data'), 'int16'); +% end +fclose(fileID); +% if false +% % Generate "Bulk EMG" dataset +% notBroken = 1:size(data, 2); +% notBroken(brokenChan) = []; +% if length(dataChan) == 32 +% bottomHalf = [9:16 25:32]; +% topHalf = [1:8 17:24]; +% bottomHalf(ismember(bottomHalf, brokenChan)) = []; +% topHalf(ismember(topHalf, brokenChan)) = []; +% bEMG = int16(mean(data(:, bottomHalf), 2)) - int16(mean(data(:, topHalf), 2)); +% else +% bEMG = int16(mean(data(:, notBroken), 2)); +% end +% save([myo_sorted_dir '/bulkEMG'], 'bEMG', 'notBroken', 'dataChan') +% clear bEMG +% disp('Saved generated bulk EMG') +% end disp('Saved myomatrix data binary') quit diff --git a/sorting/resorter/myomatrix_call.m b/sorting/resorter/myomatrix_call.m index 079305df..e0b6f56b 100644 --- a/sorting/resorter/myomatrix_call.m +++ b/sorting/resorter/myomatrix_call.m @@ -2,9 +2,13 @@ load(fullfile(script_dir, '/tmp/config.mat')) load(fullfile(myo_sorted_dir, 'brokenChan.mat')) +% % set GPU to use +% disp(strcat("Setting GPU device to use: ", num2str(GPU_to_use))) +% gpuDevice(GPU_to_use); + % load channel map with broken channels removed if chosen by user if length(brokenChan) > 0 && remove_bad_myo_chans(1) ~= false - load(fullfile(myo_sorted_dir, 'chanMap_minus_brokenChans.mat')) + load(fullfile(myo_sorted_dir, 'chanMapAdjusted.mat')) else load(myo_chan_map_file) end @@ -19,8 +23,8 @@ params.SNRThresh = 2.0; params.corrThresh = 0.9; % minimum correlation to be considered as originating from one cluster params.consistencyThresh = 0.6; % minimum consistency to be considered as originating from one cluster -params.spikeCountLim = 10; % minimum spike count to be included in output -params.refractoryLim = 1; % spikes below this refractory time limit will be considered duplicates +params.spikeCountLim = 100; % minimum spike count to be included in output +params.refractoryLim = 0.5; % spikes below this refractory time limit will be considered duplicates % make sure a sorting exists if isfile([myo_sorted_dir '/spike_times.npy']) diff --git a/sorting/resorter/resorter.m b/sorting/resorter/resorter.m index 7b9226c3..1f13e93f 100644 --- a/sorting/resorter/resorter.m +++ b/sorting/resorter/resorter.m @@ -271,7 +271,7 @@ function resorter(params) % Remove clusters that don't meet inclusion criteria mdata_orig = mdata; if keepGoing % save intermediate merges - save_dir_for_merges = ['/intermediate_merge' num2str(loopCount)]; + % save_dir_for_merges = ['/intermediate_merge' num2str(loopCount)]; [T, ascending_idxs] = sort(T); % sort to make times monotonic I = I(ascending_idxs); C = sort(unique(I)); @@ -294,31 +294,37 @@ function resorter(params) consistency.wave = consistency.wave(:, :, :, saveUnits); consistency.channel = consistency.channel(:, saveUnits); end - if not(isfolder([params.kiloDir save_dir_for_merges])) - mkdir([params.kiloDir save_dir_for_merges]) - end templates = permute(mdata_orig, [3 1 2]); % now it's nTemplates x nSamples x nChannels templatesInds = repmat([0:size(templates, 3) - 1], size(templates, 1), 1); % we include all channels so this is trivial - disp(['Number of clusters: ' num2str(length(C))]) - disp(['Number of spikes: ' num2str(length(I))]) - disp(['Saving custom-merged data for Phy to: ' params.kiloDir save_dir_for_merges]) - - % write all files to save_dir_for_merges - save([params.kiloDir save_dir_for_merges '/custom_merge.mat'], 'T', 'I', 'C', 'mdata', 'SNR', 'consistency'); - writeNPY(uint64(T), [params.kiloDir save_dir_for_merges '/spike_times.npy']); - writeNPY(uint32(I - 1), [params.kiloDir save_dir_for_merges '/spike_templates.npy']); % -1 for zero indexing - writeNPY(single(templates), [params.kiloDir save_dir_for_merges '/templates.npy']); - writeNPY(double(templatesInds), [params.kiloDir save_dir_for_merges '/templates_ind.npy']); - copyfile([params.kiloDir '/../whitening_mat.npy'], [params.kiloDir save_dir_for_merges '/whitening_mat.npy']) - copyfile([params.kiloDir '/../whitening_mat_inv.npy'], [params.kiloDir save_dir_for_merges '/whitening_mat_inv.npy']) - copyfile([params.kiloDir '/../channel_map.npy'], [params.kiloDir save_dir_for_merges '/channel_map.npy']) - copyfile([params.kiloDir '/../channel_positions.npy'], [params.kiloDir save_dir_for_merges '/channel_positions.npy']) - copyfile([params.kiloDir '/../params.py'], [params.kiloDir save_dir_for_merges '/params.py']) - - % count the number of intermediate merges - loopCount = loopCount + 1; + if keepGoing + % count the number of intermediate merges + loopCount = loopCount + 1; + else + if isfolder([params.kiloDir save_dir_for_merges]) + rmdir([params.kiloDir save_dir_for_merges], 's') + mkdir([params.kiloDir save_dir_for_merges]) + else + mkdir([params.kiloDir save_dir_for_merges]) + end + + disp(['Number of clusters: ' num2str(length(C))]) + disp(['Number of spikes: ' num2str(length(I))]) + disp(['Saving custom-merged data for Phy to: ' params.kiloDir save_dir_for_merges]) + + % write all files to save_dir_for_merges + save([params.kiloDir save_dir_for_merges '/custom_merge.mat'], 'T', 'I', 'C', 'mdata', 'SNR', 'consistency'); + writeNPY(uint64(T), [params.kiloDir save_dir_for_merges '/spike_times.npy']); + writeNPY(uint32(I - 1), [params.kiloDir save_dir_for_merges '/spike_templates.npy']); % -1 for zero indexing + writeNPY(single(templates), [params.kiloDir save_dir_for_merges '/templates.npy']); + writeNPY(double(templatesInds), [params.kiloDir save_dir_for_merges '/templates_ind.npy']); + copyfile([params.kiloDir '/../whitening_mat.npy'], [params.kiloDir save_dir_for_merges '/whitening_mat.npy']) + copyfile([params.kiloDir '/../whitening_mat_inv.npy'], [params.kiloDir save_dir_for_merges '/whitening_mat_inv.npy']) + copyfile([params.kiloDir '/../channel_map.npy'], [params.kiloDir save_dir_for_merges '/channel_map.npy']) + copyfile([params.kiloDir '/../channel_positions.npy'], [params.kiloDir save_dir_for_merges '/channel_positions.npy']) + copyfile([params.kiloDir '/../params.py'], [params.kiloDir save_dir_for_merges '/params.py']) + end end disp('Finished merging clusters') @@ -348,13 +354,22 @@ function resorter(params) set(gca, 'YLim', [min(ycoords) * yScale - inc max(ycoords) * yScale + inc]) if params.savePlots - if ~exist([params.kiloDir '/Plots'], 'dir') + if exist([params.kiloDir '/Plots'], 'dir') + rmdir([params.kiloDir '/Plots'], 's') + mkdir([params.kiloDir '/Plots']) + else mkdir([params.kiloDir '/Plots']) end - if ~exist([params.kiloDir '/Plots/svg/'], 'dir') + if exist([params.kiloDir '/Plots/svg/'], 'dir') + rmdir([params.kiloDir '/Plots/svg/']) + mkdir([params.kiloDir '/Plots/svg/']) + else mkdir([params.kiloDir '/Plots/svg/']) end - if ~exist([params.kiloDir '/Plots/png/'], 'dir') + if exist([params.kiloDir '/Plots/png/'], 'dir') + rmdir([params.kiloDir '/Plots/png/']) + mkdir([params.kiloDir '/Plots/png/']) + else mkdir([params.kiloDir '/Plots/png/']) end print([params.kiloDir '/Plots/png/' num2str(j) '.png'], '-dpng') @@ -503,7 +518,7 @@ function resorter(params) elseif nChan == 16 grabChannels = 8; else - grabChannels = 8; + grabChannels = nChan; end tempm = squeeze(nanmean(tempdata, 3)); diff --git a/sorting/spike_validation_plot.m b/sorting/spike_validation_plot.m index 9a7ce89f..4cba5755 100644 --- a/sorting/spike_validation_plot.m +++ b/sorting/spike_validation_plot.m @@ -5,7 +5,7 @@ function spike_validation_plot(chunk, clusters) disp(['Using this channel map: ' myo_chan_map_file]) % load channel map with broken channels removed if chosen by user if length(brokenChan) > 0 && remove_bad_myo_chans(1) ~= false - load(fullfile(myo_sorted_dir, 'chanMap_minus_brokenChans.mat')) + load(fullfile(myo_sorted_dir, 'chanMapAdjusted.mat')) else load(myo_chan_map_file) end @@ -21,29 +21,29 @@ function spike_validation_plot(chunk, clusters) channels = chanMap; % Input: provide the path to the custom_merge.mat file. - fid = fopen(processed_ephys_data_path,'r'); + fid = fopen(processed_ephys_data_path, 'r'); data_1D = fread(fid, 'int16'); fclose(fid); - load(final_sort_path,'C','I','T','mdata'); + load(final_sort_path, 'C', 'I', 'T', 'mdata'); spike_times = T; cluster_ID = I; mdata_full = mdata; if isa(clusters, 'logical') && clusters == true disp("Showing all clusters.") else - C = intersect(C, clusters+1); - disp("Showing clusters: " + num2str(clusters+1)) + C = intersect(C, clusters + 1); + disp("Showing clusters: " + num2str(clusters + 1)) end - mdata = mdata_full(:,:,C); + mdata = mdata_full(:, :, C); mdata_size = size(mdata_full); template_width = mdata_size(1); num_chans = length(channels); dbstop if errors - data = reshape(data_1D,num_chans,length(data_1D)/num_chans)'; - data = data(:,channels); + data = reshape(data_1D, num_chans, length(data_1D) / num_chans)'; + data = data(:, channels); chan_cmap = gray(32); - chan_cmap = repmat(chan_cmap(16:32)',3); % get lighter half + chan_cmap = repmat(chan_cmap(16:32)', 3); % get lighter half clust_cmap = prism(length(C)); % get full rainbow % figure(1) @@ -51,54 +51,54 @@ function spike_validation_plot(chunk, clusters) % title('Spike Counts for Each Template ID') % xlabel('Template ID') % ylabel('Spike Counts') - - figure('CloseRequestFcn',@my_closereq); hold on - data_mins = min(data(chunk_index_range,:)); - data_maxs = max(data(chunk_index_range,:)); + + figure('CloseRequestFcn', @my_closereq); hold on + data_mins = min(data(chunk_index_range, :)); + data_maxs = max(data(chunk_index_range, :)); data_ranges = data_maxs - data_mins; - norm_data = (data(chunk_index_range,:)).*(1./data_ranges); + norm_data = (data(chunk_index_range, :)) .* (1 ./ data_ranges); - temp_mins = min(mdata(:,:,:)); - temp_maxs = max(mdata(:,:,:)); + temp_mins = min(mdata(:, :, :)); + temp_maxs = max(mdata(:, :, :)); temp_ranges = temp_maxs - temp_mins; - norm_temp_rngs = repmat(max(temp_ranges),1,length(channels)); - norm_temp = (mdata(:,:,:)).*(1./norm_temp_rngs); + norm_temp_rngs = repmat(max(temp_ranges), 1, length(channels)); + norm_temp = (mdata(:, :, :)) .* (1 ./ norm_temp_rngs); data_amount_size = length(chunk_index_range); - % disp_ch = + % disp_ch = for jj = 1:length(channels) - % ch = channels(jj); - plot(chunk_index_range, norm_data(:,jj)+2*jj*ones(data_amount_size,1), ... - 'color',chan_cmap(jj,:), ... - 'LineWidth',1.2) + % ch = channels(jj); + plot(chunk_index_range, norm_data(:, jj) + 2 * jj * ones(data_amount_size, 1), ... + 'color', chan_cmap(jj, :), ... + 'LineWidth', 1.2) end for ii = 1:length(C) cc = C(ii); - bitmask = ismember(cluster_ID,cc); + bitmask = ismember(cluster_ID, cc); spikes_for_cluster = spike_times(bitmask); - trunc_idxs = spikes_for_clustermin(chunk_index_range); + trunc_idxs = spikes_for_cluster < max(chunk_index_range) & spikes_for_cluster > min(chunk_index_range); trunc_spike_times = spikes_for_cluster(trunc_idxs); - % trunc_cluster_ID = cc.*uint32(ones(sum(trunc_idxs),1)); - % s = scatter(trunc_spike_times,trunc_cluster_ID,'|'); - clust_template = norm_temp(:,:,ii); + % trunc_cluster_ID = cc.*uint32(ones(sum(trunc_idxs),1)); + % s = scatter(trunc_spike_times,trunc_cluster_ID,'|'); + clust_template = norm_temp(:, :, ii); for kk = 1:length(channels) - % ch = channels(kk); - for iT=1:length(trunc_spike_times)-1 + % ch = channels(kk); + for iT = 1:length(trunc_spike_times) - 1 plot( ... - (trunc_spike_times(iT)-floor(template_width/2)+1):(trunc_spike_times(iT)+floor(template_width/2)), ... - clust_template(:,kk)+1+kk*2, ... - 'color',[clust_cmap(ii,:) 0.5],... - 'LineWidth',1.2); - % t = plot(trunc_spike_times,trunc_cluster_ID,'|'); + (trunc_spike_times(iT) - floor(template_width / 2) + 1):(trunc_spike_times(iT) + floor(template_width / 2)), ... + clust_template(:, kk) + 1 + kk * 2, ... + 'color', [clust_cmap(ii, :) 0.5], ... + 'LineWidth', 1.2); + % t = plot(trunc_spike_times,trunc_cluster_ID,'|'); end end alpha(0.2) - % set(s(1), ... - % 'SizeData', 500, ... - % 'LineWidth',1.5, ... - % 'MarkerEdgeColor', clust_cmap(ii,:), ... - % 'MarkerEdgeAlpha', 0.5) + % set(s(1), ... + % 'SizeData', 500, ... + % 'LineWidth',1.5, ... + % 'MarkerEdgeColor', clust_cmap(ii,:), ... + % 'MarkerEdgeAlpha', 0.5) end title('Template Matches for Each Channel') xlabel('Time (s)') @@ -107,27 +107,27 @@ function spike_validation_plot(chunk, clusters) % ax.XAxis.Exponent = 0; set(gcf, 'WindowState', 'fullscreen'); % set fullscreen outerpos = ax.OuterPosition; - ti = ax.TightInset; + ti = ax.TightInset; left = outerpos(1) + ti(1); bottom = outerpos(2) + ti(2); ax_width = outerpos(3) - ti(1) - ti(3); ax_height = outerpos(4) - ti(2) - ti(4); ax.Position = [left bottom ax_width ax_height]; - set(ax,'color',[0 0 0]) + set(ax, 'color', [0 0 0]) set(ax, 'YTick', []); set(ax, 'XTick', chunk_index_range(1):30000:chunk_index_range(end)); - set(ax, 'XTickLabel', chunk_index_range(1)/30000:1:chunk_index_range(end)/30000); + set(ax, 'XTickLabel', chunk_index_range(1) / 30000:1:chunk_index_range(end) / 30000); end % get 10 second chunks of data function data_amount = get_data_amount(chunk) - data_amount = (chunk-1)*300000+1:chunk*300000; + data_amount = (chunk - 1) * 300000 + 1:chunk * 300000; end -function my_closereq(src,event) - % Close request function +function my_closereq(src, event) + % Close request function % to quit MATLAB when plot is closed - disp('Plot closed. Quitting MATLAB.') - delete(gcf) - quit -end \ No newline at end of file + disp('Closing figure and quitting MATLAB.') + delete(gcf); + quit; +end