Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resample to n_vols for sampling_rate == 'TR' #713

Merged
merged 14 commits into from
Apr 12, 2021
4 changes: 2 additions & 2 deletions bids/analysis/tests/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def sparse_run_variable_with_missing_values():
'duration': [1.2, 1.6, 0.8, 2],
'amplitude': [1, 1, np.nan, 1]
})
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz')]
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz', 10)]
var = SparseRunVariable(
name='var', data=data, run_info=run_info, source='events')
return BIDSRunVariableCollection([var])
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_convolve_impulse():
'duration': [0, 0],
'amplitude': [1, 1]
})
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz')]
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz', 10)]
var = SparseRunVariable(
name='var', data=data, run_info=run_info, source='events')
coll = BIDSRunVariableCollection([var])
Expand Down
10 changes: 6 additions & 4 deletions bids/variables/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _densify_and_resample(self, sampling_rate=None, variables=None,
resample_dense=False, force_dense=False,
in_place=False, kind='linear'):

sampling_rate = self._get_sampling_rate(sampling_rate)
sr = self._get_sampling_rate(sampling_rate)

_dense, _sparse = [], []

Expand All @@ -375,11 +375,13 @@ def _densify_and_resample(self, sampling_rate=None, variables=None,
if force_dense:
for v in _sparse:
if is_numeric_dtype(v.values):
_variables[v.name] = v.to_dense(sampling_rate)
_variables[v.name] = v.to_dense(sr)

if resample_dense:
# Propagate 'TR' if exact match to TR is required
sr_arg = sampling_rate if sampling_rate == 'TR' else sr
for v in _dense:
_variables[v.name] = v.resample(sampling_rate, kind=kind)
_variables[v.name] = v.resample(sr_arg, kind=kind)

coll = self if in_place else self.clone()

Expand All @@ -388,7 +390,7 @@ def _densify_and_resample(self, sampling_rate=None, variables=None,
else:
coll.variables = _variables

coll.sampling_rate = sampling_rate
coll.sampling_rate = sr
return coll

def to_dense(self, sampling_rate=None, variables=None, in_place=False,
Expand Down
9 changes: 5 additions & 4 deletions bids/variables/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class RunNode(Node):
The task name for this run.
"""

def __init__(self, entities, image_file, duration, repetition_time):
def __init__(self, entities, image_file, duration, repetition_time, n_vols):
self.image_file = image_file
self.duration = duration
self.repetition_time = repetition_time
self.n_vols = n_vols
super(RunNode, self).__init__('run', entities)

def get_info(self):
Expand All @@ -68,18 +69,18 @@ def get_info(self):
# a RunInfo or any containing object.
entities = dict(self.entities)
return RunInfo(entities, self.duration,
self.repetition_time, self.image_file)
self.repetition_time, self.image_file, self.n_vols)


# Stores key information for each Run.
RunInfo_ = namedtuple('RunInfo', ['entities', 'duration', 'tr', 'image'])
RunInfo_ = namedtuple('RunInfo', ['entities', 'duration', 'tr', 'image', 'n_vols'])


# Wrap with class to provide docstring
class RunInfo(RunInfo_):
""" A namedtuple storing run-related information.
Properties include 'entities', 'duration', 'tr', and 'image'.
Properties include 'entities', 'duration', 'tr', and 'image', 'n_vols'.
"""
pass

Expand Down
4 changes: 3 additions & 1 deletion bids/variables/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _load_time_variables(layout, dataset=None, columns=None, scan_length=None,
except Exception as e:
if scan_length is not None:
duration = scan_length
nvols = int(scan_length / tr)
adelavega marked this conversation as resolved.
Show resolved Hide resolved
else:
msg = ("Unable to extract scan duration from one or more "
"BOLD runs, and no scan_length argument was provided "
Expand Down Expand Up @@ -247,7 +248,8 @@ def _load_time_variables(layout, dataset=None, columns=None, scan_length=None,
}

run = dataset.create_node('run', entities, image_file=img_f,
duration=duration, repetition_time=tr)
duration=duration, repetition_time=tr,
n_vols=nvols)
run_info = run.get_info()

# Process event files
Expand Down
3 changes: 1 addition & 2 deletions bids/variables/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def test_run_variable_collection_dense_variable_accessors(run_coll):
def test_run_variable_collection_get_sampling_rate(run_coll):
coll = run_coll.clone()
assert coll._get_sampling_rate(None) == 10
assert coll._get_sampling_rate('TR') == 0.5
coll.variables['RT'].run_info[0] = RunInfo({}, 200, 10, None)
coll.variables['RT'].run_info[0] = RunInfo({}, 200, 10, None, 20)
with pytest.raises(ValueError) as exc:
coll._get_sampling_rate('TR')
assert str(exc.value).startswith('Non-unique')
Expand Down
7 changes: 5 additions & 2 deletions bids/variables/tests/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def layout2():
def test_run(layout1):
img = layout1.get(subject='01', task='mixedgamblestask', suffix='bold',
run=1, return_type='obj')[0]
run = RunNode(None, img.filename, 480, 2)
run = RunNode(None, img.filename, 480, 2, 480/2)
assert run.image_file == img.filename
assert run.duration == 480
assert run.repetition_time == 2
assert run.n_vols == 480 / 2


def test_get_or_create_node(layout1):
Expand All @@ -44,9 +45,11 @@ def test_get_or_create_node(layout1):

run = index.get_or_create_node('run', img.entities,
image_file=img.filename, duration=480,
repetition_time=2)
repetition_time=2,
n_vols=480/2)
assert run.__class__ == RunNode
assert run.duration == 480
assert run.n_vols == 480 / 2


def test_get_nodes(layout1):
Expand Down
2 changes: 1 addition & 1 deletion bids/variables/tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def generate_DEV(name='test', sr=20, duration=480):
ent_names = ['task', 'run', 'session', 'subject']
entities = {e: uuid.uuid4().hex for e in ent_names}
image = uuid.uuid4().hex + '.nii.gz'
run_info = RunInfo(entities, duration, 2, image)
run_info = RunInfo(entities, duration, 2, image, duration / 2),
return DenseRunVariable(name='test', values=values, run_info=run_info,
source='dummy', sampling_rate=sr)

Expand Down
26 changes: 18 additions & 8 deletions bids/variables/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,15 +496,20 @@ def split(self, grouper):
source=self.source,
sampling_rate=self.sampling_rate)
for i, name in enumerate(df.columns)]

def _build_entity_index(self, run_info, sampling_rate):
def _build_entity_index(self, run_info, sampling_rate, match_vol=False):
"""Build the entity index from run information. """

index = []
interval = int(round(1000. / sampling_rate))
_timestamps = []
for run in run_info:
reps = int(math.ceil(run.duration * sampling_rate))
if match_vol:
# If TR, fix reps to n_vols to ensure match
reps = run.n_vols
else:
reps = int(math.ceil(run.duration * sampling_rate))

interval = int(round(1000. / sampling_rate))
ent_vals = list(run.entities.values())
df = pd.DataFrame([ent_vals] * reps, columns=list(run.entities.keys()))
ts = pd.date_range(0, periods=len(df), freq='%sms' % interval)
Expand Down Expand Up @@ -532,14 +537,19 @@ def resample(self, sampling_rate, inplace=False, kind='linear'):
var = self.clone()
var.resample(sampling_rate, True, kind)
return var


match_vol = False
if sampling_rate == 'TR':
match_vol = True
sampling_rate = 1. / self.run_info[0].tr

if sampling_rate == self.sampling_rate:
return

n = len(self.index)

self.index = self._build_entity_index(self.run_info, sampling_rate)

self.index = self._build_entity_index(self.run_info, sampling_rate, match_vol)
x = np.arange(n)
num = len(self.index)

Expand Down