Skip to content

Commit

Permalink
Refactored replay_session into Simulator object basic functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
sreekaroo committed Sep 17, 2023
1 parent f6f5bb0 commit eaf4bf7
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 0 deletions.
Empty file added bcipy/simulator/__init__.py
Empty file.
135 changes: 135 additions & 0 deletions bcipy/simulator/helpers/signal_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import json
import logging as logger
from dataclasses import dataclass
from typing import Tuple
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from bcipy.config import (
RAW_DATA_FILENAME,
TRIGGER_FILENAME,
DEFAULT_PARAMETER_FILENAME, SESSION_DATA_FILENAME,
DEFAULT_DEVICE_SPEC_FILENAME,
)
from bcipy.helpers.acquisition import analysis_channels
import bcipy.acquisition.devices as devices
from bcipy.helpers.list import grouper
from bcipy.helpers.load import load_json_parameters, load_raw_data, load_experimental_data
from bcipy.helpers.session import read_session, evidence_records
from bcipy.helpers.stimuli import update_inquiry_timing
from bcipy.helpers.triggers import TriggerType, trigger_decoder
from bcipy.helpers.symbols import alphabet
from bcipy.signal.model import PcaRdaKdeModel
from bcipy.signal.process import get_default_transform, filter_inquiries, ERPTransformParams

logger.getLogger().setLevel(logger.INFO)
log = logger.getLogger(__name__)


@dataclass()
class ExtractedExperimentData: # TODO clean up design
inquiries: np.ndarray
trials: np.ndarray
labels: list
inquiry_timing: list

decoded_triggers: tuple


def process_raw_data_for_model(data_folder, parameters, model_class=PcaRdaKdeModel) -> ExtractedExperimentData:

assert parameters, "Parameters are required for offline analysis."
if not data_folder:
data_folder = load_experimental_data()

# extract relevant session information from parameters file
trial_window = parameters.get("trial_window")
window_length = trial_window[1] - trial_window[0]

prestim_length = parameters.get("prestim_length")
trials_per_inquiry = parameters.get("stim_length")
# The task buffer length defines the min time between two inquiries
# We use half of that time here to buffer during transforms
buffer = int(parameters.get("task_buffer_length") / 2)
raw_data_file = f"{RAW_DATA_FILENAME}.csv"

# get signal filtering information
transform_params = parameters.instantiate(ERPTransformParams)
downsample_rate = transform_params.down_sampling_rate
static_offset = parameters.get("static_trigger_offset")

log.info(
f"\nData processing settings: \n"
f"{str(transform_params)} \n"
f"Trial Window: {trial_window[0]}-{trial_window[1]}s, "
f"Prestimulus Buffer: {prestim_length}s, Poststimulus Buffer: {buffer}s \n"
f"Static offset: {static_offset}"
)

# Load raw data
raw_data = load_raw_data(Path(data_folder, raw_data_file))
channels = raw_data.channels
type_amp = raw_data.daq_type
sample_rate = raw_data.sample_rate

devices.load(Path(data_folder, DEFAULT_DEVICE_SPEC_FILENAME))
device_spec = devices.preconfigured_device(raw_data.daq_type)

# setup filtering
default_transform = get_default_transform(
sample_rate_hz=sample_rate,
notch_freq_hz=transform_params.notch_filter_frequency,
bandpass_low=transform_params.filter_low,
bandpass_high=transform_params.filter_high,
bandpass_order=transform_params.filter_order,
downsample_factor=transform_params.down_sampling_rate,
)

log.info(f"Channels read from csv: {channels}")
log.info(f"Device type: {type_amp}, fs={sample_rate}")

k_folds = parameters.get("k_folds")
model = model_class(k_folds=k_folds)

# Process triggers.txt files
trigger_targetness, trigger_timing, trigger_symbols = trigger_decoder(
offset=static_offset,
trigger_path=f"{data_folder}/{TRIGGER_FILENAME}",
exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION],
)

# update the trigger timing list to account for the initial trial window
corrected_trigger_timing = [timing + trial_window[0] for timing in trigger_timing]

# Channel map can be checked from raw_data.csv file or the devices.json located in the acquisition module
# The timestamp column [0] is already excluded.
channel_map = analysis_channels(channels, device_spec)
channels_used = [channels[i] for i, keep in enumerate(channel_map) if keep == 1]
log.info(f'Channels used in analysis: {channels_used}')

data, fs = raw_data.by_channel()

inquiries, inquiry_labels, inquiry_timing = model.reshaper(
trial_targetness_label=trigger_targetness,
timing_info=corrected_trigger_timing,
eeg_data=data,
sample_rate=sample_rate,
trials_per_inquiry=trials_per_inquiry,
channel_map=channel_map,
poststimulus_length=window_length,
prestimulus_length=prestim_length,
transformation_buffer=buffer,
)

inquiries, fs = filter_inquiries(inquiries, default_transform, sample_rate)
inquiry_timing = update_inquiry_timing(inquiry_timing, downsample_rate)
trial_duration_samples = int(window_length * fs)
trials = model.reshaper.extract_trials(inquiries, trial_duration_samples, inquiry_timing)

# define the training classes using integers, where 0=nontargets/1=targets
# labels = inquiry_labels.flatten()

return ExtractedExperimentData(inquiries, trials, inquiry_labels, inquiry_timing, (trigger_targetness, trigger_timing, trigger_symbols))
68 changes: 68 additions & 0 deletions bcipy/simulator/helpers/sim_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging as logger

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from bcipy.helpers.parameters import Parameters

logger.getLogger().setLevel(logger.INFO)


def plot_comparison_records(records, outdir, title="response_values", y_scale="log"):
df = pd.DataFrame.from_records(records)
ax = sns.stripplot(
x="which_model",
y="response_value",
data=df,
order=["old_target", "new_target", "old_non_target", "new_non_target"],
)
sns.boxplot(
showmeans=True,
meanline=True,
meanprops={"color": "k", "ls": "-", "lw": 2},
medianprops={"visible": False},
whiskerprops={"visible": False},
zorder=10,
x="which_model",
y="response_value",
data=df,
showfliers=False,
showbox=False,
showcaps=False,
ax=ax,
order=["old_target", "new_target", "old_non_target", "new_non_target"],
)

ax.set(yscale=y_scale)
plt.savefig(outdir / f"{title}.stripplot.png", dpi=150, bbox_inches="tight")
plt.close()
ax = sns.boxplot(
x="which_model",
y="response_value",
data=df,
order=["old_target", "new_target", "old_non_target", "new_non_target"],
)
ax.set(yscale=y_scale)
plt.savefig(outdir / f"{title}.boxplot.png", dpi=150, bbox_inches="tight")


def plot_replay_comparison(new_target_eeg_evidence: np.ndarray,
new_non_target_eeg_evidence: np.ndarray,
old_target_eeg_evidence: np.ndarray,
old_non_target_eeg_evidence: np.ndarray,
outdir: str,
parameters: Parameters,
) -> None:
def convert_to_records(arr, key_value, key_name="which_model", value_name="response_value") -> [dict]:
return [{key_name: key_value, value_name: val} for val in arr]

records = []

records.extend(convert_to_records(new_target_eeg_evidence, "new_target"))
records.extend(convert_to_records(new_non_target_eeg_evidence, "new_non_target"))
records.extend(convert_to_records(old_target_eeg_evidence, "old_target"))
records.extend(convert_to_records(old_non_target_eeg_evidence, "old_non_target"))

plot_comparison_records(records, outdir, y_scale="log")
52 changes: 52 additions & 0 deletions bcipy/simulator/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import argparse
from pathlib import Path

from bcipy.helpers.load import load_json_parameters
from bcipy.simulator.sim_factory import SimulationFactory

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--data_folders",
action="append",
type=Path,
required=True,
help="Session data folders to be processed. This argument can be repeated to accumulate sessions data.")
parser.add_argument(
"-sm",
"--smodel_files",
action="append",
type=Path,
required=True,
help="Signal models to be used")
parser.add_argument(
"-lm",
"--lmodel_file",
action="append",
type=Path,
required=False,
help="Language models to be used")
parser.add_argument("-o", "--out_dir", type=Path, default=None)
parser.add_argument(
"-p",
"--parameter_path",
type=Path,
default=None,
help="Parameter file to be used for replay. If none, the session parameter file will be used.")

args = vars(parser.parse_args())

# assert len(set(args['data_folders'])) == len(args.data_folders), "Duplicated data folders"

if args['out_dir'] is None:
args['out_dir'] = Path(__file__).resolve().parent

# Load parameters
sim_parameters = load_json_parameters("bcipy/simulator/sim_parameters.json", value_cast=True)
sim_task = sim_parameters.get("sim_task")
args['sim_task'] = sim_task

simulator = SimulationFactory.create(**args)
simulator.run()
Loading

0 comments on commit eaf4bf7

Please sign in to comment.