diff --git a/python/stempy/io/sparse_array.py b/python/stempy/io/sparse_array.py index cec4a04b..e4f5c37b 100644 --- a/python/stempy/io/sparse_array.py +++ b/python/stempy/io/sparse_array.py @@ -192,27 +192,30 @@ def from_hdf5(cls, filepath, keep_flyback=True, **init_kwargs): scan_positions_group = f['electron_events/scan_positions'] scan_shape = [scan_positions_group.attrs[x] for x in ['Nx', 'Ny']] frame_shape = [frames.attrs[x] for x in ['Nx', 'Ny']] - + if keep_flyback: data = frames[()] # load the full data set scan_positions = scan_positions_group[()] else: - # Generate the original scan indices from the scan_shape - orig_indices = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape) - # Remove the indices of the last column - crop_indices = np.delete(orig_indices, orig_indices[scan_shape[0]-1::scan_shape[0]]) - # Load only the data needed - data = frames[crop_indices] - # Reduce the column shape by 1 - scan_shape[0] = scan_shape[0] - 1 + num = frames.shape[0] // np.prod(scan_shape, dtype=int) # number of frames per probe position + data = np.empty(((scan_shape[0]-1) * scan_shape[1] * num), dtype=object) + new_num_cols = scan_shape[0]-1 # number of columns without flyback + for ii in range(scan_shape[1]): + start = ii*new_num_cols*num # start of cropped data + end = (ii+1)*new_num_cols*num + start2 = ii*new_num_cols*num + num*ii # start of uncropped data + end2 = (ii+1)*new_num_cols*num + num*ii + data[start:end] = frames[start2:end2] + scan_shape = (scan_shape[0]-1, scan_shape[1]) # update scan shape # Create the proper scan_positions without the flyback column - scan_positions = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape) + scan_positions = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)], scan_shape) # Load any metadata metadata = {} if 'metadata' in f: load_h5_to_dict(f['metadata'], metadata) + # reverse the scan shape to match expected shape scan_shape = scan_shape[::-1] if version >= 3: diff --git a/tests/conftest.py b/tests/conftest.py index d8cc8f25..bcd86b6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -109,6 +109,11 @@ def cropped_multi_frames_v2(cropped_multi_frames_data_v2): def cropped_multi_frames_v3(cropped_multi_frames_data_v3): return SparseArray.from_hdf5(cropped_multi_frames_data_v3, dtype=np.uint16) +@pytest.fixture +def cropped_multi_frames_v3_noflyback(cropped_multi_frames_data_v3): + return SparseArray.from_hdf5(cropped_multi_frames_data_v3, + dtype=np.uint16, keep_flyback=False) + @pytest.fixture def simulate_sparse_array(): diff --git a/tests/test_sparse_array.py b/tests/test_sparse_array.py index e088884c..87fa36e0 100644 --- a/tests/test_sparse_array.py +++ b/tests/test_sparse_array.py @@ -710,11 +710,13 @@ def compare_with_sparse(full, sparse): assert np.array_equal(m_array[[False, True], 0][0], position_one) -def test_keep_flyback(electron_data_small): - flyback = SparseArray.from_hdf5(electron_data_small, keep_flyback=True) - assert flyback.scan_shape[1] == 50 - no_flyback = SparseArray.from_hdf5(electron_data_small, keep_flyback=False) - assert no_flyback.scan_shape[1] == 49 +def test_keep_flyback(cropped_multi_frames_v3, cropped_multi_frames_v3_noflyback): + # Test keeping the flyback + assert cropped_multi_frames_v3.scan_shape[1] == 20 + assert cropped_multi_frames_v3.num_frames_per_scan == 2 + # Test removing the flyback + assert cropped_multi_frames_v3_noflyback.scan_shape[1] == 19 + assert cropped_multi_frames_v3_noflyback.num_frames_per_scan == 2 # Test binning until this number TEST_BINNING_UNTIL = 33