Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resample to n_vols for sampling_rate == 'TR' #713

Merged
merged 14 commits into from
Apr 12, 2021
4 changes: 2 additions & 2 deletions bids/analysis/tests/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def sparse_run_variable_with_missing_values():
'duration': [1.2, 1.6, 0.8, 2],
'amplitude': [1, 1, np.nan, 1]
})
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz')]
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz', 10)]
var = SparseRunVariable(
name='var', data=data, run_info=run_info, source='events')
return BIDSRunVariableCollection([var])
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_convolve_impulse():
'duration': [0, 0],
'amplitude': [1, 1]
})
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz')]
run_info = [RunInfo({'subject': '01'}, 20, 2, 'dummy.nii.gz', 10)]
var = SparseRunVariable(
name='var', data=data, run_info=run_info, source='events')
coll = BIDSRunVariableCollection([var])
Expand Down
14 changes: 1 addition & 13 deletions bids/variables/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,21 +324,9 @@ def _get_sampling_rate(self, sampling_rate):
if sampling_rate is None:
return self.sampling_rate

if isinstance(sampling_rate, (float, int)):
if isinstance(sampling_rate, (float, int)) or sampling_rate == 'TR':
return sampling_rate

if sampling_rate == 'TR':
trs = {var.run_info[0].tr for var in self.variables.values()}
if not trs:
raise ValueError("Repetition time unavailable; specify "
"sampling_rate in Hz explicitly or set to"
" 'highest'.")
elif len(trs) > 1:
raise ValueError("Non-unique Repetition times found "
"({!r}); specify sampling_rate explicitly"
.format(trs))
return 1. / trs.pop()

if sampling_rate.lower() == 'highest':
dense_vars = self.get_dense_variables()
# If no dense variables are available, fall back on instance SR
Expand Down
9 changes: 5 additions & 4 deletions bids/variables/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class RunNode(Node):
The task name for this run.
"""

def __init__(self, entities, image_file, duration, repetition_time):
def __init__(self, entities, image_file, duration, repetition_time, n_vols):
self.image_file = image_file
self.duration = duration
self.repetition_time = repetition_time
self.n_vols = n_vols
super(RunNode, self).__init__('run', entities)

def get_info(self):
Expand All @@ -68,18 +69,18 @@ def get_info(self):
# a RunInfo or any containing object.
entities = dict(self.entities)
return RunInfo(entities, self.duration,
self.repetition_time, self.image_file)
self.repetition_time, self.image_file, self.n_vols)


# Stores key information for each Run.
RunInfo_ = namedtuple('RunInfo', ['entities', 'duration', 'tr', 'image'])
RunInfo_ = namedtuple('RunInfo', ['entities', 'duration', 'tr', 'image', 'n_vols'])


# Wrap with class to provide docstring
class RunInfo(RunInfo_):
""" A namedtuple storing run-related information.
Properties include 'entities', 'duration', 'tr', and 'image'.
Properties include 'entities', 'duration', 'tr', and 'image', 'n_vols'.
"""
pass

Expand Down
4 changes: 3 additions & 1 deletion bids/variables/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _load_time_variables(layout, dataset=None, columns=None, scan_length=None,
"as a fallback. Please check that the image files are "
"available, or manually specify the scan duration.")
raise ValueError(msg) from e
nvols = None

# We don't want to pass all the image file's entities onto get_node(),
# as there can be unhashable nested slice timing values, and this also
Expand Down Expand Up @@ -247,7 +248,8 @@ def _load_time_variables(layout, dataset=None, columns=None, scan_length=None,
}

run = dataset.create_node('run', entities, image_file=img_f,
duration=duration, repetition_time=tr)
duration=duration, repetition_time=trm,
effigies marked this conversation as resolved.
Show resolved Hide resolved
n_vols=nvols)
run_info = run.get_info()

# Process event files
Expand Down
2 changes: 1 addition & 1 deletion bids/variables/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_run_variable_collection_get_sampling_rate(run_coll):
coll = run_coll.clone()
assert coll._get_sampling_rate(None) == 10
assert coll._get_sampling_rate('TR') == 0.5
coll.variables['RT'].run_info[0] = RunInfo({}, 200, 10, None)
coll.variables['RT'].run_info[0] = RunInfo({}, 200, 10, None, 20)
with pytest.raises(ValueError) as exc:
coll._get_sampling_rate('TR')
assert str(exc.value).startswith('Non-unique')
Expand Down
4 changes: 3 additions & 1 deletion bids/variables/tests/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def test_get_or_create_node(layout1):

run = index.get_or_create_node('run', img.entities,
image_file=img.filename, duration=480,
repetition_time=2)
repetition_time=2,
n_vols=480/2)
assert run.__class__ == RunNode
assert run.duration == 480
assert run.n_vols = 480 / 2
effigies marked this conversation as resolved.
Show resolved Hide resolved


def test_get_nodes(layout1):
Expand Down
2 changes: 1 addition & 1 deletion bids/variables/tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def generate_DEV(name='test', sr=20, duration=480):
ent_names = ['task', 'run', 'session', 'subject']
entities = {e: uuid.uuid4().hex for e in ent_names}
image = uuid.uuid4().hex + '.nii.gz'
run_info = RunInfo(entities, duration, 2, image)
run_info = RunInfo(entities, duration, 2, image, duration / 2),
return DenseRunVariable(name='test', values=values, run_info=run_info,
source='dummy', sampling_rate=sr)

Expand Down
34 changes: 30 additions & 4 deletions bids/variables/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,15 +496,39 @@ def split(self, grouper):
source=self.source,
sampling_rate=self.sampling_rate)
for i, name in enumerate(df.columns)]


def _get_sampling_rate(sampling_rate):
if sampling_rate == 'TR':
trs = {var.run_info[0].tr for var in self.variables.values()}
if not trs:
raise ValueError("Repetition time unavailable; specify "
"sampling_rate in Hz explicitly or set to"
" 'highest'.")
elif len(trs) > 1:
raise ValueError("Non-unique Repetition times found "
"({!r}); specify sampling_rate explicitly"
.format(trs))
return 1. / trs.pop()

else:
return sampling_rate


def _build_entity_index(self, run_info, sampling_rate):
"""Build the entity index from run information. """

index = []
interval = int(round(1000. / sampling_rate))
_timestamps = []
for run in run_info:
reps = int(math.ceil(run.duration * sampling_rate))
if sampling_rate == 'TR':
reps = run.n_vols
sr = run.tr
else:
sr = sampling_rate
reps = int(math.ceil(run.duration * sr))

interval = int(round(1000. / sr))
ent_vals = list(run.entities.values())
df = pd.DataFrame([ent_vals] * reps, columns=list(run.entities.keys()))
ts = pd.date_range(0, periods=len(df), freq='%sms' % interval)
Expand Down Expand Up @@ -532,8 +556,10 @@ def resample(self, sampling_rate, inplace=False, kind='linear'):
var = self.clone()
var.resample(sampling_rate, True, kind)
return var

sr = self._get_sampling_rate(sampling_rate)

if sampling_rate == self.sampling_rate:
if sr == self.sampling_rate:
return

n = len(self.index)
Expand All @@ -543,7 +569,7 @@ def resample(self, sampling_rate, inplace=False, kind='linear'):
x = np.arange(n)
num = len(self.index)

if sampling_rate < self.sampling_rate:
if sr < self.sampling_rate:
# Downsampling, so filter the signal
from scipy.signal import butter, filtfilt
# cutoff = new Nyqist / old Nyquist
Expand Down