Skip to content

Commit

Permalink
Merge pull request #713 from bids-standard/round
Browse files Browse the repository at this point in the history
Resample to n_vols for sampling_rate == 'TR'
  • Loading branch information
adelavega authored Apr 12, 2021
2 parents 995ae8e + c0551ea commit 3ac0f92
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 24 deletions.
4 changes: 2 additions & 2 deletions bids/analysis/tests/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def sparse_run_variable_with_missing_values():
'duration': [1.2, 1.6, 0.8, 2],
'amplitude': [1, 1, np.nan, 1]
})
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz')]
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz', 10)]
var = SparseRunVariable(
name='var', data=data, run_info=run_info, source='events')
return BIDSRunVariableCollection([var])
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_convolve_impulse():
'duration': [0, 0],
'amplitude': [1, 1]
})
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz')]
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz', 10)]
var = SparseRunVariable(
name='var', data=data, run_info=run_info, source='events')
coll = BIDSRunVariableCollection([var])
Expand Down
10 changes: 6 additions & 4 deletions bids/variables/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _densify_and_resample(self, sampling_rate=None, variables=None,
resample_dense=False, force_dense=False,
in_place=False, kind='linear'):

sampling_rate = self._get_sampling_rate(sampling_rate)
sr = self._get_sampling_rate(sampling_rate)

_dense, _sparse = [], []

Expand All @@ -375,11 +375,13 @@ def _densify_and_resample(self, sampling_rate=None, variables=None,
if force_dense:
for v in _sparse:
if is_numeric_dtype(v.values):
_variables[v.name] = v.to_dense(sampling_rate)
_variables[v.name] = v.to_dense(sr)

if resample_dense:
# Propagate 'TR' if exact match to TR is required
sr_arg = sampling_rate if sampling_rate == 'TR' else sr
for v in _dense:
_variables[v.name] = v.resample(sampling_rate, kind=kind)
_variables[v.name] = v.resample(sr_arg, kind=kind)

coll = self if in_place else self.clone()

Expand All @@ -388,7 +390,7 @@ def _densify_and_resample(self, sampling_rate=None, variables=None,
else:
coll.variables = _variables

coll.sampling_rate = sampling_rate
coll.sampling_rate = sr
return coll

def to_dense(self, sampling_rate=None, variables=None, in_place=False,
Expand Down
9 changes: 5 additions & 4 deletions bids/variables/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class RunNode(Node):
The task name for this run.
"""

def __init__(self, entities, image_file, duration, repetition_time):
def __init__(self, entities, image_file, duration, repetition_time, n_vols):
self.image_file = image_file
self.duration = duration
self.repetition_time = repetition_time
self.n_vols = n_vols
super(RunNode, self).__init__('run', entities)

def get_info(self):
Expand All @@ -68,18 +69,18 @@ def get_info(self):
# a RunInfo or any containing object.
entities = dict(self.entities)
return RunInfo(entities, self.duration,
self.repetition_time, self.image_file)
self.repetition_time, self.image_file, self.n_vols)


# Stores key information for each Run.
RunInfo_ = namedtuple('RunInfo', ['entities', 'duration', 'tr', 'image'])
RunInfo_ = namedtuple('RunInfo', ['entities', 'duration', 'tr', 'image', 'n_vols'])


# Wrap with class to provide docstring
class RunInfo(RunInfo_):
""" A namedtuple storing run-related information.
Properties include 'entities', 'duration', 'tr', and 'image'.
Properties include 'entities', 'duration', 'tr', and 'image', 'n_vols'.
"""
pass

Expand Down
4 changes: 3 additions & 1 deletion bids/variables/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _load_time_variables(layout, dataset=None, columns=None, scan_length=None,
except Exception as e:
if scan_length is not None:
duration = scan_length
nvols = int(np.rint(scan_length / tr))
else:
msg = ("Unable to extract scan duration from one or more "
"BOLD runs, and no scan_length argument was provided "
Expand Down Expand Up @@ -247,7 +248,8 @@ def _load_time_variables(layout, dataset=None, columns=None, scan_length=None,
}

run = dataset.create_node('run', entities, image_file=img_f,
duration=duration, repetition_time=tr)
duration=duration, repetition_time=tr,
n_vols=nvols)
run_info = run.get_info()

# Process event files
Expand Down
40 changes: 38 additions & 2 deletions bids/variables/tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from os.path import join, dirname, abspath

import pytest
import numpy as np

from bids.layout import BIDSLayout
from bids.tests import get_test_data_path
Expand All @@ -16,6 +17,14 @@ def run_coll():
return layout.get_collections('run', types=['events'], merge=True,
scan_length=480, subject=['01', '02', '04'])

@pytest.fixture(scope="module")
def run_coll_bad_length():
path = join(get_test_data_path(), 'ds005')
layout = BIDSLayout(path)
# Limit to a few subjects to reduce test running time
return layout.get_collections('run', types=['events'], merge=True,
scan_length=480.1, subject=['01', '02', '04'])


@pytest.fixture(scope="module")
def run_coll_list():
Expand Down Expand Up @@ -54,8 +63,7 @@ def test_run_variable_collection_dense_variable_accessors(run_coll):
def test_run_variable_collection_get_sampling_rate(run_coll):
coll = run_coll.clone()
assert coll._get_sampling_rate(None) == 10
assert coll._get_sampling_rate('TR') == 0.5
coll.variables['RT'].run_info[0] = RunInfo({}, 200, 10, None)
coll.variables['RT'].run_info[0] = RunInfo({}, 200, 10, None, 20)
with pytest.raises(ValueError) as exc:
coll._get_sampling_rate('TR')
assert str(exc.value).startswith('Non-unique')
Expand Down Expand Up @@ -204,6 +212,34 @@ def test_run_variable_collection_to_df_all_dense_vars(run_coll):
n_rows = int(rows_per_var * 12 / 10)
assert df.shape == (n_rows, 18)

def test_run_variable_collection_bad_length_to_df_all_dense_vars(run_coll_bad_length):

timing_cols = {'onset', 'duration'}
entity_cols = {'subject', 'run', 'task', 'suffix', 'datatype'}
cond_names = {'PTval', 'RT', 'gain', 'loss', 'parametric gain', 'respcat',
'respnum', 'trial_type'}
md_names = {'TaskName', 'RepetitionTime', 'extension', 'SliceTiming'}
condition = {'condition'}
ampl = {'amplitude'}

unif_coll = run_coll_bad_length.to_dense(sampling_rate=10)

df = unif_coll.to_df()
rows_per_var = np.round(3 * 3 * 480.1 * 10) # subjects x runs x time x sampling rate

# Test resampling without setting sample rate (default sr == 10)
df = unif_coll.to_df(format='long')
assert df.shape == (rows_per_var * 7, 13)
cols = timing_cols | entity_cols | condition | ampl | md_names
assert set(df.columns) == cols

# Test resampling to TR
df = unif_coll.to_df(sampling_rate='TR')
n_rows = int(480 * 3 * 3 / 2) # (Note number of volumes is 480, not 480.1)
assert df.shape == (n_rows, 18)
cols = (timing_cols | entity_cols | cond_names | md_names) - {'trial_type'}
assert set(df.columns) == cols


def test_run_variable_collection_to_df_mixed_vars(run_coll):
coll = run_coll.clone()
Expand Down
7 changes: 5 additions & 2 deletions bids/variables/tests/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def layout2():
def test_run(layout1):
img = layout1.get(subject='01', task='mixedgamblestask', suffix='bold',
run=1, return_type='obj')[0]
run = RunNode(None, img.filename, 480, 2)
run = RunNode(None, img.filename, 480, 2, 480/2)
assert run.image_file == img.filename
assert run.duration == 480
assert run.repetition_time == 2
assert run.n_vols == 480 / 2


def test_get_or_create_node(layout1):
Expand All @@ -44,9 +45,11 @@ def test_get_or_create_node(layout1):

run = index.get_or_create_node('run', img.entities,
image_file=img.filename, duration=480,
repetition_time=2)
repetition_time=2,
n_vols=480/2)
assert run.__class__ == RunNode
assert run.duration == 480
assert run.n_vols == 480 / 2


def test_get_nodes(layout1):
Expand Down
2 changes: 1 addition & 1 deletion bids/variables/tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def generate_DEV(name='test', sr=20, duration=480):
ent_names = ['task', 'run', 'session', 'subject']
entities = {e: uuid.uuid4().hex for e in ent_names}
image = uuid.uuid4().hex + '.nii.gz'
run_info = RunInfo(entities, duration, 2, image)
run_info = RunInfo(entities, duration, 2, image, duration / 2),
return DenseRunVariable(name='test', values=values, run_info=run_info,
source='dummy', sampling_rate=sr)

Expand Down
26 changes: 18 additions & 8 deletions bids/variables/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,15 +496,20 @@ def split(self, grouper):
source=self.source,
sampling_rate=self.sampling_rate)
for i, name in enumerate(df.columns)]

def _build_entity_index(self, run_info, sampling_rate):
def _build_entity_index(self, run_info, sampling_rate, match_vol=False):
"""Build the entity index from run information. """

index = []
interval = int(round(1000. / sampling_rate))
_timestamps = []
for run in run_info:
reps = int(math.ceil(run.duration * sampling_rate))
if match_vol:
# If TR, fix reps to n_vols to ensure match
reps = run.n_vols
else:
reps = int(math.ceil(run.duration * sampling_rate))

interval = int(round(1000. / sampling_rate))
ent_vals = list(run.entities.values())
df = pd.DataFrame([ent_vals] * reps, columns=list(run.entities.keys()))
ts = pd.date_range(0, periods=len(df), freq='%sms' % interval)
Expand Down Expand Up @@ -532,14 +537,19 @@ def resample(self, sampling_rate, inplace=False, kind='linear'):
var = self.clone()
var.resample(sampling_rate, True, kind)
return var


match_vol = False
if sampling_rate == 'TR':
match_vol = True
sampling_rate = 1. / self.run_info[0].tr

if sampling_rate == self.sampling_rate:
return

n = len(self.index)

self.index = self._build_entity_index(self.run_info, sampling_rate)

self.index = self._build_entity_index(self.run_info, sampling_rate, match_vol)
x = np.arange(n)
num = len(self.index)

Expand Down

0 comments on commit 3ac0f92

Please sign in to comment.