diff --git a/constants.py b/constants.py deleted file mode 100644 index cc9d0cd..0000000 --- a/constants.py +++ /dev/null @@ -1,6 +0,0 @@ - -FEATURES = [ - 'CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE', - 'CYAN', 'GREEN', 'MAGENTA', 'YELLOW', - 'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL' -] \ No newline at end of file diff --git a/format_beh.ipynb b/format_beh.ipynb deleted file mode 100644 index e64fb7f..0000000 --- a/format_beh.ipynb +++ /dev/null @@ -1,168 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "from spike_tools import (\n", - " general as spike_general,\n", - " analysis as spike_analysis,\n", - ")\n", - "from constants import FEATURES\n", - "\n", - "species = 'nhp'\n", - "subject = 'SA'\n", - "exp = 'WCST'\n", - "session = 20180802 # this is the session for which there are spikes at the moment. " - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "behavior_file = spike_general.get_behavior_path(subject, session)\n", - "behavior_data = pd.read_csv(\"/data/sub-SA_sess-20180802_object_features.csv\")\n", - "valid_beh = behavior_data[behavior_data.Response.isin([\"Correct\", \"Incorrect\"])]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 74, - "metadata": {}, - "outputs": [], - "source": [ - "def get_X_by_bins(bin_size, data):\n", - " max_time = np.max(valid_beh[\"TrialEnd\"].values)\n", - " max_bin_idx = int(max_time / bin_size) + 1\n", - " columns = FEATURES + [\"CORRECT\", \"INCORRECT\"]\n", - " types = [\"f4\" for _ in columns]\n", - " zipped = list(zip(columns, types))\n", - " dtype = np.dtype(zipped)\n", - " arr = np.zeros((max_bin_idx), dtype=dtype)\n", - "\n", - " for _, row in data.iterrows():\n", - " # grab features of item chosen\n", - " item_chosen = int(row[\"ItemChosen\"])\n", - " color = row[f\"Item{item_chosen}Color\"]\n", - " shape = row[f\"Item{item_chosen}Shape\"]\n", - " pattern = row[f\"Item{item_chosen}Pattern\"]\n", - "\n", - " chosen_time = row[\"FeedbackOnset\"] - 800\n", - " chosen_bin = int(chosen_time / bin_size)\n", - " arr[chosen_bin][color] = 1\n", - " arr[chosen_bin][shape] = 1\n", - " arr[chosen_bin][pattern] = 1\n", - "\n", - " feedback_bin = int(row[\"FeedbackOnset\"] / bin_size)\n", - " # print(feedback_bin)\n", - " if row[\"Response\"] == \"Correct\":\n", - " arr[feedback_bin][\"CORRECT\"] = 1\n", - " else:\n", - " arr[feedback_bin][\"INCORRECT\"] = 1\n", - " df = pd.DataFrame(arr)\n", - " df[\"bin_idx\"] = np.arange(len(df))\n", - " return df\n", - " \n" - ] - }, - { - "cell_type": "code", - "execution_count": 75, - "metadata": {}, - "outputs": [], - "source": [ - "res = get_X_by_bins(50, valid_beh)" - ] - }, - { - "cell_type": "code", - "execution_count": 76, - "metadata": {}, - "outputs": [], - "source": [ - "res.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Grab bin idxs of interval around fb onset" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [], - "source": [ - "def get_trial_intervals(behavioral_data, event=\"FeedbackOnset\", pre_interval=0, post_interval=0, bin_size=50):\n", - " \"\"\"Per trial, finds time interval surrounding some event in the behavioral data\n", - "\n", - " Args:\n", - " behavioral_data: Dataframe describing each trial, must contain\n", - " columns: TrialNumber, whatever 'event' param describes\n", - " event: name of event to align around, must be present as a\n", - " column name in behavioral_data Dataframe\n", - " pre_interval: number of miliseconds before event\n", - " post_interval: number of miliseconds after event\n", - "\n", - " Returns:\n", - " DataFrame with num_trials length, columns: TrialNumber,\n", - " IntervalStartTime, IntervalEndTime\n", - " \"\"\"\n", - " trial_event_times = behavioral_data[[\"TrialNumber\", event]]\n", - "\n", - " intervals = np.empty((len(trial_event_times), 3))\n", - " intervals[:, 0] = trial_event_times[\"TrialNumber\"]\n", - " intervals[:, 1] = trial_event_times[event] - pre_interval\n", - " intervals[:, 2] = trial_event_times[event] + post_interval\n", - " intervals_df = pd.DataFrame(columns=[\"TrialNumber\", \"IntervalStartTime\", \"IntervalEndTime\"])\n", - " intervals_df[\"TrialNumber\"] = trial_event_times[\"TrialNumber\"].astype(int)\n", - " intervals_df[\"IntervalStartTime\"] = trial_event_times[event] - pre_interval\n", - " intervals_df[\"IntervalEndTime\"] = trial_event_times[event] + post_interval\n", - " intervals_df[\"IntervalStartBin\"] = (intervals_df[\"IntervalStartTime\"] / bin_size).astype(int)\n", - " intervals_df[\"IntervalEndBin\"] = (intervals_df[\"IntervalEndTime\"] / bin_size).astype(int)\n", - " return intervals_df\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/create_design_matrix.ipynb b/notebooks/create_design_matrix.ipynb new file mode 100644 index 0000000..a0f51ed --- /dev/null +++ b/notebooks/create_design_matrix.ipynb @@ -0,0 +1,91 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Notebook to create and store a design matrix of behavior and spikes " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from spike_tools import (\n", + " general as spike_general,\n", + " analysis as spike_analysis,\n", + ")\n", + "import wcst_encode.data_utils as data_utils\n", + "from wcst_encode.constants import COLUMN_NAMES\n", + "\n", + "species = 'nhp'\n", + "subject = 'SA'\n", + "exp = 'WCST'\n", + "session = 20180802 # this is the session for which there are spikes at the moment. \n", + "\n", + "tau_pre = 20\n", + "tau_post = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "spikes_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_spike_counts_binsize_50.pickle')\n", + "beh_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')\n", + "intervals = pd.read_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "column_names_w_units = COLUMN_NAMES + spikes_by_bins.columns[1:].tolist()\n", + "design_mat = data_utils.get_design_matrix(spikes_by_bins, beh_by_bins, column_names_w_units, tau_pre, tau_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "design_mat.to_pickle(\"/data/processed/sub-SA_sess-20180802_design_mat_taupre_20_taupost_0_binsize_50.pickle\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/format_beh.ipynb b/notebooks/format_beh.ipynb new file mode 100644 index 0000000..9c5a4f5 --- /dev/null +++ b/notebooks/format_beh.ipynb @@ -0,0 +1,114 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from spike_tools import (\n", + " general as spike_general,\n", + " analysis as spike_analysis,\n", + ")\n", + "import wcst_encode.data_utils as data_utils\n", + "from wcst_encode.constants import FEATURES\n", + "\n", + "species = 'nhp'\n", + "subject = 'SA'\n", + "exp = 'WCST'\n", + "session = 20180802 # this is the session for which there are spikes at the moment. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "behavior_file = spike_general.get_behavior_path(subject, session)\n", + "behavior_data = pd.read_csv(\"/data/sub-SA_sess-20180802_object_features.csv\")\n", + "valid_beh = behavior_data[behavior_data.Response.isin([\"Correct\", \"Incorrect\"])]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "behavior_by_bins = data_utils.get_behavior_by_bins(50, valid_beh)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "behavior_by_bins.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Grab bin idxs of interval around fb onset" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "intervals = data_utils.get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [], + "source": [ + "intervals.to_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/format_spikes.ipynb b/notebooks/format_spikes.ipynb similarity index 100% rename from format_spikes.ipynb rename to notebooks/format_spikes.ipynb diff --git a/wcst_encode/constants.py b/wcst_encode/constants.py new file mode 100644 index 0000000..13c99e9 --- /dev/null +++ b/wcst_encode/constants.py @@ -0,0 +1,13 @@ +# useful constants during analysis + +FEATURES = [ + 'CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE', + 'CYAN', 'GREEN', 'MAGENTA', 'YELLOW', + 'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL' +] + +COLUMN_NAMES = FEATURES + ["CORRECT", "INCORRECT"] + +# time in miliseconds for required fixation on a card to register a choice +# also time between choice and feedback signals +CHOICE_FIXATION_TIME = 800 \ No newline at end of file diff --git a/wcst_encode/data_utils.py b/wcst_encode/data_utils.py new file mode 100644 index 0000000..2fead6b --- /dev/null +++ b/wcst_encode/data_utils.py @@ -0,0 +1,129 @@ +from .constants import FEATURES, CHOICE_FIXATION_TIME +import numpy as np +import math +import pandas as pd +from itertools import repeat + +def get_behavior_by_bins(bin_size, beh): + """ + bin_size: in miliseconds, bin size + data: dataframe for behavioral data from object features csv + Returns: new dataframe with one-hot encoding of features, feedback + """ + max_time = np.max(beh["TrialEnd"].values) + max_bin_idx = math.ceil(max_time / bin_size) + columns = FEATURES + ["CORRECT", "INCORRECT"] + zipped = list(zip(columns, repeat("f4"))) + dtype = np.dtype(zipped) + arr = np.zeros((max_bin_idx), dtype=dtype) + + for _, row in beh.iterrows(): + # grab features of item chosen + item_chosen = int(row["ItemChosen"]) + color = row[f"Item{item_chosen}Color"] + shape = row[f"Item{item_chosen}Shape"] + pattern = row[f"Item{item_chosen}Pattern"] + + chosen_time = row["FeedbackOnset"] - CHOICE_FIXATION_TIME + chosen_bin = math.floor(chosen_time / bin_size) + arr[chosen_bin][color] = 1 + arr[chosen_bin][shape] = 1 + arr[chosen_bin][pattern] = 1 + + feedback_bin = int(np.floor(row["FeedbackOnset"] / bin_size)) + # print(feedback_bin) + if row["Response"] == "Correct": + arr[feedback_bin]["CORRECT"] = 1 + elif row["Response"] == "Incorrect": + arr[feedback_bin]["INCORRECT"] = 1 + else: + raise ValueError(f"{row['Response']} is undefined") + df = pd.DataFrame(arr) + df["bin_idx"] = np.arange(len(df)) + return df + + +def get_spikes_by_bins(bin_size, spike_times): + """Given a bin_size and a series of spike times, return spike counts by bin. + Args: + bin_size: size of bins in miliseconds + spike_times: dataframe with unit_id, spike times. + Returns: + df with bin_idx, unit_* as columns, filled with spike counts + """ + + units = np.unique(spike_times.UnitID.values) + num_time_bins = int(spike_times.SpikeTime.max() / bin_size) + 1 + bin_edges = np.arange(num_time_bins) * bin_size + + df = pd.DataFrame(data={'bin_idx': np.arange(num_time_bins)[:-1]}) + for unit in units: + unit_spike_times = spike_times[spike_times.UnitID==unit].SpikeTime.values + unit_spike_counts, _ = np.histogram(unit_spike_times, bins=bin_edges) + df[f'unit_{unit}'] = unit_spike_counts + return df + +def get_trial_intervals(behavioral_data, event="FeedbackOnset", pre_interval=0, post_interval=0, bin_size=50): + """Per trial, finds time interval surrounding some event in the behavioral data + + Args: + behavioral_data: Dataframe describing each trial, must contain + columns: TrialNumber, as well as the column corresponding to the `event` parameter + event: name of event to align around, must be present as a + column name in behavioral_data Dataframe + pre_interval: number of miliseconds before the event to include. Should be >= 0 + post_interval: number of miliseconds after the event to include. Should be >= 0 + + Returns: + DataFrame with num_trials length, columns: TrialNumber, + IntervalStartTime, IntervalEndTime + """ + if pre_interval >= 0 or post_interval >= 0: + raise ValueError("Neither pre_interval: {pre_interval} or post_interval: {post_interval} should be negative") + + trial_event_times = behavioral_data[["TrialNumber", event]] + + intervals_df = pd.DataFrame(columns=["TrialNumber", "IntervalStartTime", "IntervalEndTime"]) + intervals_df["TrialNumber"] = trial_event_times["TrialNumber"].astype(int) + intervals_df["IntervalStartTime"] = trial_event_times[event] - pre_interval + intervals_df["IntervalEndTime"] = trial_event_times[event] + post_interval + intervals_df["IntervalStartBin"] = np.floor(intervals_df["IntervalStartTime"] / bin_size).astype(int) + intervals_df["IntervalEndBin"] = np.floor(intervals_df["IntervalEndTime"] / bin_size).astype(int) + return intervals_df + + +def get_design_matrix(spikes_by_bins, beh_by_bins, columns, tau_pre, tau_post): + """ + Reformats data as a design matrix dataframe, where for each of the specified columns, + additional columns are added for each of the time points between tau_pre and tau_post + Args: + spike_by_bins: df with bin_idx, unit_* as columns + beh_by_bins: df with bin_idx, behavioral vars of interest as columns + columns: columns to include, must be present in either spike_by_bins or beh_by_bins + tau_pre: number of bins to look in the past + tau_post: number of bins to look in the future + Returns: + df with bin_idx, columns for each time points between tau_pre and tau_post + missing time shift values will be filled with nans + """ + joint = pd.merge(spikes_by_bins, beh_by_bins, on="bin_idx", how="inner") + res = pd.DataFrame() + taus = np.arange(-tau_pre, tau_post) + for tau in taus: + shift_idx = -1 * tau + column_names = [f"{x}_{tau}" for x in columns] + res[column_names] = joint.shift(shift_idx)[columns] + res["bin_idx"] = joint["bin_idx"] + return res + + +def get_interval_bins(intervals): + """ + Gets all the bins belonging to all the intervals + Args: + intervals: df with trialnumber, IntervalStartBin, IntervalEndBin + Returns: + np array of all bins for all trials falling between startbin and endbin + """ + interval_bins = intervals.apply(lambda x: np.arange(x.IntervalStartBin, x.IntervalEndBin).astype(int), axis=1) + return np.concatenate(interval_bins.to_numpy()) \ No newline at end of file