diff --git a/brainwidemap/encoding/Dockerfile b/brainwidemap/encoding/Dockerfile index 9ff29672..88305d6f 100644 --- a/brainwidemap/encoding/Dockerfile +++ b/brainwidemap/encoding/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:11.7.1-devel-ubuntu22.04 +FROM ubuntu:latest # This can optionally be built with just ubuntu, rather than the nvidia cuda container. # If saving space is a concern, this is the way to go. LABEL description="Core container which has the basic necessities to run analyses in the\ @@ -15,24 +15,17 @@ COPY ./environment.yaml /data/environment.yaml SHELL ["/bin/bash", "-c"] # For some reason ibllib.io.video needs opencv which requires libgl1-mesa-dev ¯\_(ツ)_/¯ RUN apt update && apt install -y wget git libgl1-mesa-dev -RUN wget -O Mambaforge.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh" -RUN bash Mambaforge.sh -b -p /opt/conda && rm Mambaforge.sh +RUN wget -O Miniforge3.sh "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" +RUN bash Miniforge3.sh -b -p /opt/conda && rm Miniforge3.sh +RUN wget -O iblreq.txt "https://raw.githubusercontent.com/int-brain-lab/ibllib/master/requirements.txt" +RUN head -n -1 iblreq.txt > requirements.txt +RUN rm iblreq.txt RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh && \ mamba install --yes conda-build &&\ mamba env create -n iblenv --file=environment.yaml" -RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh &&\ - conda activate iblenv &&\ - mamba install --yes pytorch pytorch-cuda=11.7 -c pytorch -c nvidia &&\ - conda clean --all -f -y" -RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh &&\ - conda activate iblenv &&\ - pip install globus-sdk iblutil ibl-neuropixel ONE-api phylib pynrrd slidingRP &&\ - git clone https://github.com/int-brain-lab/iblapps.git &&\ - conda develop ./iblapps &&\ - git clone https://github.com/int-brain-lab/ibllib &&\ - conda develop ./ibllib &&\ - git clone https://github.com/berkgercek/neurencoding &&\ - conda develop ./neurencoding" +RUN /bin/bash -c "source /opt/conda/etc/profile.d/conda.sh && \ + conda activate iblenv && pip install -r requirements.txt && pip install ibllib --no-deps" +RUN rm requirements.txt # The below allows interactively running the container with the correct environment, but be warned # that this will not work with commands passed to the container in a non-interactive shell. # In the case of e.g. `docker run thiscontainer python myscript.py`, the environment will not diff --git a/brainwidemap/encoding/README.md b/brainwidemap/encoding/README.md index af0c1ea8..78b60f2f 100644 --- a/brainwidemap/encoding/README.md +++ b/brainwidemap/encoding/README.md @@ -38,7 +38,7 @@ The `scripts/` folder contains small scripts that either run plotting or simple ### Cluster worker -`cluster_worker.py` implements a mother script for cluster workers to process individual probe insertions. This relies on a cached dataset, produced using the `pipelines/01_cache_regressors.py` script, as well as several files specifying the identity and parameters of a cached dataset and the parameters of the current run of the model. +`cluster_worker.py` implements a mother script for cluster workers to process individual probe insertions. This relies on a cached dataset, produced using the `pipelines/01_cache_regressors.py` script, as well as several files specifying the identity and parameters of a cached dataset and the parameters of the current run of the model. Note that the params.py file wil need to point to the appropriate cache locations as well for the worker to function. ### Design matrix diff --git a/brainwidemap/encoding/cluster_worker.py b/brainwidemap/encoding/cluster_worker.py index a2d4b455..5a01b4ba 100644 --- a/brainwidemap/encoding/cluster_worker.py +++ b/brainwidemap/encoding/cluster_worker.py @@ -14,6 +14,7 @@ # Third party libraries import numpy as np +from pandas import read_pickle # Brainwidemap repo imports from brainwidemap.encoding.design import generate_design @@ -23,7 +24,7 @@ def get_cached_regressors(fpath): with open(fpath, "rb") as fo: - d = pickle.load(fo) + d = read_pickle(fo) return d["trialsdf"], d["spk_times"], d["spk_clu"], d["clu_regions"], d["clu_df"] @@ -37,9 +38,9 @@ def _create_sub_sess_path(parent, subject, session): return sesspath -def save_stepwise(subject, session_id, fitout, params, probes, input_fn, clu_reg, clu_df, fitdate): +def save_stepwise(subject, session_id, fitout, params, probes, input_fn, clu_reg, clu_df, fitdate, splitstr=""): sesspath = _create_sub_sess_path(GLM_FIT_PATH, subject, session_id) - fn = sesspath.joinpath(f"{fitdate}_{probes}_stepwise_regression.pkl") + fn = sesspath.joinpath(f"{fitdate}_{probes}{splitstr}_stepwise_regression.pkl") outdict = { "params": params, "probes": probes, @@ -81,14 +82,41 @@ def fit_save_inputs( t_before, fitdate, null=None, + earlyrts=False, + laterts=False, ): stdf, sspkt, sspkclu, sclureg, scluqc = get_cached_regressors(eidfn) sessprior = stdf["probabilityLeft"] - sessdesign = generate_design(stdf, sessprior, t_before, **params) + match (earlyrts, laterts): + case (False, False): + splitstr = "" + case (True, False): + splitstr = "_earlyrt" + case (False, True): + splitstr = "_latert" + if not earlyrts and not laterts: + sessdesign = generate_design(stdf, sessprior, t_before, **params) + else: + # Handle early and late RT flags, compute median for session if necessary + if "rt_thresh" not in params: + raise ValueError("Must specify rt_thresh if fitting early or late RTs") + if laterts and earlyrts: + raise ValueError( + "Cannot fit both early and late RTs. Disable both flags to fit all trials." + ) + if params["rt_thresh"] == "session_median": + params["rt_thresh"] = np.median(stdf["firstMovement_times"] - stdf["trial_start"]) + + if earlyrts: + mask = (stdf["firstMovement_times"] - stdf["trial_start"]) < params["rt_thresh"] + elif laterts: + mask = (stdf["firstMovement_times"] - stdf["trial_start"]) >= params["rt_thresh"] + stdf = stdf[mask] + sessdesign = generate_design(stdf, sessprior, t_before, **params) if null is None: sessfit = fit_stepwise(sessdesign, sspkt, sspkclu, **params) outputfn = save_stepwise( - subject, eid, sessfit, params, probes, eidfn, sclureg, scluqc, fitdate + subject, eid, sessfit, params, probes, eidfn, sclureg, scluqc, fitdate, splitstr ) elif null == "pseudosession_pleft_iti": sessfit, nullfits = fit_stepwise_with_pseudoblocks( @@ -114,11 +142,13 @@ def fit_save_inputs( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Cluster GLM fitter. This script is called by" - "the batch script generated in " - "pipelines/02_fit_sessions.py and should in most " - "cases beyond debugging not be used in a " - "standalone fashion.") + parser = argparse.ArgumentParser( + description="Cluster GLM fitter. This script is called by" + "the batch script generated in " + "pipelines/02_fit_sessions.py and should in most " + "cases beyond debugging not be used in a " + "standalone fashion." + ) parser.add_argument( "datafile", type=Path, @@ -131,6 +161,16 @@ def fit_save_inputs( ) parser.add_argument("fitdate", help="Date of fit for output file") parser.add_argument("--impostor_path", type=Path, help="Path to main impostor df file") + parser.add_argument( + "--earlyrt", + action="store_true", + help="Whether to fit separate movement kernels to early trials", + ) + parser.add_argument( + "--latert", + action="store_true", + help="Whether to fit separate movement kernels to late trials", + ) args = parser.parse_args() with open(args.datafile, "rb") as fo: @@ -154,6 +194,8 @@ def fit_save_inputs( t_before, args.fitdate, null=params["null"], + earlyrts=args.earlyrt, + laterts=args.latert, ) print("Fitting completed successfully!") print(outputfn) diff --git a/brainwidemap/encoding/design.py b/brainwidemap/encoding/design.py index 37ecd0b5..989517eb 100644 --- a/brainwidemap/encoding/design.py +++ b/brainwidemap/encoding/design.py @@ -6,13 +6,12 @@ # Standard library import logging +# IBL libraries +import neurencoding.design_matrix as dm + # Third party libraries import numpy as np import pandas as pd -from scipy.stats import norm - -# IBL libraries -import neurencoding.design_matrix as dm _logger = logging.getLogger("brainwide") diff --git a/brainwidemap/encoding/environment.yaml b/brainwidemap/encoding/environment.yaml index 22c49eef..490877f2 100644 --- a/brainwidemap/encoding/environment.yaml +++ b/brainwidemap/encoding/environment.yaml @@ -1,31 +1,14 @@ name: iblenv dependencies: - - python=3.9 - - apptools >= 4.5.0 - - boto3 - - click - - colorcet - - colorlog - - cython - - dataclasses - - flake8 - - graphviz - - h5py + - python=3.10 - ipython - matplotlib - numba - - numpy - pandas - - plotly - - pyarrow - - pyflakes >= 2.4.0 - - pytest - - requests - scikit-learn - scipy >=1.4.1 - seaborn - statsmodels - tqdm - pip - - pip: - - opencv-python + - pyqt<6 diff --git a/brainwidemap/encoding/fit.py b/brainwidemap/encoding/fit.py index 1d68a0ce..9eaad747 100644 --- a/brainwidemap/encoding/fit.py +++ b/brainwidemap/encoding/fit.py @@ -13,12 +13,25 @@ from brainwidemap.encoding.design import generate_design -def fit(design, spk_t, spk_clu, binwidth, model, estimator, n_folds=5, contiguous=False, **kwargs): +def fit( + design, + spk_t, + spk_clu, + binwidth, + model, + estimator, + n_folds=5, + contiguous=False, + mintrials=100, + **kwargs +): """ Function to fit a model using a cross-validated design matrix. """ trials_idx = design.trialsdf.index - nglm = model(design, spk_t, spk_clu, binwidth=binwidth, estimator=estimator, mintrials=0) + nglm = model( + design, spk_t, spk_clu, binwidth=binwidth, estimator=estimator, mintrials=mintrials + ) splitter = KFold(n_folds, shuffle=not contiguous) scores, weights, intercepts, alphas, splits = [], [], [], [], [] for test, train in splitter.split(trials_idx): @@ -52,6 +65,7 @@ def fit_stepwise( estimator, n_folds=5, contiguous=False, + mintrials=100, seqsel_kwargs={}, seqselfit_kwargs={}, **kwargs @@ -107,7 +121,9 @@ def fit_stepwise( splits: list of dicts containing the test and train indices for each fold. """ trials_idx = design.trialsdf.index - nglm = model(design, spk_t, spk_clu, binwidth=binwidth, estimator=estimator, mintrials=0) + nglm = model( + design, spk_t, spk_clu, binwidth=binwidth, estimator=estimator, mintrials=mintrials + ) splitter = KFold(n_folds, shuffle=not contiguous) sequences, scores, deltas, splits = [], [], [], [] for test, train in tqdm(splitter.split(trials_idx), desc="Fold", leave=False): diff --git a/brainwidemap/encoding/glm_predict.py b/brainwidemap/encoding/glm_predict.py index 72e72ff1..9d94839e 100644 --- a/brainwidemap/encoding/glm_predict.py +++ b/brainwidemap/encoding/glm_predict.py @@ -204,23 +204,7 @@ def psth_summary(self, align_time, unit, t_before=0.1, t_after=0.6, trials=None, ax=ax[0], smoothing=0.01, ) - keytuple = (align_time, t_before, t_after, tuple(trials)) - if keytuple not in self.full_psths: - self.full_psths[keytuple] = pred_psth( - self.nglm, align_time, t_before, t_after, trials=trials - ) - self.cov_psths[keytuple] = {} - tmp = self.cov_psths[keytuple] - for cov in self.covar: - tmp[cov] = pred_psth( - self.nglm, - align_time, - t_before, - t_after, - targ_regressors=[cov], - trials=trials, - incl_bias=False, - ) + keytuple = self.compute_model_psth(align_time, t_before, t_after, trials) for cov in self.covar: ax[2].plot(self.combweights[cov].loc[unit]) ax[2].set_title("Individual kernels (not PSTH contrib)") @@ -244,3 +228,45 @@ def psth_summary(self, align_time, unit, t_before=0.1, t_after=0.6, trials=None, plt.suptitle(f"Unit {unit}") plt.tight_layout() return ax + + def compute_model_psth(self, align_time, t_before, t_after, trials): + """ + Generate and store internally model PSTHs for a given alignment time and trials. + + Parameters + ---------- + align_time : str + Column in the design matrix to align PSTH to + t_before : float + Time before the align time to compute PSTH for + t_after : _type_ + Time after the align time to compute PSTH for + trials : array-like of int + List of trials on which to compute the PSTH + + Returns + ------- + tuple + 4-tuple with the alignment time, time before, time after, and trials used to compute, + can be used as a key in the internal self.full_psths and self.cov_psths dictionaries, + which contain the full PSTH and the PSTH for each regressor, respectively. + """ + keytuple = (align_time, t_before, t_after, tuple(trials)) + if keytuple not in self.full_psths: + self.full_psths[keytuple] = pred_psth( + self.nglm, align_time, t_before, t_after, trials=trials + ) + self.cov_psths[keytuple] = {} + tmp = self.cov_psths[keytuple] + for cov in self.covar: + tmp[cov] = pred_psth( + self.nglm, + align_time, + t_before, + t_after, + targ_regressors=[cov], + trials=trials, + incl_bias=False, + ) + + return keytuple diff --git a/brainwidemap/encoding/params.py b/brainwidemap/encoding/params.py index 90cb28fc..1d8705ee 100644 --- a/brainwidemap/encoding/params.py +++ b/brainwidemap/encoding/params.py @@ -4,5 +4,5 @@ work. """ -GLM_CACHE = "/mnt/Storage/glm_cache/" -GLM_FIT_PATH = "/mnt/Storage/results/glms/" +GLM_CACHE = "/home/gercek/Projects/glm_cache/" +GLM_FIT_PATH = "/home/gercek/Projects/results/glms/" diff --git a/brainwidemap/encoding/pipelines/01_cache_regressors.py b/brainwidemap/encoding/pipelines/01_cache_regressors.py index 5fd66225..96f45522 100644 --- a/brainwidemap/encoding/pipelines/01_cache_regressors.py +++ b/brainwidemap/encoding/pipelines/01_cache_regressors.py @@ -7,14 +7,11 @@ # Third party libraries import dask -import numpy as np import pandas as pd from dask.distributed import Client from dask_jobqueue import SLURMCluster # IBL libraries -import brainbox.io.one as bbone -from iblutil.numerical import ismember from one.api import ONE from brainwidemap.encoding.params import GLM_CACHE from brainwidemap.encoding.utils import load_regressors @@ -68,7 +65,7 @@ def delayed_loadsave(subject, session_id, pid, params): T_BEF = 0.6 # Time before stimulus onset to include in the definition of the trial T_AFT = 0.6 # Time after feedback to include in the definition of a trial BINWIDTH = 0.02 # Size of binwidth for wheel velocity traces, in seconds -ABSWHEEL = False # Whether to return wheel velocity (False) or speed (True) +ABSWHEEL = True # Whether to return wheel velocity (False) or speed (True) CLU_CRITERIA = "bwm" # Criteria on cluster inclusion in cache # End parameters @@ -79,13 +76,15 @@ def delayed_loadsave(subject, session_id, pid, params): "binwidth": BINWIDTH, "abswheel": ABSWHEEL, "clu_criteria": CLU_CRITERIA, + "one_url": "https://alyx.internationalbrainlab.org", + "one_pw": "international", } -pw = 'international' -one = ONE(base_url='https://openalyx.internationalbrainlab.org', password=pw, silent=True) +one = ONE(base_url=params["one_url"], silent=True) dataset_futures = [] -sessdf = bwm_query().set_index("pid") +freeze = "2023_12_bwm_release" if CLU_CRITERIA == "bwm" else None +sessdf = bwm_query(freeze=freeze).set_index("pid") for pid, rec in sessdf.iterrows(): subject = rec.subject @@ -110,7 +109,7 @@ def delayed_loadsave(subject, session_id, pid, params): f"export OPENBLAS_NUM_THREADS={N_CORES}", ], ) -cluster.scale(40) +cluster.scale(20) client = Client(cluster) tmp_futures = [client.compute(future[3]) for future in dataset_futures] diff --git a/brainwidemap/encoding/pipelines/02_fit_sessions.py b/brainwidemap/encoding/pipelines/02_fit_sessions.py index d17820c7..826ce4bf 100644 --- a/brainwidemap/encoding/pipelines/02_fit_sessions.py +++ b/brainwidemap/encoding/pipelines/02_fit_sessions.py @@ -1,19 +1,18 @@ # Standard library -import os -import pickle import argparse +import pickle from datetime import date from pathlib import Path +# IBL libraries +import neurencoding.linear as lm +import neurencoding.utils as mut + # Third party libraries import numpy as np import sklearn.linear_model as skl from sklearn.model_selection import GridSearchCV -# IBL libraries -import neurencoding.linear as lm -import neurencoding.utils as mut - # Brainwide repo imports from brainwidemap.encoding.params import GLM_CACHE, GLM_FIT_PATH from brainwidemap.encoding.utils import make_batch_slurm_singularity @@ -27,12 +26,12 @@ " parameters for the actual GLM fitting are defined within the script itself." " The arguments passed to the script via this parser are only for cluster control." " If you would like to change parameters of the actual fit please adjust the contents" - " of the \"parameters\" section in the file." + ' of the "parameters" section in the file.' ) parser.add_argument( "--basefilepath", type=Path, - default=Path("~/").expanduser().joinpath("bwm_stepwise_glm_leaveoneout"), + default=Path("~/").expanduser().joinpath("jobscripts/bwm_stepwise_glm_leaveoneout"), help="Base filename for batch scripts", ) parser.add_argument( @@ -54,7 +53,7 @@ "--singularity_modules", type=str, nargs="+", - default=["GCC/9.3.0", "Singularity/3.7.3-Go-1.14"], + default=[], help="Modules to load when using singularity containers.", ) parser.add_argument( @@ -85,21 +84,17 @@ "--job_cores", type=int, default=32, help="Number of cores to request per job." ) parser.add_argument("--mem", type=str, default="12GB", help="Memory to request per job.") -parser.add_argument( - "--submit_batch", - action="store_true", - default=False, - help="Submit batch jobs to SLURM cluster using the script.", -) args = parser.parse_args() + # Model parameters # The GLM constructor class requires a function that converts time to bin index, here we define it -# using the binwidth parameter created shortly. -def tmp_binf(t): +# using the binwidth parameter created shortly. +def tmp_binf(t): return np.ceil(t / params["binwidth"]).astype(int) + ######### PARAMETERS ######### params = { "binwidth": 0.02, @@ -108,7 +103,7 @@ def tmp_binf(t): "wheel_offset": -0.3, "contnorm": 5.0, "reduce_wheel_dim": False, - "dataset_fn": "2022-12-22_dataset_metadata.pkl", + "dataset_fn": "2024-08-12_dataset_metadata.pkl", "model": lm.LinearGLM, "alpha_grid": {"alpha": np.logspace(-3, 2, 50)}, "contiguous": False, @@ -118,6 +113,8 @@ def tmp_binf(t): "seqsel_kwargs": {"direction": "backward", "n_features_to_select": 8}, "seqselfit_kwargs": {"full_scores": True}, "seed": 0, + "rt_thresh": "session_median", + "mintrials": 50, } params["bases"] = { @@ -128,6 +125,14 @@ def tmp_binf(t): } # Estimator relies on alpha grid in case of GridSearchCV, needs to be defined after main params params["estimator"] = GridSearchCV(skl.Ridge(), params["alpha_grid"]) +if "rt_thresh" in params: + earlyrt_flag = "--earlyrt" + latert_flag = "--latert" + earlyrt_fn = "_early_rt" +else: + earlyrt_flag = "" + latert_flag = "" + earlyrt_fn = "" # Output parameters file for workers currdate = str(date.today()) @@ -142,7 +147,7 @@ def tmp_binf(t): # Generate batch script make_batch_slurm_singularity( - str(args.basefilepath), + str(args.basefilepath) + earlyrt_fn, str(Path(__file__).parents[1].joinpath("cluster_worker.py")), job_name=args.jobname, partition=args.partition, @@ -156,14 +161,29 @@ def tmp_binf(t): cores_per_job=args.job_cores, memory=args.mem, array_size=f"1-{njobs}", - f_args=[str(datapath), str(parpath), r"${SLURM_ARRAY_TASK_ID}", currdate], + f_args=[earlyrt_flag, str(datapath), str(parpath), r"${SLURM_ARRAY_TASK_ID}", currdate], ) - -# If SUBMIT_BATCH, then actually execute the batch job -if args.submit_batch: - os.system(f"sbatch {str(args.basefilepath) + '_batch.sh'}") -else: - print( - f"Batch file generated at {str(args.basefilepath) + '_batch.sh'};" - " user must submit it themselves. Good luck!" +if len(earlyrt_fn) > 0: + make_batch_slurm_singularity( + str(args.basefilepath) + "_late_rt", + str(Path(__file__).parents[1].joinpath("cluster_worker.py")), + job_name=args.jobname, + partition=args.partition, + time=args.timelimit, + singularity_modules=args.singularity_modules, + container_image=args.singularity_image, + img_condapath=args.singularity_conda, + img_envname=args.singularity_env, + local_pathadd=Path(__file__).parents[3], + logpath=args.logpath, + cores_per_job=args.job_cores, + memory=args.mem, + array_size=f"1-{njobs}", + f_args=[latert_flag, str(datapath), str(parpath), r"${SLURM_ARRAY_TASK_ID}", currdate], ) + +# If SUBMIT_BATCH, then actually execute the batch jo +print( + f"Batch file generated at {str(args.basefilepath) + '_batch.sh'};" + " user must submit it themselves. Good luck!" +) diff --git a/brainwidemap/encoding/pipelines/03_gather_results.py b/brainwidemap/encoding/pipelines/03_gather_results.py index c505081a..1dd32c77 100644 --- a/brainwidemap/encoding/pipelines/03_gather_results.py +++ b/brainwidemap/encoding/pipelines/03_gather_results.py @@ -1,6 +1,5 @@ # Standard library from functools import cache -from argparse import ArgumentParser # Third party libraries import numpy as np @@ -8,6 +7,40 @@ # IBL libraries from iblatlas.atlas import BrainRegions +from iblutil.numerical import ismember + + +def colrename(cname, suffix): + return str(cname + 1) + "cov" + suffix + + +def remap(ids, source="Allen", dest="Beryl", output="acronym", br=BrainRegions()): + _, inds = ismember(ids, br.id[br.mappings[source]]) + ids = br.id[br.mappings[dest][inds]] + if output == "id": + return br.id[br.mappings[dest][inds]] + elif output == "acronym": + return br.get(br.id[br.mappings[dest][inds]])["acronym"] + + +def get_id(acronym, brainregions=BrainRegions()): + return brainregions.id[np.argwhere(brainregions.acronym == acronym)[0, 0]] + + +def get_name(acronym, brainregions=BrainRegions()): + if acronym == "void": + return acronym + reg_idxs = np.argwhere(brainregions.acronym == acronym).flat + return brainregions.name[reg_idxs[0]] + + +def label_cerebellum(acronym, brainregions=BrainRegions()): + regid = brainregions.id[np.argwhere(brainregions.acronym == acronym).flat][0] + ancestors = brainregions.ancestors(regid) + if "Cerebellum" in ancestors.name or "Medulla" in ancestors.name: + return True + else: + return False if __name__ == "__main__": @@ -20,33 +53,21 @@ # Brainwide repo imports from brainwidemap.encoding.params import GLM_CACHE, GLM_FIT_PATH - parser = ArgumentParser( - description="Gather results from GLM fitting on a given date with given N covariates." - ) - parser.add_argument( - "--fitdate", - type=str, - default="2023-04-09", - help="Date on which fit was run", - ) - parser.add_argument( - "--n_cov", - type=int, - default=9, - help="Number of covariates in model", - ) - args = parser.parse_args() - fitdate = args.fitdate - n_cov = args.n_cov - parpath = Path(GLM_FIT_PATH).joinpath(f"{fitdate}_glm_fit_pars.pkl") + currdate = "2024-09-15" # Date on which fit was run + n_cov = 9 # Modify if you change the model! + parpath = Path(GLM_FIT_PATH).joinpath(f"{currdate}_glm_fit_pars.pkl") + early_split = False with open(parpath, "rb") as fo: params = pickle.load(fo) datapath = Path(GLM_CACHE).joinpath(params["dataset_fn"]) with open(datapath, "rb") as fo: dataset = pickle.load(fo) + subject_names = dataset["dataset_filenames"]["subject"].unique() filenames = [] for subj in os.listdir(Path(GLM_FIT_PATH)): + if subj not in subject_names: + continue subjdir = Path(GLM_FIT_PATH).joinpath(subj) if not os.path.isdir(subjdir): continue @@ -54,7 +75,7 @@ sessdir = subjdir.joinpath(sess) for file in os.listdir(sessdir): filepath = sessdir.joinpath(file) - if os.path.isfile(filepath) and filepath.match(f"*{fitdate}*"): + if os.path.isfile(filepath) and filepath.match(f"*{currdate}*"): filenames.append(filepath) # Process files after fitting @@ -65,51 +86,44 @@ folds = [] for i in range(len(tmpfile["scores"])): tmpdf = tmpfile["deltas"][i]["test"] + tmpdf.index.name = "clu_id" tmpdf["full_model"] = tmpfile["scores"][i]["basescores"]["test"] tmpdf["eid"] = fitname.parts[-2] tmpdf["pid"] = fitname.parts[-1].split("_")[1] - tmpdf["acronym"] = tmpfile["clu_regions"][tmpdf.index] - tmpdf["qc_label"] = tmpfile["clu_df"]["label"][tmpdf.index] + tmpdf["acronym"] = tmpfile["clu_regions"] + tmpdf["qc_label"] = tmpfile["clu_df"]["label"] tmpdf["fold"] = i + tmpdf.index = tmpfile["clu_df"].iloc[tmpdf.index].cluster_id tmpdf.index.set_names(["clu_id"], inplace=True) folds.append(tmpdf.reset_index()) sess_master = pd.concat(folds) sessdfs.append(sess_master) masterscores = pd.concat(sessdfs) - # Take the average score across the different folds of cross-validation for each unit for - # each of the model regressors + kernels = [ + "stimonR", + "stimonL", + "correct", + "incorrect", + "fmoveR", + "fmoveL", + "pLeft", + "pLeft_tr", + "wheel", + "full_model", + ] + meanmaster = ( masterscores.set_index(["eid", "pid", "clu_id", "acronym", "qc_label", "fold"]) .groupby(["eid", "pid", "clu_id", "acronym", "qc_label"]) - .agg( - { - k: "mean" - for k in [ - "stimonR", - "stimonL", - "correct", - "incorrect", - "fmoveR", - "fmoveL", - "pLeft", - "pLeft_tr", - "wheel", - "full_model", - ] - } - ) + .agg({k: "mean" for k in kernels}) ) - br = BrainRegions() - @cache def regmap(acr): - return br.acronym2acronym(acr, mapping="Beryl") + ids = get_id(acr) + return remap(ids, br=br) br = BrainRegions() - # Remap the existing acronyms, which use the Allen ontology, into the Beryl ontology - # Note that the groupby operation is to save time on computation so we don't need to - # recompute the region mapping for each unit, but rather each Allen acronym. grpby = masterscores.groupby("acronym") meanmaster.reset_index(["acronym", "qc_label"], inplace=True) masterscores["region"] = [regmap(ac)[0] for ac in masterscores["acronym"]] @@ -122,5 +136,5 @@ def regmap(acr): "mean_fit_results": meanmaster, "fit_files": filenames, } - with open(Path(GLM_FIT_PATH).joinpath(f"{fitdate}_glm_fit.pkl"), "wb") as fw: + with open(Path(GLM_FIT_PATH).joinpath(f"{currdate}_glm_fit.pkl"), "wb") as fw: pickle.dump(outdict, fw) diff --git a/brainwidemap/encoding/pipelines/04_plot_figures.py b/brainwidemap/encoding/pipelines/04_plot_figures.py index adec0555..c7f71cc2 100644 --- a/brainwidemap/encoding/pipelines/04_plot_figures.py +++ b/brainwidemap/encoding/pipelines/04_plot_figures.py @@ -3,7 +3,6 @@ # Third party libraries import matplotlib.pyplot as plt -from matplotlib import colors import numpy as np import pandas as pd import seaborn as sns @@ -11,11 +10,12 @@ # IBL libraries from iblatlas.atlas import BrainRegions from iblatlas.plots import plot_swanson +from matplotlib import colors # Brainwidemap repo imports from brainwidemap.encoding.params import GLM_FIT_PATH -FITDATE = "2023-03-02" +FITDATE = "2024-07-16" VARIABLES = [ "stimonR", "stimonL", @@ -23,6 +23,8 @@ "incorrect", "fmoveR", "fmoveL", + # "fmoveR_early", # Comment/uncomment if early RT split is used. + # "fmoveL_early", "pLeft", "pLeft_tr", "wheel", @@ -40,7 +42,7 @@ ABSDIFF = True # Whether to plot absolute value of difference or signed difference ANNOTATE = False # Whether to annotate brain regions IMGFMT = "png" # Format of output image -SAVEPATH = Path("/home/berk/Documents/Projects/results/plots/swanson_maps/") # Path to save plots +SAVEPATH = Path("/home/gercek/Projects/results/plots/swanson_maps/") # Path to save plots if not SAVEPATH.exists(): SAVEPATH.mkdir() @@ -83,31 +85,27 @@ br = BrainRegions() -def flatmap_variable(df, - cmap, - cmin=COLOR_RANGE[0], - cmax=COLOR_RANGE[1], - norm=None, - plswan_kwargs={}): +def flatmap_variable( + df, cmap, cmin=COLOR_RANGE[0], cmax=COLOR_RANGE[1], norm=None, plswan_kwargs={} +): fig = plt.figure(figsize=(8, 4) if not ANNOTATE else (16, 8)) ax = fig.add_subplot(111) if norm is not None: cmap_kwargs = {"norm": norm, **plswan_kwargs} elif GLOBAL_CMAP: - cmap_kwargs = { - "norm": colors.LogNorm(vmin=cmin, vmax=cmax, clip=True), - **plswan_kwargs - } + cmap_kwargs = {"norm": colors.LogNorm(vmin=cmin, vmax=cmax, clip=True), **plswan_kwargs} else: cmap_kwargs = {"vmin": cmin, "vmax": cmax, **plswan_kwargs} - ax = plot_swanson(df.index, - df.values, - hemisphere="left", - cmap=cmap, - br=br, - ax=ax, - annotate=ANNOTATE, - **cmap_kwargs) + ax = plot_swanson( + df.index, + df.values, + hemisphere="left", + cmap=cmap, + br=br, + ax=ax, + annotate=ANNOTATE, + **cmap_kwargs, + ) plt.colorbar(mappable=ax.images[0]) ax.set_xticks([]) ax.set_yticks([]) @@ -116,9 +114,9 @@ def flatmap_variable(df, def get_cmap(split): - ''' + """ for each split, get a colormap defined by Yanliang - ''' + """ varmaps = { "stimonR": "stim", "stimonL": "stim", @@ -128,14 +126,14 @@ def get_cmap(split): "incorrect": "fback", "pLeft": "block", "pLeft_tr": "block", - "wheel": "wheel" + "wheel": "wheel", } dc = { - 'stim': ["#ffffff", "#D5E1A0", "#A3C968", "#86AF40", "#517146"], - 'choice': ["#ffffff", "#F8E4AA", "#F9D766", "#E8AC22", "#DA4727"], - 'fback': ["#ffffff", "#F1D3D0", "#F5968A", "#E34335", "#A23535"], - 'block': ["#ffffff", "#D0CDE4", "#998DC3", "#6159A6", "#42328E"], - 'wheel': ["#ffffff", "#C2E1EA", "#95CBEE", "#5373B8", "#324BA0"] + "stim": ["#ffffff", "#D5E1A0", "#A3C968", "#86AF40", "#517146"], + "choice": ["#ffffff", "#F8E4AA", "#F9D766", "#E8AC22", "#DA4727"], + "fback": ["#ffffff", "#F1D3D0", "#F5968A", "#E34335", "#A23535"], + "block": ["#ffffff", "#D0CDE4", "#998DC3", "#6159A6", "#42328E"], + "wheel": ["#ffffff", "#C2E1EA", "#95CBEE", "#5373B8", "#324BA0"], } return colors.LinearSegmentedColormap.from_list("mycmap", dc[varmaps[split]]) @@ -144,9 +142,9 @@ def get_cmap(split): # Distribution of full model R^2 values, and std. dev between folds meanscores = fitdata["mean_fit_results"] full_model = meanscores["full_model"].copy() -full_model_std = (fitdata["fit_results"].groupby(["eid", "pid", "clu_id" - ]).agg({"full_model": - "std"})) +full_model_std = ( + fitdata["fit_results"].groupby(["eid", "pid", "clu_id"]).agg({"full_model": "std"}) +) joindf = full_model_std.join(full_model, how="inner", lsuffix="_std") if DISTPLOTS: @@ -168,8 +166,7 @@ def get_cmap(split): unitcounts = meanscores.groupby("region").size().astype(int) keepreg = unitcounts[unitcounts >= MIN_UNITS].index if GLOBAL_CMAP: - allmeans = meanscores.set_index( - "region", append=True)[VARIABLES].groupby("region").mean() + allmeans = meanscores.set_index("region", append=True)[VARIABLES].groupby("region").mean() cmin = np.percentile(allmeans.values.flatten(), COLOR_RANGE[0]) if cmin < 0: cmin = 1e-5 @@ -222,7 +219,8 @@ def get_cmap(split): fig.suptitle(f"{var1} $\Delta R^2$ - {var2} $\Delta R^2$") fig.savefig( DIFFPATH.joinpath( - f"{var1}_{var2}_{'abs' * ABSDIFF}diff{'_annotated' * ANNOTATE}.{IMGFMT}"), + f"{var1}_{var2}_{'abs' * ABSDIFF}diff{'_annotated' * ANNOTATE}.{IMGFMT}" + ), format=IMGFMT, dpi=450, ) diff --git a/brainwidemap/encoding/scripts/embed_saved_weights.py b/brainwidemap/encoding/scripts/embed_saved_weights.py new file mode 100644 index 00000000..3ecee171 --- /dev/null +++ b/brainwidemap/encoding/scripts/embed_saved_weights.py @@ -0,0 +1,193 @@ +# %% +import cuml +from ibllib.atlas import BrainRegions +from pathlib import Path +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +from functools import cache +import sklearn.preprocessing as pp +import sklearn.manifold as skma + + +@cache +def parentreg(region, level=3): + try: + return br.ancestors(br.acronym2id(region))["name"][level] + except IndexError: + return br.name[br.acronym2index(region)[1][0][0]] + + +@cache +def regcolor(region): + return br.rgba[br.acronym2index(region)[1]][0, 0] / 255 + + +def regionscatter(x, y, data, **kwargs): + return sns.scatterplot( + data.query("level6 in @level6_interesting_regions"), x=x, y=y, hue="level6", **kwargs + ) + + +COVARS = [ + "stimonR", + "stimonL", + "correct", + "incorrect", + "fmoveR", + "fmoveL", + "pLeft", + "pLeft_tr", + "wheel", +] + +covpairs = { + "stim": ["stimonR", "stimonL"], + "fmove": ["fmoveR", "fmoveL", "wheel"], + "feedback": ["correct", "incorrect"], + "pLeft": ["pLeft", "pLeft_tr"], + "all": COVARS, +} + +level6_interesting_regions = [ + "Somatomotor areas", + "Orbital area" "Anterior cingulate area", + "Visual areas", + "Hippocampal region", + "Caudoputamen", +] + +# %% +fitdata = pd.read_pickle("/home/berk/2023-10-02_glm_fit.pkl") +br = BrainRegions() +weights = fitdata["fit_weights"].query("region != 'root' & region != 'void'").copy() +weights["parent"] = weights.region.apply(parentreg) +weights["level6"] = weights.region.apply(parentreg, level=6) +weights["regcolor"] = weights.region.apply(regcolor) +mask = weights["parent"].apply(type) == str +weights = weights[mask] + +scaled_weights = weights.copy() +scaled_weights.loc[:, "stimonR_0":"wheel_2"] = pp.power_transform( + scaled_weights.loc[:, "stimonR_0":"wheel_2"], +) + + +methods = { + "spectr_umap": cuml.manifold.UMAP(n_components=2), + "lown_umap": cuml.manifold.UMAP(n_components=2, n_neighbors=5), + "highn_umap": cuml.manifold.UMAP(n_components=2, n_neighbors=100), + "pca": cuml.PCA(n_components=2), +} +# %% +# First single-variable embeddings +embeddings = [] + + +def create_emb_df(weights, scaled_weights, methname, method, cov, columns, scaling): + basedata = weights if not scaling else scaled_weights + emb = method.fit_transform(basedata[columns].values) + embdf = pd.DataFrame(emb, index=basedata.index, columns=["Dim 1", "Dim 2"]) + embdf["method"] = methname + embdf["covariate"] = cov + embdf["uniform_scaled"] = scaling + embdf["parent"] = basedata["parent"] + embdf["level6"] = basedata["level6"] + embdf["region"] = basedata["region"] + embdf["regcolor"] = basedata["regcolor"] + return embdf + + +for methname, method in methods.items(): + for cov in COVARS: + columns = weights.columns[weights.columns.str.match(cov)] + if len(columns) <= 2: + continue + for scaling in (True, False): + embdf = create_emb_df(weights, scaled_weights, methname, method, cov, columns, scaling) + embeddings.append(embdf) +embeddings = pd.concat(embeddings) +# %% +for cov in COVARS: + for scaling in (True, False): + fgdata = embeddings.query( + "covariate == @cov & uniform_scaled == @scaling &" + " level6 in @level6_interesting_regions" + ) + if len(fgdata) < 5: + continue + fg = sns.FacetGrid( + fgdata, + row="method", + col="level6", + sharex="row", + sharey="row", + ) + fg.map( + sns.histplot, + "Dim 1", + "Dim 2", + bins=25, + ) + fg.add_legend() + fg.set_titles("{col_name}\n{row_name}") + datatype_folder = "raw_weights" if not scaling else "normaldist_rescale" + fg.savefig( + Path("~/Documents/Projects/results/glms/").expanduser() + / "regions_of_interest" + / datatype_folder + / f"{cov}_weight_embedding_{scaling * 'rescaled_data_'}" + "bestmethods_interestingregions.png", + dpi=300, + ) + plt.close() +# %% +embeddings = [] +for methname, method in methods.items(): + for pname, pair in covpairs.items(): + columns = weights.columns[ + weights.columns.str.match(pair[0]) | weights.columns.str.match(pair[1]) + ] + if len(columns) <= 2: + continue + for scaling in (True, False): + embdf = create_emb_df(weights, scaled_weights, methname, method, pname, columns, scaling) + embeddings.append(embdf) +embeddings = pd.concat(embeddings) + +# %% +for pair in covpairs: + for scaling in (True, False): + fgdata = embeddings.query( + "covariate == @pair & uniform_scaled == @scaling &" + " level6 in @level6_interesting_regions" + ) + if len(fgdata) < 5: + continue + fg = sns.FacetGrid( + fgdata, + row="method", + col="level6", + sharex="row", + sharey="row", + ) + fg.map( + sns.histplot, + "Dim 1", + "Dim 2", + bins=25, + ) + fg.add_legend() + fg.set_titles("{col_name}\n{row_name}") + datatype_folder = "raw_weights" if not scaling else "normaldist_rescale" + fg.savefig( + Path("~/Documents/Projects/results/glms/").expanduser() + / "regions_of_interest" + / datatype_folder + / f"{pair}_pair_weight_embedding_{scaling * 'rescaled_data_'}" + "bestmethods_interestingregions.png", + dpi=300, + ) + plt.close() +# %% diff --git a/brainwidemap/encoding/scripts/glm_params.pkl b/brainwidemap/encoding/scripts/glm_params.pkl new file mode 100644 index 00000000..3089ed14 Binary files /dev/null and b/brainwidemap/encoding/scripts/glm_params.pkl differ diff --git a/brainwidemap/encoding/scripts/subpanel_plots.py b/brainwidemap/encoding/scripts/subpanel_plots.py new file mode 100644 index 00000000..a4c8f022 --- /dev/null +++ b/brainwidemap/encoding/scripts/subpanel_plots.py @@ -0,0 +1,314 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from one.api import ONE +from brainbox.plot import peri_event_time_histogram +from brainwidemap.encoding.design import generate_design +from brainwidemap.encoding.glm_predict import GLMPredictor, predict +from brainwidemap.encoding.utils import load_regressors, single_cluster_raster, find_trial_ids + +import neurencoding.linear as lm +from neurencoding.utils import remove_regressors + + +def plot_twocond( + eid, + pid, + clu_id, + align_time, + aligncol, + aligncond1, + aligncond2, + t_before, + t_after, + regressors, +): + # Load in data and fit model to particular cluster + stdf, sspkt, sspkclu, design, spkmask, nglm = load_unit_fit_model(eid, pid, clu_id) + # Construct GLM prediction object that does our model predictions + pred = GLMPredictor(stdf, nglm, sspkt, sspkclu) + # Construct design matrix without regressors of interest + noreg_dm = remove_regressors(design, regressors) + # Fit model without regressors of interest + nrnglm = lm.LinearGLM( + noreg_dm, sspkt[spkmask], sspkclu[spkmask], estimator=glm_params["estimator"], mintrials=0 + ) + nrnglm.fit() + # Construct GLM prediction object that does model predictions without regressors of interest + nrpred = GLMPredictor(stdf, nrnglm, sspkt, sspkclu) + + # Compute model predictions for each condition + keyset1 = pred.compute_model_psth( + align_time, + t_before, + t_after, + trials=stdf[aligncond1(stdf[aligncol])].index, + ) + cond1pred = pred.full_psths[keyset1][clu_id][0] + keyset2 = pred.compute_model_psth( + align_time, + t_before, + t_after, + trials=stdf[aligncond2(stdf[aligncol])].index, + ) + cond2pred = pred.full_psths[keyset2][clu_id][0] + nrkeyset1 = nrpred.compute_model_psth( + align_time, + t_before, + t_after, + trials=stdf[aligncond1(stdf[aligncol])].index, + ) + nrcond1pred = nrpred.full_psths[nrkeyset1][clu_id][0] + nrkeyset2 = nrpred.compute_model_psth( + align_time, + t_before, + t_after, + trials=stdf[aligncond2(stdf[aligncol])].index, + ) + nrcond2pred = nrpred.full_psths[nrkeyset2][clu_id][0] + + # Plot PSTH of original units and model predictions in both cases + fig, ax = plt.subplots(1, 2, figsize=(8, 4), sharey="row") + x = np.arange(-t_before, t_after, nglm.binwidth) + for rem_regressor in [False, True]: + i = int(rem_regressor) + oldticks = [] + peri_event_time_histogram( + sspkt, + sspkclu, + stdf[aligncond1(stdf[aligncol])][align_time], + clu_id, + t_before, + t_after, + bin_size=nglm.binwidth, + error_bars="sem", + ax=ax[i], + smoothing=0.01, + pethline_kwargs={"color": "blue", "linewidth": 2}, + errbar_kwargs={"color": "blue", "alpha": 0.5}, + ) + oldticks.extend(ax[i].get_yticks()) + peri_event_time_histogram( + sspkt, + sspkclu, + stdf[aligncond2(stdf[aligncol])][align_time], + clu_id, + t_before, + t_after, + bin_size=nglm.binwidth, + error_bars="sem", + ax=ax[i], + smoothing=0.01, + pethline_kwargs={"color": "red", "linewidth": 2}, + errbar_kwargs={"color": "red", "alpha": 0.5}, + ) + oldticks.extend(ax[i].get_yticks()) + pred1 = cond1pred if not rem_regressor else nrcond1pred + pred2 = cond2pred if not rem_regressor else nrcond2pred + ax[i].step(x, pred1, color="darkblue", linewidth=2) + oldticks.extend(ax[i].get_yticks()) + ax[i].step(x, pred2, color="darkred", linewidth=2) + oldticks.extend(ax[i].get_yticks()) + ax[i].set_ylim([0, np.max(oldticks) * 1.1]) + return fig, ax, sspkt, sspkclu, stdf + + +def load_unit_fit_model(eid, pid, clu_id): + stdf, sspkt, sspkclu, _, __ = load_regressors( + eid, + pid, + one, + t_before=0.6, + t_after=0.6, + binwidth=glm_params["binwidth"], + abswheel=True, + ) + design = generate_design(stdf, stdf["probabilityLeft"], t_before=0.6, **glm_params) + spkmask = sspkclu == clu_id + nglm = lm.LinearGLM( + design, sspkt[spkmask], sspkclu[spkmask], estimator=glm_params["estimator"], mintrials=0 + ) + nglm.fit() + return stdf, sspkt, sspkclu, design, spkmask, nglm + + +# Please use the saved parameters dict from 02_fit_sessions.py as params +PLOTPATH = Path("/home/berk/Documents/Projects/results/plots/prediction_summaries") +N_TOP_UNITS = 20 +RAST_BINSIZE = 0.002 +one = ONE() +plt.rcParams["svg.fonttype"] = "none" +# Sets of align_time as key with aligncol, aligncond1/2 functions, t_before/t_after, +# and the name of the associated model regressors as values +alignsets = { + "stimOn_times": ( + "contrastRight", # Column name in df to use for filtering + lambda c: np.isnan(c), # Condition 1 function (left stim) + lambda c: np.isfinite(c), # Condition 2 function (right stim) + 0.1, # Time before align_time to include in trial psth/raster + 0.4, # Time after align_time to include in trial psth/raster + "stimonL", # Condition 1 label within the GLM design matrix + "stimonR", # Condition 2 label within the GLM design matrix + ), + "firstMovement_times": ( + "choice", + lambda c: c == 1, + lambda c: c == -1, + 0.2, + 0.05, + "fmoveL", + "fmoveR", + ), + "feedback_times": ( + "feedbackType", + lambda f: f == 1, + lambda f: f == -1, + 0.1, + 0.4, + "correct", + "incorrect", + ), +} + +glm_params = pd.read_pickle(Path(__file__).parent.joinpath("glm_params.pkl")) + +# Which units we're going to use for plotting +targetunits = { # eid, pid, clu_id, region, drsq, alignset key + "stim": ( + "e0928e11-2b86-4387-a203-80c77fab5d52", # EID + "799d899d-c398-4e81-abaf-1ef4b02d5475", # PID + 209, # clu_id + "VISp", # region + 0.04540706, # drsq (from 02_fit_sessions.py) + "stimOn_times", # Alignset key + ), + "choice": ( + "671c7ea7-6726-4fbe-adeb-f89c2c8e489b", + "04c9890f-2276-4c20-854f-305ff5c9b6cf", + 123, + "GRN", + 0.000992895, # drsq + "firstMovement_times", + ), + "feedback": ( + "a7763417-e0d6-4f2a-aa55-e382fd9b5fb8", + "57c5856a-c7bd-4d0f-87c6-37005b1484aa", + 98, + "IRN", + 0.3077195113, # drsq + "feedback_times", + ), + "block": ( + "7bee9f09-a238-42cf-b499-f51f765c6ded", + "26118c10-35dd-4ab1-9f0f-b9a89a1da070", + 207, + "MOp", + 0.0043285, # drsq + "stimOn_times", + ), +} + + +# This is a hack to use the "find trial IDs" function provided by Mayo to get the trial IDs +# corresponding to each condition for the different variables +sortlookup = {"stim": "side", "choice": "movement", "feedback": "fdbk", "wheel": "movement"} + +for variable, (eid, pid, clu_id, region, drsq, aligntime) in targetunits.items(): + # Skip block (separate logic) and make folder structure before continuing if folders don't exist + if variable == "block": + continue + varfolder = Path(PLOTPATH).joinpath(variable) # Variable gets its own folder + rasterfolder = varfolder.joinpath("rasters") # Rasters in a separate subfolder for variable + if not varfolder.exists(): + varfolder.mkdir() + if not rasterfolder.exists(): + rasterfolder.mkdir() + if not varfolder.joinpath("png").exists(): # PNGs separated too + varfolder.joinpath("png").mkdir() + if not rasterfolder.joinpath("png").exists(): # PNGs separated too + rasterfolder.joinpath("png").mkdir() + aligncol, aligncond1, aligncond2, t_before, t_after, reg1, reg2 = alignsets[aligntime] + fig, ax, sspkt, sspkclu, stdf = plot_twocond( + eid, + pid, + clu_id, + aligntime, + aligncol, + aligncond1, + aligncond2, + t_before, + t_after, + [reg1, reg2] if variable != "wheel" else ["wheel"], + ) + # Wheel only has one regressor, unlike all the others. Handle this properly + if variable != "wheel": + remstr = f"\n[{reg1}, {reg2}] regressors rem." + else: + remstr = "\nwheel regressor rem." + names = [reg1, reg2, reg1 + remstr, reg2 + remstr] + for subax, title in zip(ax, names): + subax.set_title(title) + plt.savefig(varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.svg")) + plt.savefig( + varfolder.joinpath(f"png/{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.png") + ) + plt.close() + + stdf["response_times"] = stdf["stimOn_times"] + trial_idx, dividers = find_trial_ids(stdf, sort=sortlookup[variable]) + fig, ax = single_cluster_raster( + sspkt[sspkclu == clu_id], + stdf[aligntime], + trial_idx, + dividers, + ["b", "r"], + [reg1, reg2], + pre_time=t_before, + post_time=t_after, + raster_cbar=True, + raster_bin=RAST_BINSIZE, + ) + ax.set_title("{} unit {} : $\log \Delta R^2$ = {:.2f}".format(region, clu_id, np.log(drsq))) + plt.savefig(rasterfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_raster.svg")) + plt.savefig( + rasterfolder.joinpath(f"png/{eid}_{pid}_clu{clu_id}_{region}_{variable}_raster.png") + ) + plt.close() + +## Treat block separately since it's a different type of plot +variable = "block" +varfolder = Path(PLOTPATH).joinpath(variable) +if not varfolder.exists(): + varfolder.mkdir() + +eid, pid, clu_id, region, drsq, aligntime = targetunits[variable] +stdf, sspkt, sspkclu, design, spkmask, nglm = load_unit_fit_model(eid, pid, clu_id) +pred, trlabels = predict(nglm, glm_type="linear", retlab=True) +mask = design.dm[:, design.covar["pLeft"]["dmcol_idx"]] != 0 +itipred = pred[clu_id][mask] +iticounts = nglm.binnedspikes[mask, :] +labels = trlabels[mask] +rates = pd.DataFrame( + index=stdf.index[stdf.probabilityLeft != 0.5], + columns=["firing_rate", "pred_rate", "pLeft"], + dtype=float, +) +for p_val in [0.2, 0.8]: + trials = stdf.index[stdf.probabilityLeft == p_val] + for trial in trials: + trialmask = labels == trial + rates.loc[trial, "firing_rate"] = np.mean(iticounts[trialmask]) / design.binwidth + rates.loc[trial, "pred_rate"] = np.mean(itipred[trialmask]) / design.binwidth + rates.loc[trial, "pLeft"] = p_val +fig, ax = plt.subplots(1, 2, figsize=(6, 6), sharey=True) +sns.boxplot(rates, x="pLeft", y="firing_rate", ax=ax[0]) +sns.boxplot(rates, x="pLeft", y="pred_rate", ax=ax[1]) +ax[0].set_title(f"{region} {clu_id} firing rate by block") +ax[1].set_title(f"{region} {clu_id} predicted rate by block") +ax[0].set_ylabel("Firing rate (spikes/s)") +plt.savefig(varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.svg")) +plt.savefig(varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.png")) +plt.close() diff --git a/brainwidemap/encoding/scripts/twocond_plots.py b/brainwidemap/encoding/scripts/twocond_plots.py index c86755d4..c1ff8eec 100644 --- a/brainwidemap/encoding/scripts/twocond_plots.py +++ b/brainwidemap/encoding/scripts/twocond_plots.py @@ -1,63 +1,19 @@ from pathlib import Path -import brainwidemap.encoding.cluster_worker as cw import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns +from one.api import ONE from brainwidemap.bwm_loading import bwm_query from brainwidemap.encoding.glm_predict import GLMPredictor, predict +from brainwidemap.encoding.design import generate_design from brainwidemap.encoding.params import GLM_CACHE, GLM_FIT_PATH -from brainwidemap.encoding.utils import single_cluster_raster, find_trial_ids +from brainwidemap.encoding.utils import single_cluster_raster, find_trial_ids, load_regressors import neurencoding.linear as lm from neurencoding.utils import remove_regressors -# Please use the saved parameters dict form 02_fit_sessions.py as params -PLOTPATH = Path("/home/berk/Documents/Projects/results/plots/prediction_summaries") -N_TOP_UNITS = 20 -plt.rcParams['svg.fonttype'] = 'none' -alignsets = { # Sets of align_time as key with aligncol, aligncond1/2 functions, and t_before/t_after as the values - "stimOn_times": ( - "contrastRight", - lambda c: np.isnan(c), - lambda c: np.isfinite(c), - 0.1, - 0.4, - "stimonL", - "stimonR", - ), - "firstMovement_times": ( - "choice", - lambda c: c == 1, - lambda c: c == -1, - 0.2, - 0.05, - "fmoveL", - "fmoveR", - ), - "feedback_times": ( - "feedbackType", - lambda f: f == 1, - lambda f: f == -1, - 0.1, - 0.4, - "correct", - "incorrect", - ), -} - -targetreg = { # Function to produce the target metric, the target regions, and alignset key for each plottable - "stim": (lambda df: df["stimonR"] - df["stimonL"], ["VISp"], "stimOn_times"), - "choice": (lambda df: df["fmoveR"] - df["fmoveL"], ["GRN"], "firstMovement_times"), - "feedback": (lambda df: df["correct"] - df["incorrect"], ["IRN"], "feedback_times"), - "wheel": (lambda df: df["wheel"], ["GRN"], "firstMovement_times"), - "block": (lambda df: df["pLeft"], ["PL"], "stimOn_times"), -} - -params = pd.read_pickle(GLM_FIT_PATH + "/2023-03-07_glm_fit_pars.pkl") -meanscores = pd.read_pickle(GLM_FIT_PATH + "/2023-03-02_glm_fit.pkl")["mean_fit_results"].set_index("region", append=True) - def plot_twocond( eid, @@ -71,14 +27,7 @@ def plot_twocond( t_after, regressors, ): - sessdf = bwm_query() - subject = sessdf[sessdf["eid"] == eid]["subject"].iloc[0] - eidfn = Path(GLM_CACHE).joinpath(Path(f"{subject}/{eid}/2022-12-22_{pid}_regressors.pkl")) - stdf, sspkt, sspkclu, sclureg, clu_df = cw.get_cached_regressors(eidfn) - design = cw.generate_design(stdf, stdf["probabilityLeft"], t_before=0.6, **params) - spkmask = sspkclu == clu_id - nglm = lm.LinearGLM(design, sspkt[spkmask], sspkclu[spkmask], estimator=params["estimator"], mintrials=0) - nglm.fit() + stdf, sspkt, sspkclu, design, spkmask, nglm = load_unit_fit_model(eid, pid, clu_id) pred = GLMPredictor(stdf, nglm, sspkt, sspkclu) fig, ax = plt.subplots(3, 4, figsize=(12, 12), sharey="row") oldticks = [] @@ -101,7 +50,9 @@ def plot_twocond( ) oldticks.extend(ax[0, 1].get_yticks()) noreg_dm = remove_regressors(design, regressors) - nrnglm = lm.LinearGLM(noreg_dm, sspkt[spkmask], sspkclu[spkmask], estimator=params["estimator"], mintrials=0) + nrnglm = lm.LinearGLM( + noreg_dm, sspkt[spkmask], sspkclu[spkmask], estimator=glm_params["estimator"], mintrials=0 + ) nrnglm.fit() nrpred = GLMPredictor(stdf, nrnglm, sspkt, sspkclu) nrpred.psth_summary( @@ -127,6 +78,76 @@ def plot_twocond( return fig, ax, sspkt, sspkclu, stdf +def load_unit_fit_model(eid, pid, clu_id): + stdf, sspkt, sspkclu, _, __ = load_regressors( + eid, + pid, + one, + t_before=0.6, + t_after=0.6, + binwidth=glm_params["binwidth"], + abswheel=True, + ) + design = generate_design(stdf, stdf["probabilityLeft"], t_before=0.6, **glm_params) + spkmask = sspkclu == clu_id + nglm = lm.LinearGLM( + design, sspkt[spkmask], sspkclu[spkmask], estimator=glm_params["estimator"], mintrials=0 + ) + nglm.fit() + return stdf, sspkt, sspkclu, design, spkmask, nglm + + +# Please use the saved parameters dict from 02_fit_sessions.py as params +PLOTPATH = Path("/home/berk/Documents/Projects/results/plots/prediction_summaries") +N_TOP_UNITS = 20 +RAST_BINSIZE = 0.002 +OVERWRITE = False +one = ONE() +plt.rcParams["svg.fonttype"] = "none" +alignsets = { # Sets of align_time as key with aligncol, aligncond1/2 functions, and t_before/t_after as the values + "stimOn_times": ( + "contrastRight", + lambda c: np.isnan(c), + lambda c: np.isfinite(c), + 0.1, + 0.4, + "stimonL", + "stimonR", + ), + "firstMovement_times": ( + "choice", + lambda c: c == 1, + lambda c: c == -1, + 0.2, + 0.05, + "fmoveL", + "fmoveR", + ), + "feedback_times": ( + "feedbackType", + lambda f: f == 1, + lambda f: f == -1, + 0.1, + 0.4, + "correct", + "incorrect", + ), +} + +targetreg = { # Function to produce the target metric, the target regions, and alignset key for each plottable + "stim": (lambda df: df["stimonR"] - df["stimonL"], ["VISp", "SSp-tr"], "stimOn_times"), + "choice": (lambda df: df["fmoveR"] - df["fmoveL"], ["GRN"], "firstMovement_times"), + "feedback": (lambda df: df["correct"] - df["incorrect"], ["IRN"], "feedback_times"), + "wheel": (lambda df: df["wheel"], ["GRN"], "firstMovement_times"), + "block": (lambda df: df["pLeft"], ["PL", "MOp"], "stimOn_times"), +} + +glm_params = pd.read_pickle(GLM_FIT_PATH + "/2024-07-16_glm_fit_pars.pkl") +meanscores = pd.read_pickle(GLM_FIT_PATH + "/2024-07-16_glm_fit.pkl")[ + "mean_fit_results" +].set_index("region", append=True) + + sortlookup = {"stim": "side", "choice": "movement", "feedback": "fdbk", "wheel": "movement"} for variable, (targetmetricfun, regions, aligntime) in targetreg.items(): @@ -141,20 +162,31 @@ def plot_twocond( targetmetric = targetmetricfun(meanscores) aligncol, aligncond1, aligncond2, t_before, t_after, reg1, reg2 = alignsets[aligntime] for region in regions: - topunits = targetmetric.loc[:, :, :, region].sort_values(ascending=False).iloc[:N_TOP_UNITS] + topunits = ( + targetmetric.loc[:, :, :, region].sort_values(ascending=False).iloc[:N_TOP_UNITS] + ) for (eid, pid, clu_id), drsq in topunits.items(): - fig, ax, sspkt, sspkclu, stdf = plot_twocond( - eid, - pid, - clu_id, - aligntime, - aligncol, - aligncond1, - aligncond2, - t_before, - t_after, - [reg1, reg2] if variable != "wheel" else ["wheel"], + twocond_path = varfolder.joinpath( + f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.svg" ) + if twocond_path.exists() and not OVERWRITE: + print("skipping as {} already exists".format(twocond_path)) + continue + try: + fig, ax, sspkt, sspkclu, stdf = plot_twocond( + eid, + pid, + clu_id, + aligntime, + aligncol, + aligncond1, + aligncond2, + t_before, + t_after, + [reg1, reg2] if variable != "wheel" else ["wheel"], + ) + except: + continue if variable != "wheel": remstr = f"\n[{reg1}, {reg2}] regressors rem." else: @@ -162,46 +194,125 @@ def plot_twocond( names = [reg1, reg2, reg1 + remstr, reg2 + remstr] for subax, title in zip(ax[0, :], names): subax.set_title(title) - plt.savefig(varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.svg")) - plt.savefig(varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.png")) + plt.savefig( + varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.svg") + ) + plt.savefig( + varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.png") + ) plt.close() stdf["response_times"] = stdf["stimOn_times"] trial_idx, dividers = find_trial_ids(stdf, sort=sortlookup[variable]) fig, ax = single_cluster_raster( - sspkt[sspkclu == clu_id], stdf[aligntime], trial_idx, dividers, ["b", "r"], [reg1, reg2], - pre_time=t_before, post_time=t_after, + sspkt[sspkclu == clu_id], + stdf[aligntime], + trial_idx, + dividers, + ["b", "r"], + [reg1, reg2], + pre_time=t_before, + post_time=t_after, + raster_cbar=True, + raster_bin=RAST_BINSIZE, + ) + ax.set_title( + "{} unit {} : $\log \Delta R^2$ = {:.2f}".format(region, clu_id, np.log(drsq)) + ) + plt.savefig( + rasterfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_raster.svg") + ) + plt.savefig( + rasterfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_raster.png") ) - plt.savefig(rasterfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_raster.svg")) - plt.savefig(rasterfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_raster.png")) plt.close() - ## Treat block separately since it's a different type of plot variable = "block" targetmetricfun, regions, aligntime = targetreg["block"] varfolder = Path(PLOTPATH).joinpath(variable) +rasterfolder = varfolder.joinpath("rasters") if not varfolder.exists(): varfolder.mkdir() +if not rasterfolder.exists(): + rasterfolder.mkdir() +if not varfolder.joinpath("png").exists(): # PNGs separated too + varfolder.joinpath("png").mkdir() +if not rasterfolder.joinpath("png").exists(): # PNGs separated too + rasterfolder.joinpath("png").mkdir() targetmetric = targetmetricfun(meanscores) +block_colors = {0.5: "gray", 0.8: "b", 0.2: "r"} for region in regions: topunits = targetmetric.loc[:, :, :, region].sort_values(ascending=False).iloc[:N_TOP_UNITS] for (eid, pid, clu_id), drsq in topunits.items(): + twocond_path = varfolder.joinpath( + f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.svg" + ) + if twocond_path.exists() and not OVERWRITE: + print("skipping as {} already exists".format(twocond_path)) + continue sessdf = bwm_query() subject = sessdf[sessdf["eid"] == eid]["subject"].iloc[0] eidfn = Path(GLM_CACHE).joinpath(Path(f"{subject}/{eid}/2022-12-22_{pid}_regressors.pkl")) - stdf, sspkt, sspkclu, sclureg, clu_df = cw.get_cached_regressors(eidfn) - design = cw.generate_design(stdf, stdf["probabilityLeft"], t_before=0.6, **params) + stdf, sspkt, sspkclu, sclureg, clu_df = load_regressors(eid, pid, one, t_before=0.6) + design = generate_design(stdf, stdf["probabilityLeft"], t_before=0.6, **glm_params) spkmask = sspkclu == clu_id - nglm = lm.LinearGLM(design, sspkt[spkmask], sspkclu[spkmask], estimator=params["estimator"], mintrials=0) + if np.all(spkmask == False): + continue + nglm = lm.LinearGLM( + design, sspkt[spkmask], sspkclu[spkmask], estimator=glm_params["estimator"], mintrials=0 + ) nglm.fit() + + fig, ax = plt.subplots(1, 2, figsize=(10, 5)) + trial_idx, dividers = find_trial_ids(stdf, sort="block") + + _, __ = single_cluster_raster( + sspkt[spkmask], + stdf[aligntime], + trial_idx, + dividers, + ["b", "r"], + ["P(Right) = 0.8", "P(Right) = 0.2"], + pre_time=0.5, + post_time=0, + raster_cbar=True, + raster_bin=RAST_BINSIZE, + axs=ax[0] + ) + # ax[0].vlines(-0.5, 0, len(trial_idx), color="k", linestyle="--") + # ax[0].vlines(-0.1, 0, len(trial_idx), color="k", linestyle="--") + block_dividers = np.nonzero(np.diff(stdf["probabilityLeft"]))[0] + block_values = stdf["probabilityLeft"].iloc[[*block_dividers, block_dividers[-1] + 1]] + colors = [block_colors[val] for val in block_values] + _, __ = single_cluster_raster( + sspkt[spkmask], + stdf[aligntime], + range(len(stdf.index)), + list(block_dividers), + colors, + block_values.astype(str).to_list(), + pre_time=0.5, + post_time=0, + raster_cbar=True, + raster_bin=RAST_BINSIZE, + axs=ax[1] + ) + + plt.savefig( + rasterfolder.joinpath(f"png/{eid}_{pid}_clu{clu_id}_{region}_{variable}_raster.png") + ) + plt.close() + pred, trlabels = predict(nglm, glm_type="linear", retlab=True) mask = design.dm[:, design.covar["pLeft"]["dmcol_idx"]] != 0 itipred = pred[clu_id][mask] iticounts = nglm.binnedspikes[mask, :] labels = trlabels[mask] rates = pd.DataFrame( - index=stdf.index[stdf.probabilityLeft != 0.5], columns=["firing_rate", "pred_rate", "pLeft"], dtype=float + index=stdf.index[stdf.probabilityLeft != 0.5], + columns=["firing_rate", "pred_rate", "pLeft"], + dtype=float, ) for p_val in [0.2, 0.8]: trials = stdf.index[stdf.probabilityLeft == p_val] @@ -216,5 +327,10 @@ def plot_twocond( ax[0].set_title(f"{region} {clu_id} firing rate by block") ax[1].set_title(f"{region} {clu_id} predicted rate by block") ax[0].set_ylabel("Firing rate (spikes/s)") - plt.savefig(varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.{IMGFMT}")) + plt.savefig( + varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.svg") + ) + plt.savefig( + varfolder.joinpath(f"{eid}_{pid}_clu{clu_id}_{region}_{variable}_predsummary.png") + ) plt.close() diff --git a/brainwidemap/encoding/utils.py b/brainwidemap/encoding/utils.py index 40b0af0d..6fa96dfa 100644 --- a/brainwidemap/encoding/utils.py +++ b/brainwidemap/encoding/utils.py @@ -11,24 +11,27 @@ import matplotlib.pyplot as plt # IBL libraries +from one.api import ONE from iblutil.util import Bunch import brainbox.io.one as bbone from brainbox.io.one import SessionLoader # Brainwidemap repo imports -from brainwidemap.bwm_loading import load_trials_and_mask, bwm_units +from brainwidemap.bwm_loading import load_good_units, load_trials_and_mask, bwm_units from brainwidemap.encoding.timeseries import TimeSeries, sync def load_regressors( session_id, pid, - one, + one=None, t_before=0.0, t_after=0.2, binwidth=0.02, abswheel=False, clu_criteria="bwm", + one_url="https://openalyx.internationalbrainlab.org", + one_pw="international" ): """ Load in regressors for given session and probe. Returns a dictionary with the following keys: @@ -61,6 +64,9 @@ def load_regressors( trialsdf, spk_times, spk_clu, clu_regions, clu_qc, clu_df, clu_qc (optional) Output regressors for GLM """ + if one is None: + one = ONE(base_url=one_url, password=one_pw, silent=True) + _, mask = load_trials_and_mask(one=one, eid=session_id) mask = mask.index[np.nonzero(mask.values)] trialsdf = load_trials_df( @@ -75,34 +81,12 @@ def load_regressors( trials_mask=mask, ) - clusters = {} - ssl = bbone.SpikeSortingLoader(one=one, pid=pid, eid=session_id) - origspikes, tmpclu, channels = ssl.load_spike_sorting() - if "metrics" not in tmpclu: - tmpclu["metrics"] = np.ones(tmpclu["channels"].size) - clusters[pid] = ssl.merge_clusters(origspikes, tmpclu, channels) - clu_df = pd.DataFrame(clusters[pid]).set_index(["cluster_id"]) - clu_df["pid"] = pid - if clu_criteria == "bwm": - allunits = ( - bwm_units(one=one) - .rename(columns={"cluster_id": "clu_id"}) - .set_index(["eid", "pid", "clu_id"]) - ) - keepclu = clu_df.index.intersection(allunits.loc[session_id, pid, :].index) - elif clu_criteria == "all": - keepclu = clu_df.index - else: - raise ValueError("clu_criteria must be 'bwm' or 'all'") - - clu_df = clu_df.loc[keepclu] - keepmask = np.isin(origspikes.clusters, keepclu) - spikes = Bunch({k: v[keepmask] for k, v in origspikes.items()}) - sortinds = np.argsort(spikes.times) - spk_times = spikes.times[sortinds] - spk_clu = spikes.clusters[sortinds] - clu_regions = clusters[pid].acronym + spikes, clu_df = load_good_units(one=one, pid=pid) + spk_times = spikes["times"] + spk_clu = spikes["clusters"] + clu_df["pid"] = pid + clu_regions = clu_df.acronym return trialsdf, spk_times, spk_clu, clu_regions, clu_df @@ -188,7 +172,8 @@ def make_batch_slurm_singularity( fw.write(f"#SBATCH --cpus-per-task={cores_per_job}\n") fw.write(f"#SBATCH --mem={memory}\n") fw.write("\n") - fw.write(f"module load {' '.join(singularity_modules)}\n") + if not len(singularity_modules) == 0: + fw.write(f"module load {' '.join(singularity_modules)}\n") bindstr = "" if len(mount_paths) == 0 else "-B " mountpairs = ",".join([f"{k}:{v}" for k, v in mount_paths.items()]) fw.write(f"singularity run {bindstr} {mountpairs} {container_image} /bin/bash {workerscript}") @@ -205,7 +190,6 @@ def make_batch_slurm_singularity( def load_trials_df( eid, - one, t_before=0.0, t_after=0.2, ret_wheel=False, @@ -213,6 +197,7 @@ def load_trials_df( wheel_binsize=0.02, addtl_types=[], trials_mask=None, + one=None, ): """ Generate a pandas dataframe of per-trial timing information about a given session. @@ -257,6 +242,8 @@ def load_trials_df( have a monotonic index. Has special columns trial_start and trial_end which define start and end times via t_before and t_after """ + if one is None: + raise ValueError("one must be defined.") if ret_wheel and ret_abswheel: raise ValueError("ret_wheel and ret_abswheel cannot both be true.") @@ -341,6 +328,8 @@ def single_cluster_raster( pre_time=0.4, post_time=1.0, raster_bin=0.01, + raster_cbar=False, + raster_interp="none", show_psth=False, psth_bin=0.05, weights=None, @@ -371,6 +360,10 @@ def single_cluster_raster( Time after event to plot, by default 1. raster_bin : float, optional Time bin size for the raster, by default 0.01 + raster_cbar : bool, optional + Whether to include a color bar for the raster, which uses binned spike counts. + raster_interp : str, optional + Passed to matplotlib.pyplot.imshow, by default "none" psth : bool, optional Whether to plot the PSTH, by default False psth_bin : float, optional @@ -392,7 +385,7 @@ def single_cluster_raster( """ raster, t_raster = bin_spikes( spike_times, - events, + events.values, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, @@ -400,7 +393,7 @@ def single_cluster_raster( ) psth, t_psth = bin_spikes( spike_times, - events, + events.values, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, @@ -466,6 +459,7 @@ def single_cluster_raster( cmap="binary", origin="lower", extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], + interpolation=raster_interp, aspect="auto", ) @@ -479,6 +473,9 @@ def single_cluster_raster( ) raster_ax.set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width]) + raster_ax.set_yticks(dividers) + if raster_cbar: + plt.colorbar(raster_ax.get_images()[0], ax=raster_ax, label="Spike count") secax = raster_ax.secondary_yaxis("right") secax.set_yticks(label_pos) diff --git a/pyproject.toml b/pyproject.toml index e352c2cb..7e7b64c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,8 @@ import_heading_stdlib = "Standard library" import_heading_thirdparty = "Third party libraries" import_heading_firstparty = "IBL libraries" import_heading_localfolder = "Brainwidemap repo imports" -line_length = 130 -wrap_length = 130 +line_length = 99 +wrap_length = 99 [tool.black] line-length = 99 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8e748d4c..a4b8d8e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ ibllib -ipython \ No newline at end of file +ipython