diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index bff8709..221d30a 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -31,7 +31,7 @@ jobs: - name: Build sdist and wheel run: pipx run build - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: path: dist @@ -45,11 +45,11 @@ jobs: permissions: id-token: write steps: - - uses: actions/setup-python@v4.7.0 + - uses: actions/setup-python@v5 name: Install Python with: python-version: '3.10' - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: artifact path: dist @@ -58,7 +58,7 @@ jobs: ls -ltrh ls -ltrh dist - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.8.10 + uses: pypa/gh-action-pypi-publish@v1.8.14 with: repository-url: https://test.pypi.org/legacy/ verbose: true @@ -92,10 +92,10 @@ jobs: if: github.event_name == 'release' && github.event.action == 'published' steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@v1.8.10 + - uses: pypa/gh-action-pypi-publish@v1.8.14 if: startsWith(github.ref, 'refs/tags') diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7fb802c..f903be2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11"] # add this back later: , "3.12" runs-on: [ubuntu-latest, macos-latest, windows-latest] experimental: [false, false, true] @@ -33,13 +33,13 @@ jobs: with: fetch-depth: 0 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} allow-prereleases: true - name: Install package - run: python -m pip install .[test] + run: python -m pip install .[dev] - name: Test package run: >- @@ -47,4 +47,4 @@ jobs: --durations=20 - name: Upload coverage report - uses: codecov/codecov-action@v3.1.4 + uses: codecov/codecov-action@v4.1.0 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 551f036..fc29d96 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -15,4 +15,4 @@ python: - method: pip path: . extra_requirements: - - docs + - dev diff --git a/README.md b/README.md index 583c274..d6d494d 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![ssec](https://img.shields.io/badge/SSEC-Project-purple?logo=&style=plastic)](https://escience.washington.edu/wetai/) [![MIT License](https://badgen.net/badge/license/MIT/blue)](LICENSE) [![Documentation Status](https://readthedocs.org/projects/braingeneers/badge/?version=latest)](https://braingeneers.readthedocs.io/en/latest/?badge=latest) +[![DOI](https://zenodo.org/badge/166130153.svg)](https://zenodo.org/badge/latestdoi/166130153) ## Getting Started @@ -38,10 +39,22 @@ pip install --force-reinstall git+https://github.com/braingeneers/braingeneerspy You can install `braingeneerspy` with specific optional dependencies based on your needs. Use the following command examples: -- Install with IoT, analysis, and data access functions (skips machine learning and lab-specific dependencies): +- Install with machine-learning dependencies: ```bash -pip install "braingeneers[iot,analysis,data]" +pip install "braingeneers[ml]" +``` + +- Install with Hengen lab dependencies: + +```bash +pip install "braingeneers[hengenlab]" +``` + +- Install with developer dependencies (running tests and building sphinx docs): + +```bash +pip install "braingeneers[dev]" ``` - Install with all optional dependencies: @@ -52,7 +65,7 @@ pip install "braingeneers[all]" ## Committing Changes to the Repo -To make changes and publish them on GitHub, please refer to the [CONTRIBUTING.md](https://github.com/braingeneers/braingeneerspy/blob/development/.github/CONTRIBUTING.md) file for up-to-date guidelines. +To make changes and publish them on GitHub, please refer to the [CONTRIBUTING.md](https://github.com/braingeneers/braingeneerspy/blob/master/.github/CONTRIBUTING.md) file for up-to-date guidelines. ## Modules and Subpackages diff --git a/braingeneers/data/test_data/maxwell-metadata.expected.json b/braingeneers/data/test_data/maxwell-metadata.expected.json index dc20144..c39ae72 100644 --- a/braingeneers/data/test_data/maxwell-metadata.expected.json +++ b/braingeneers/data/test_data/maxwell-metadata.expected.json @@ -172,4 +172,4 @@ "data_format": "NeurodataWithoutBorders" } } -} \ No newline at end of file +} diff --git a/braingeneers/data/test_data/maxwell-metadata.old.json b/braingeneers/data/test_data/maxwell-metadata.old.json index aa77584..48afc10 100644 --- a/braingeneers/data/test_data/maxwell-metadata.old.json +++ b/braingeneers/data/test_data/maxwell-metadata.old.json @@ -171,4 +171,4 @@ ] } } -} \ No newline at end of file +} diff --git a/src/braingeneers/_version.py b/src/braingeneers/_version.py new file mode 100644 index 0000000..2d07c8e --- /dev/null +++ b/src/braingeneers/_version.py @@ -0,0 +1,2 @@ +__version__ = version = '0.0.0.dev0' +__version_tuple__ = version_tuple = (0, 0, 0, 'dev0') diff --git a/src/braingeneers/data/datasets_electrophysiology.py b/src/braingeneers/data/datasets_electrophysiology.py index dce84f6..c1502e0 100644 --- a/src/braingeneers/data/datasets_electrophysiology.py +++ b/src/braingeneers/data/datasets_electrophysiology.py @@ -1,9 +1,11 @@ from __future__ import annotations import os +import sys import json import warnings import copy +import diskcache import matplotlib.pyplot as plt import numpy as np @@ -27,10 +29,6 @@ import re from types import ModuleType import bisect -try: - import neuraltoolkit as ntk # optional import -except ImportError: - pass VALID_LOAD_DATA_DTYPES = [np.int16, np.float16, np.float32, np.float64] @@ -63,15 +61,10 @@ def list_uuids(): def save_metadata(metadata: dict): """ - Saves a metadata file back to S3. This is not multi-writer safe, you can use a lock as shown in the example: - - from braingeneers.iot.messaging import MessageBroker() - import braingeneers.data.datasets_electrophysiology as de - - with MessageBroker().get_lock('a-unique-lock-name-for-your-process'): - metadata = de.load_metadata(uuid) - metadata = do_something_to(metadata) - de.save_metadata(metadata) + Saves a metadata file back to S3. This is not multi-writer safe, you can use: + braingeneers.utils.common_utils.checkout + braingeneers.utils.common_utils.checkin + to lock the file while you are writing to it. :param metadata: the metadata dictionary as obtained from load_metadata(uuid) """ @@ -86,6 +79,38 @@ def save_metadata(metadata: dict): f.write(json.dumps(metadata, indent=2)) +def cached_load_data(cache_path: str, max_size_gb: int = 10, **kwargs): + """ + Wraps a call to load_data with a diskcache at path `cache_path`. + This is multiprocessing/thread safe. + All arguments after the cache_path are passed to load_data (see load_data docs) + You must specify the load_data argument names to avoid ambiguity with the cached_load_data parameters. + + When reading data from S3 (or even a compressed local file), this can provide a significant speedup by + storing the results of load_data in a local (uncompressed) cache. + + Example usage: + from braingeneers.data.datasets_electrophysiology import load_metadata, cached_load_data + + metadata = load_metadata('9999-00-00-e-test') + data = cached_load_data(cache_path='/tmp/cache-dir', metadata=metadata, experiment=0, offset=0, length=1000) + + Note: this can safely be used with `map2` from `braingeneers.utils.common_utils` to parallelize calls to load_data. + + :param cache_path: str, path to the cache directory. + :param max_size_gb: int, maximum size of the cache in GB (10 GB default). If the cache exceeds this size, the oldest items will be removed. + :param kwargs: keyword arguments to pass to load_data, see load_data documentation. + """ + cache = diskcache.Cache(cache_path, size_limit=10 ** 9 * max_size_gb) + key = json.dumps(kwargs) + if key in cache: + return cache[key] + else: + data = load_data(**kwargs) + cache[key] = data + return data + + def load_metadata(batch_uuid: str) -> dict: """ Loads the batch UUID metadata. @@ -287,11 +312,11 @@ def load_windows(metadata, exp, window_centers, window_sz, dtype=np.float16, # Check if window is out of bounds if window[0] < 0 or window[1] > dataset_length: - print("Window out of bounds, inserting zeros for window",window) + print("Window out of bounds, inserting zeros for window", window) try: data_temp = np.zeros((data_temp.shape[0],window_sz),dtype=dtype) except Exception as e: - print(e) + print(e, file=sys.stderr) data_temp = load_window(metadata, exp, window, dtype=dtype, channels=channels) else: data_temp = load_window(metadata, exp, window, dtype=dtype, channels=channels) @@ -659,14 +684,14 @@ def load_stims_maxwell(uuid: str, metadata_ephys_exp: dict = None, experiment_st return df except FileNotFoundError: - print(f'\tThere seems to be no stim log file for this experiment! :(') + print(f'\tThere seems to be no stim log file for this experiment! :(', file=sys.stderr) return None except OSError: - print(f'\tThere seems to be no stim log file (on s3) for this experiment! :(') + print(f'\tThere seems to be no stim log file (on s3) for this experiment! :(', file=sys.stderr) return None -def load_gpio_maxwell(dataset_path, fs=20000): +def load_gpio_maxwell(dataset_path, fs=20000.0): """ Loads the GPIO events for optogenetics stimulation. :param dataset_path: a local or a s3 path @@ -675,10 +700,12 @@ def load_gpio_maxwell(dataset_path, fs=20000): """ with smart_open.open(dataset_path, 'rb') as f: with h5py.File(f, 'r') as dataset: - assert 'bits' in dataset.keys(), 'No GPIO event in the dataset!' + if 'bits' not in dataset.keys(): + print('No GPIO event in the dataset!', file=sys.stderr) + return np.array([]) bits_dataset = list(dataset['bits']) - bits_dataframe = [bits_dataset[i][0] for i in range(len(bits_dataset))] - rec_startframe = dataset['raw'][-1, 0] << 16 | dataset['raw'][-2, 0] + bits_dataframe = [bits_dataset[i][0] for i in range(len(bits_dataset))] + rec_startframe = dataset['sig'][-1, 0] << 16 | dataset['sig'][-2, 0] if len(bits_dataframe) % 2 == 0: stim_pairs = (np.array(bits_dataframe) - rec_startframe).reshape(len(bits_dataframe) // 2, 2) return stim_pairs / fs @@ -844,133 +871,6 @@ def _read_hengenlab_ecube_timestamp(filepath: str) -> int: return int(np.frombuffer(f.read(8), dtype=np.uint64)) -def generate_metadata_hengenlab(batch_uuid: str, - dataset_name: str, - experiment_name: Union[List[str], str] = 'experiment1', - fs: int = 25000, - n_threads: int = 32, - save: bool = False): - """ - Generates a metadata json and experiment1...experimentN section for a hengenlab dataset upload. - File locations in S3 for hengenlab neural data files: - s3://braingeneers/ephys/YYYY-MM-DD-e-${DATASET_NAME}/original/data/*.bin - Contiguous recording periods - :param batch_uuid: location on braingeneers storage (S3) - :param dataset_name: the dataset_name as defined in `neuraltoolkit`. Metadata will be pulled from `neuraltoolkit`. - :param experiment_name: Dataset name as stored in `neuraltoolkit`. For example "CAF26" - :param fs: sampling rate, default to 25,000 - :param n_threads: number of threads to use for reading ecube timestamps (default: 32) - :param save: (default False) option to save the metadata.json back to S3 - (or the current braingeneers.default_endpoint) - :return: metadata.json - """ - # hengenlab's (current) source of record for experiment metadata is stored in a repo which can't be imported - # due to unacceptable dependencies. Instead, the source code is being downloaded with the relevant static - # functions parsed out explicitly. This is a hacky approach, but this data shouldn't be stored - # in a repo and is expected to be replaced with a proper database in the future. - # All current solutions to this problem are bad, this is the least objectionable solution. - crit_utils_src = requests.get('https://raw.githubusercontent.com/hengenlab/sahara_work/master/crit_utils.py').text - - src_get_birthday = re.search(r'(def get_birthday\(animal, returnall=False\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_regions = re.search(r'(def get_regions\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_sex = re.search(r'(def get_sex\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_genotype = re.search(r'(def get_genotype\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_hstype = re.search(r'(def get_hstype\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - - module = ModuleType('tempmodule') - module.dt = datetime # the only import necessary to run these static functions - exec(compile(src_get_birthday, '', 'exec'), module.__dict__) - exec(compile(src_get_regions, '', 'exec'), module.__dict__) - exec(compile(src_get_sex, '', 'exec'), module.__dict__) - exec(compile(src_get_genotype, '', 'exec'), module.__dict__) - exec(compile(src_get_hstype, '', 'exec'), module.__dict__) - - headstage_types = module.get_hstype(dataset_name.lower()) - - # list neural data files on S3 - s3_path = f's3://braingeneers/ephys/{batch_uuid}/original/{experiment_name}/' - neural_data_files = common_utils.file_list(s3_path) - assert len(neural_data_files) > 0, f'No neural data files found at: {s3_path}' - - args = [s3_path + ndf[0] for ndf in neural_data_files] - - # get ecube times for each file - ecube_timestamps = common_utils.map2( - _read_hengenlab_ecube_timestamp, - args=args, - parallelism=n_threads, - use_multithreading=True, - ) - - # sort data files by ecube timestamps - neural_data_files = [(*ndf, et) for ndf, et in zip(neural_data_files, ecube_timestamps)] - neural_data_files.sort(key=lambda ndf: ndf[3]) - - # parse n_channels from file name - channels_match = re.search(r'.*Headstages_(\d+)_Channels.*', neural_data_files[0][0]) - assert channels_match is not None, f'Unable to parse n_channels from filename: {neural_data_files[0][0]}' - n_channels = int(channels_match.group(1)) - - # parse timestamp from first file name - timestamp_match = re.search(r'.*_Channels_int16_(.+)\.bin', neural_data_files[0][0]) - assert timestamp_match is not None, f'Unable to parse timestamp from filename: {neural_data_files[0][0]}' - timestamp = datetime.strptime(timestamp_match.group(1), '%Y-%m-%d_%H-%M-%S') - - channels_per_probe = n_channels // len(headstage_types) - channel_map = list(itertools.chain(*[ - (ntk.find_channel_map(hstype, number_of_channels=channels_per_probe) + i * channels_per_probe).tolist() - for i, hstype in enumerate(headstage_types) - ])) - - metadata = dict( - uuid=batch_uuid, - timestamp=timestamp.isoformat(), - issue='', - channel_map=channel_map, - headstage_types=headstage_types, - notes=dict( - purpose_of_experiment='', - comments='', - biology=dict( - sample_type='mouse', - dataset_name=dataset_name, - birthday=module.get_birthday(dataset_name.lower()).isoformat(), - gender=module.get_sex(dataset_name.lower()), - genotype=module.get_genotype(dataset_name.lower()), - ), - ), - ephys_experiments=[dict( - name=experiment_name, - hardware='Hengenlab', - num_channels=n_channels, - sample_rate=fs, - voltage_scaling_factor=0.19073486328125, - timestamp=timestamp.isoformat(), - units='\u00b5V', - version='1.0.0', - blocks=[ - { - 'num_frames': (size - 8) // 2 // n_channels, - 'path': f'original/{experiment_name}/{neural_data_file.split("/")[-1]}', - 'timestamp': datetime.strptime( - re.search(r'.*_Channels_int16_(.+)\.bin', neural_data_file).group(1), - '%Y-%m-%d_%H-%M-%S', - ).isoformat(), - 'ecube_time': ecube_time, - } - for neural_data_file, last_modified_timestamp, size, ecube_time in neural_data_files - ], - )], - ) - - if save is True: - save_path = f's3://braingeneers/ephys/{batch_uuid}/metadata.json' - with smart_open.open(save_path, 'w') as f: - f.write(json.dumps(metadata, indent=2)) - - return metadata - - # --- AXION READER ----------------------------- def from_uint64(all_values): """ @@ -1228,7 +1128,7 @@ def _axion_generate_per_block_metadata(filename: str): fid.seek(start + int(obj.length.item()), 0) if fid.tell() != start + obj.length: - print('Unexpected Channel array length') + print('Unexpected Channel array length', file=sys.stderr) elif obj.type == 3: continue @@ -1300,47 +1200,6 @@ def _axion_get_data(file_name, file_data_start_position, return final_raw_data_reshaped -# class IndexedList(list): -# """ -# A variant of OrderedDict indexable by index (int) or name (str). -# This class forces ints to represent index by location, else index by name/object. -# Example usages: -# metadata['ephys_experiments']['experiment0'] # index by name (must use str type) -# metadata['ephys_experiments'][0] # index by location (must use int type) -# """ -# -# def __init__(self, original_list: list, key: callable): -# self.keys_ordered = [key(v) for v in original_list] -# self.dict = {key(v): v for v in original_list} -# super().__init__() -# -# def __getitem__(self, key): -# print(key) -# if isinstance(key, int): -# return self.dict[self.keys_ordered[key]] -# elif isinstance(key, str): -# return self.dict[key] -# else: -# raise KeyError(f'Key must be type int (index by location) or str (index by name), got type: {type(key)}') -# -# def __iter__(self) -> Iterator: -# def g(): -# for k in self.keys_ordered: -# yield self.dict[k] -# -# return g() -# -# def __hash__(self): -# return self.dict.__hash__() -# -# def __eq__(self, other): -# return isinstance(other, IndexedList) and self.dict.__eq__(other.dict) -# -# def __add__(self, value): -# self.keys_ordered.append(value) -# self.dict[value] = value - - def get_mearec_h5_recordings_file(batch_uuid: str): """ Returns the filepath to the MEArec .h5/.hdf5 recordings file for the given UUID. diff --git a/src/braingeneers/data/datasets_electrophysiology_test.py b/src/braingeneers/data/datasets_electrophysiology_test.py index 525fff8..807b576 100644 --- a/src/braingeneers/data/datasets_electrophysiology_test.py +++ b/src/braingeneers/data/datasets_electrophysiology_test.py @@ -1,15 +1,20 @@ import unittest +import tempfile +import shutil +import diskcache +import json +import threading import braingeneers import braingeneers.data.datasets_electrophysiology as ephys -import json from braingeneers import skip_unittest_if_offline -# import braingeneers.utils.smart_open_braingeneers as smart_open -import smart_open +import braingeneers.utils.smart_open_braingeneers as smart_open import boto3 import numpy as np - +from unittest.mock import patch +from braingeneers.data.datasets_electrophysiology import cached_load_data from unittest.mock import patch + class MaxwellReaderTests(unittest.TestCase): @skip_unittest_if_offline @@ -124,6 +129,25 @@ def test_modify_maxwell_metadata(self): assert modified_metadata == expected_metadata + @skip_unittest_if_offline + def test_load_gpio_maxwell(self): + """ Read gpio event for Maxwell V1 file""" + data_1 = "s3://braingeneers/ephys/" \ + "2023-04-02-hc328_rec/original/data/" \ + "2023_04_02_hc328_0.raw.h5" + data_2 = "s3://braingeneers/ephys/" \ + "2023-04-04-e-hc328_hckcr1-2_040423_recs/original/data/" \ + "hc3.28_hckcr1_chip8787_plated4.4_rec4.4.raw.h5" + data_3 = "s3://braingeneers/ephys/" \ + "2023-04-04-e-hc328_hckcr1-2_040423_recs/original/data/" \ + "2023_04_04_hc328_hckcr1-2_3.raw.h5" + gpio_1 = ephys.load_gpio_maxwell(data_1) + gpio_2 = ephys.load_gpio_maxwell(data_2) + gpio_3 = ephys.load_gpio_maxwell(data_3) + self.assertEqual(gpio_1.shape, (1, 2)) + self.assertEqual(gpio_2.shape, (0,)) + self.assertEqual(gpio_3.shape, (29,)) + class MEArecReaderTests(unittest.TestCase): """The fake reader test.""" @@ -394,46 +418,86 @@ def test_online_load_data_hengenlab_float32(self): self.assertEqual((192, 4), data.shape) self.assertEqual(np.float32, data.dtype) - @skip_unittest_if_offline - def test_online_generate_metadata(self): - metadata = ephys.generate_metadata_hengenlab( - batch_uuid=self.batch_uuid, - dataset_name='CAF26', - save=False, - ) - # top level items - self.assertEqual(metadata['uuid'], '2020-04-12-e-hengenlab-caf26') - self.assertEqual(metadata['timestamp'], '2020-08-07T14:00:15') - self.assertEqual(metadata['issue'], '') - self.assertEqual(metadata['headstage_types'], ['EAB50chmap_00', 'APT_PCB', 'APT_PCB']) - - # notes - self.assertEqual(metadata['notes']['biology']['sample_type'], 'mouse') - self.assertEqual(metadata['notes']['biology']['dataset_name'], 'CAF26') - self.assertEqual(metadata['notes']['biology']['birthday'], '2020-02-20T07:30:00') - self.assertEqual(metadata['notes']['biology']['genotype'], 'wt') - - # ephys_experiments - self.assertEqual(len(metadata['ephys_experiments']), 1) - self.assertTrue(isinstance(metadata['ephys_experiments'], list)) - - experiment = metadata['ephys_experiments'][0] - self.assertEqual(experiment['name'], 'experiment1') - self.assertEqual(experiment['hardware'], 'Hengenlab') - self.assertEqual(experiment['num_channels'], 192) - self.assertEqual(experiment['sample_rate'], 25000) - self.assertEqual(experiment['voltage_scaling_factor'], 0.19073486328125) - self.assertEqual(experiment['timestamp'], '2020-08-07T14:00:15') - self.assertEqual(experiment['units'], '\u00b5V') - self.assertEqual(experiment['version'], '1.0.0') - self.assertEqual(len(experiment['blocks']), 324) - - block1 = metadata['ephys_experiments'][0]['blocks'][1] - self.assertEqual(block1['num_frames'], 7500000) - self.assertEqual(block1['path'], 'original/experiment1/Headstages_192_Channels_int16_2020-08-07_14-05-16.bin') - self.assertEqual(block1['timestamp'], '2020-08-07T14:05:16') - self.assertEqual(block1['ecube_time'], 301061600050) +class TestCachedLoadData(unittest.TestCase): + + def setUp(self): + # Create a temporary directory for the cache + self.cache_dir = tempfile.mkdtemp(prefix='test_cache_') + + def tearDown(self): + # Remove the temporary directory after the test + shutil.rmtree(self.cache_dir) + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_caching_mechanism(self, mock_load_data): + """ + Test that data is properly cached and retrieved on subsequent calls with the same parameters. + """ + mock_load_data.return_value = 'mock_data' + metadata = {'uuid': 'test_uuid'} + + # First call should invoke load_data + first_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0) + mock_load_data.assert_called_once() + + # Second call should retrieve data from cache and not invoke load_data again + second_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0) + self.assertEqual(first_call_data, second_call_data) + mock_load_data.assert_called_once() # Still called only once + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_cache_eviction_when_full(self, mock_load_data): + """ + Test that the oldest items are evicted from the cache when it exceeds its size limit. + """ + mock_load_data.side_effect = lambda **kwargs: f"data_{kwargs['experiment']}" + max_size_gb = 0.000001 # Set a very small cache size to test eviction + + # Populate the cache with enough data to exceed its size limit + for i in range(10): + cached_load_data(self.cache_dir, max_size_gb=max_size_gb, metadata={'uuid': 'test_uuid'}, experiment=i) + + cache = diskcache.Cache(self.cache_dir) + self.assertLess(len(cache), 10) # Ensure some items were evicted + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_arguments_passed_to_load_data(self, mock_load_data): + """ + Test that all arguments after cache_path are correctly passed to the underlying load_data function. + """ + # Mock load_data to return a serializable object, e.g., a numpy array + mock_load_data.return_value = np.array([1, 2, 3]) + + kwargs = {'metadata': {'uuid': 'test_uuid'}, 'experiment': 0, 'offset': 0, 'length': 1000} + cached_load_data(self.cache_dir, **kwargs) + mock_load_data.assert_called_with(**kwargs) + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_multiprocessing_thread_safety(self, mock_load_data): + """ + Test that the caching mechanism is multiprocessing/thread-safe. + """ + # Mock load_data to return a serializable object, e.g., a numpy array + mock_load_data.return_value = np.array([1, 2, 3]) + + def thread_function(cache_path, metadata, experiment): + # This function uses the mocked load_data indirectly via cached_load_data + cached_load_data(cache_path, metadata=metadata, experiment=experiment) + + metadata = {'uuid': 'test_uuid'} + threads = [] + for i in range(10): + t = threading.Thread(target=thread_function, args=(self.cache_dir, metadata, i)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # If the cache is thread-safe, this operation should complete without error + # This assertion is basic and assumes the test's success implies thread safety + self.assertTrue(True) if __name__ == '__main__': diff --git a/src/braingeneers/iot/__init__.py b/src/braingeneers/iot/__init__.py index 54494f2..1c6f8b9 100644 --- a/src/braingeneers/iot/__init__.py +++ b/src/braingeneers/iot/__init__.py @@ -1,3 +1,4 @@ import braingeneers from braingeneers.iot.messaging import * +from braingeneers.iot.device import * from braingeneers.iot.simple import * diff --git a/src/braingeneers/iot/messaging.py b/src/braingeneers/iot/messaging.py index 7742c80..db19cfc 100644 --- a/src/braingeneers/iot/messaging.py +++ b/src/braingeneers/iot/messaging.py @@ -1,27 +1,22 @@ """ A simplified MQTT client for Braingeneers specific connections """ - import redis -import tempfile -import functools -import json -import inspect import logging import os import re -import time import io import configparser import threading import queue import uuid -from typing import Callable, Tuple, List, Dict, Union import random import json import braingeneers.iot.shadows as sh -from paho.mqtt import client as mqtt_client -from deprecated import deprecated import pickle -from tenacity import retry, wait_exponential, after_log + +from typing import Callable, Tuple, List, Dict, Union +from deprecated import deprecated +from paho.mqtt import client as mqtt_client +from paho.mqtt.enums import CallbackAPIVersion AWS_REGION = 'us-west-2' @@ -767,7 +762,7 @@ def on_log(client, userdata, level, buf): self.logger.debug("MQTT log: %s", buf) client_id = f'braingeneerspy-{random.randint(0, 1000)}' - self._mqtt_connection = mqtt_client.Client(client_id) + self._mqtt_connection = mqtt_client.Client(CallbackAPIVersion.VERSION1, client_id) self._mqtt_connection.username_pw_set(self._mqtt_profile_id, self._mqtt_profile_key) self._mqtt_connection.on_connect = on_connect self._mqtt_connection.on_log = on_log @@ -821,4 +816,3 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _mqtt_topic_regex(topic: str) -> str: """ Converts a topic string with wildcards to a regex string """ return "^" + topic.replace("+", "[^/]+").replace("#", ".*").replace("$", "\\$") + "$" - diff --git a/src/braingeneers/iot/messaging_test.py b/src/braingeneers/iot/messaging_test.py index 19dd379..205acb0 100644 --- a/src/braingeneers/iot/messaging_test.py +++ b/src/braingeneers/iot/messaging_test.py @@ -303,5 +303,6 @@ def test_acquire_release(self): lock.release() + if __name__ == '__main__': unittest.main() diff --git a/src/braingeneers/utils/common_utils.py b/src/braingeneers/utils/common_utils.py index 14a9f45..c0a4031 100644 --- a/src/braingeneers/utils/common_utils.py +++ b/src/braingeneers/utils/common_utils.py @@ -1,18 +1,20 @@ """ Common utility functions """ +import io import urllib import boto3 from botocore.exceptions import ClientError import os import braingeneers -from typing import List, Tuple, Union, Callable, Iterable -import functools +import braingeneers.utils.smart_open_braingeneers as smart_open +from typing import Callable, Iterable, Union, List, Tuple, Dict, Any import inspect import multiprocessing import posixpath -import itertools - +import pathlib _s3_client = None # S3 client for boto3, lazy initialization performed in _lazy_init_s3_client() +_message_broker = None # Lazy initialization of the message broker +_named_locks = {} # Named locks for checkout and checkin def _lazy_init_s3_client(): @@ -99,31 +101,36 @@ def file_list(filepath: str) -> List[Tuple[str, str, int]]: """ Returns a list of files, last modified time, and size on local or S3 in descending order of last modified time - :param filepath: Local or S3 file path to list, example: "local/dir/" or "s3://braingeneers/ephys/ + :param filepath: Local or S3 file path to list, example: "local/dir/" or "s3://bucket/prefix/" :return: A list of tuples of [('fileA', 'last_modified_A', size), ('fileB', 'last_modified_B', size), ...] """ + files_and_details = [] + if filepath.startswith('s3://'): s3_client = _lazy_init_s3_client() o = urllib.parse.urlparse(filepath) response = s3_client.list_objects(Bucket=o.netloc, Prefix=o.path[1:]) - if 'Contents' not in response: - if raise_on_missing: - raise FileNotFoundError(filepath) - else: - return [(o.path[1:].split('/')[-1], 'Missing')] - - files_and_details = [ - (f['Key'].split('/')[-1], str(f['LastModified']), int(f['Size'])) - for f in sorted(response['Contents'], key=lambda x: x['LastModified'], reverse=True) - ] - else: + if 'Contents' in response: + files_and_details = [ + (f['Key'].split('/')[-1], str(f['LastModified']), int(f['Size'])) + for f in sorted(response['Contents'], key=lambda x: x['LastModified'], reverse=True) + ] + elif os.path.exists(filepath): files = sorted(pathlib.Path(filepath).iterdir(), key=os.path.getmtime, reverse=True) files_and_details = [(f.name, str(f.stat().st_mtime), f.stat().st_size) for f in files] return files_and_details +# Define the wrapper function as a top-level function +def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple) -> Any: + """Internal wrapper function for map2 to handle fixed values and dynamic arguments.""" + # Merge fixed_values with provided arguments, aligning provided args with required_params + call_args = {**fixed_values, **dict(zip(required_params, args))} + return func(**call_args) + + def map2(func: Callable, args: Iterable[Union[Tuple, object]] = None, fixed_values: dict = None, @@ -170,26 +177,190 @@ def f(x, y): :return: a list of the return values of func """ assert isinstance(fixed_values, (dict, type(None))) - assert isinstance(parallelism, int) + assert parallelism is False or isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer" parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism - assert isinstance(parallelism, int) + assert isinstance(parallelism, int), "parallelism must be resolved to an integer" + + fixed_values = fixed_values or {} + func_signature = inspect.signature(func) + required_params = [p.name for p in func_signature.parameters.values() if + p.default == inspect.Parameter.empty and p.name not in fixed_values] - func_partial = functools.partial(func, **(fixed_values or {})) - n_required_params = sum([p.default == inspect.Parameter.empty for p in inspect.signature(func).parameters.values()]) - n_fixed_values = len(fixed_values or {}) args_list = list(args or []) - args_tuples = args \ - if len(args_list) > 0 \ - and isinstance(args_list[0], tuple) \ - and len(args_list[0]) >= n_required_params - n_fixed_values \ - else [(a,) for a in args_list] + args_tuples = args_list if all(isinstance(a, tuple) for a in args_list) else [(a,) for a in args_list] if parallelism == 1: - result_iterator = itertools.starmap(func_partial, args_tuples) + result_iterator = map(lambda args: _map2_wrapper(fixed_values, required_params, func, args), args_tuples) else: - # noinspection PyPep8Naming - ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading is True else multiprocessing.Pool + ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool with ProcessOrThreadPool(parallelism) as pool: - result_iterator = pool.starmap(func_partial, args_tuples) + result_iterator = pool.starmap(_map2_wrapper, + [(fixed_values, required_params, func, args) for args in args_tuples]) return list(result_iterator) + + +class checkout: + """ + A context manager for atomically checking out a file from S3 for reading or writing. + + Example usage: + + # Read-then-update metadata.json (or any text based file on S3) + with checkout('s3://braingeneers/ephys/9999-0-0-e-test/metadata.json', isbinary=False) as locked_obj: + metadata_dict = json.loads(locked_obj.get_value()) + metadata_dict['new_key'] = 'new_value' + metadata_updated_str = json.dumps(metadata_dict, indent=2) + locked_obj.checkin(metadata_updated_str) + + # Read-then-update data.npy (or any binary file on S3) + with checkout('s3://braingeneersdev/test/data.npy', isbinary=True) as locked_obj: + file_obj = locked_obj.get_file() + ndarray = np.load(file_obj) + ndarray[3, 3] = 42 + locked_obj.checkin(ndarray.tobytes()) + + # Edit a file in place, note checkin is not needed, the file is updated when the context manager exits + with checkout('s3://braingeneersdev/test/test_file.bin', isbinary=True) as locked_obj: + with zipfile.ZipFile(locked_obj.get_file(), 'a') as z: + z.writestr('new_file.txt', 'new file contents') + + locked_obj functions: + get_value() # returns a string or bytes object (depending on isbinary) + get_file() # returns a file-like object akin to open() + checkin() # updates the file, accepts string, bytes, or file like objects + """ + class LockedObject: + def __init__(self, s3_file_object: io.IOBase, s3_path_str: str, isbinary: bool): + self.s3_path_str = s3_path_str + self.s3_file_object = s3_file_object # underlying file object + self.isbinary = isbinary # binary or text mode + self.modified = False # Track if the file has been modified + + def get_value(self): + # Read file object from outer class s3_file_object + self.s3_file_object.seek(0) + return self.s3_file_object.read() + + def get_file(self): + # Mark file as potentially modified when accessed + self.modified = True + # Return file object from outer class s3_file_object + self.s3_file_object.seek(0) + return self.s3_file_object + + def checkin(self, update_file: Union[str, bytes, io.IOBase]): + # Validate input + if not isinstance(update_file, (str, bytes, io.IOBase)): + raise TypeError('File must be a string, bytes, or file object.') + if isinstance(update_file, str) or isinstance(update_file, io.StringIO): + if self.isbinary: + raise ValueError( + 'Cannot check in a string or text file when checkout is specified for binary mode.') + if isinstance(update_file, bytes) or isinstance(update_file, io.BytesIO): + if not self.isbinary: + raise ValueError('Cannot check in bytes or a binary file when checkout is specified for text mode.') + + mode = 'w' if not self.isbinary else 'wb' + with smart_open.open(self.s3_path_str, mode=mode) as f: + f.write(update_file if not isinstance(update_file, io.IOBase) else update_file.read()) + + def __init__(self, s3_path_str: str, isbinary: bool = False): + # TODO: avoid circular import + from braingeneers.iot.messaging import MessageBroker + + self.s3_path_str = s3_path_str + self.isbinary = isbinary + self.mb = MessageBroker() + self.named_lock = None # message broker lock + self.locked_obj = None # user facing locked object + + def __enter__(self): + lock_str = f'common-utils-checkout-{self.s3_path_str}' + named_lock = self.mb.get_lock(lock_str) + named_lock.acquire() + self.named_lock = named_lock + f = smart_open.open(self.s3_path_str, 'rb' if self.isbinary else 'r') + self.locked_obj = checkout.LockedObject(f, self.s3_path_str, self.isbinary) + return self.locked_obj + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.locked_obj.modified: + # If the file was modified, automatically check in the changes + self.locked_obj.checkin(self.locked_obj.get_file()) + self.named_lock.release() + + +def force_release_checkout(s3_file: str): + """ + Force release the lock on a file that was checked out with checkout. + """ + # TODO: avoid circular import + from braingeneers.iot.messaging import MessageBroker + + global _message_broker + if _message_broker is None: + _message_broker = MessageBroker() + + _message_broker.delete_lock(f'common-utils-checkout-{s3_file}') + + +def pretty_print(data, n=10, indent=0): + """ + Custom pretty print function that uniformly truncates any collection (list or dictionary) + longer than `n` values, showing the first `n` values and a summary of omitted items. + Ensures mapping sections and similar are displayed compactly. + + Example usage (to display metadata.json): + + from braingeneers.utils.common_utils import pretty_print + from braingeneers.data import datasets_electrophysiology as de + + metadata = de.load_metadata('2023-04-17-e-connectoid16235_CCH') + pretty_print(metadata) + + Parameters: + - data: The data to pretty print, either a list or a dictionary. + - n: Maximum number of elements or items to display before truncation. + - indent: Don't use this. Current indentation level for formatting, used during recursion. + """ + indent_space = ' ' * indent + if isinstance(data, dict): + keys = list(data.keys()) + if len(keys) > n: + truncated_keys = keys[:n] + omitted_keys = len(keys) - n + else: + truncated_keys = keys + omitted_keys = None + + print('{') + for key in truncated_keys: + value = data[key] + print(f"{indent_space} '{key}': ", end='') + if isinstance(value, dict): + pretty_print(value, n, indent + 4) + print() + elif isinstance(value, list) and all(isinstance(x, (list, tuple)) and len(x) == 4 for x in value): + # Compact display for lists of tuples/lists of length 4. + print('[', end='') + if len(value) > n: + for item in value[:n]: + print(f"{item}, ", end='') + print(f"... (+{len(value) - n} more items)", end='') + else: + print(', '.join(map(str, value)), end='') + print('],') + else: + print(f"{value},") + if omitted_keys: + print(f"{indent_space} ... (+{omitted_keys} more items)") + print(f"{indent_space}}}", end='') + elif isinstance(data, list): + print('[') + for item in data[:n]: + pretty_print(item, n, indent + 4) + print(',') + if len(data) > n: + print(f"{indent_space} ... (+{len(data) - n} more items)") + print(f"{indent_space}]", end='') diff --git a/src/braingeneers/utils/common_utils_test.py b/src/braingeneers/utils/common_utils_test.py new file mode 100644 index 0000000..be1cc90 --- /dev/null +++ b/src/braingeneers/utils/common_utils_test.py @@ -0,0 +1,103 @@ +import io +import unittest +from unittest.mock import patch, MagicMock +import common_utils +from common_utils import checkout, force_release_checkout +from braingeneers.iot import messaging +import os +import tempfile +import braingeneers.utils.smart_open_braingeneers as smart_open +from typing import Union + + +class TestFileListFunction(unittest.TestCase): + + @patch('common_utils._lazy_init_s3_client') # Updated to common_utils + def test_s3_files_exist(self, mock_s3_client): + # Mock S3 client response + mock_response = { + 'Contents': [ + {'Key': 'file1.txt', 'LastModified': '2023-01-01', 'Size': 123}, + {'Key': 'file2.txt', 'LastModified': '2023-01-02', 'Size': 456} + ] + } + mock_s3_client.return_value.list_objects.return_value = mock_response + + result = common_utils.file_list('s3://test-bucket/') # Updated to common_utils + expected = [('file2.txt', '2023-01-02', 456), ('file1.txt', '2023-01-01', 123)] + self.assertEqual(result, expected) + + @patch('common_utils._lazy_init_s3_client') # Updated to common_utils + def test_s3_no_files(self, mock_s3_client): + # Mock S3 client response for no files + mock_s3_client.return_value.list_objects.return_value = {} + result = common_utils.file_list('s3://test-bucket/') # Updated to common_utils + self.assertEqual(result, []) + + def test_local_files_exist(self): + with tempfile.TemporaryDirectory() as temp_dir: + for f in ['tempfile1.txt', 'tempfile2.txt']: + with open(os.path.join(temp_dir, f), 'w') as w: + w.write('nothing') + + result = common_utils.file_list(temp_dir) # Updated to common_utils + # The result should contain two files with their details + self.assertEqual(len(result), 2) + + def test_local_no_files(self): + with tempfile.TemporaryDirectory() as temp_dir: + result = common_utils.file_list(temp_dir) # Updated to common_utils + self.assertEqual(result, []) + + +class TestCheckout(unittest.TestCase): + + def setUp(self): + # Setup mock for smart_open and MessageBroker + self.message_broker_patch = patch('braingeneers.iot.messaging.MessageBroker') + + # Start the patches + self.mock_message_broker = self.message_broker_patch.start() + + # Mock the message broker's get_lock and delete_lock methods + self.mock_message_broker.return_value.get_lock.return_value = MagicMock() + self.mock_message_broker.return_value.delete_lock = MagicMock() + + self.mock_file = MagicMock(spec=io.StringIO) + self.mock_file.read.return_value = 'Test data' # Ensure this is correctly setting the return value for read + self.mock_file.__enter__.return_value = self.mock_file + self.mock_file.__exit__.return_value = None + self.smart_open_mock = MagicMock(spec=smart_open) + self.smart_open_mock.open.return_value = self.mock_file + + common_utils.smart_open = self.smart_open_mock + + def tearDown(self): + # Stop all patches + self.message_broker_patch.stop() + + def test_checkout_context_manager_read(self): + # Test the reading functionality + with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj: + data = locked_obj.get_value() + self.assertEqual(data, 'Test data') + + def test_checkout_context_manager_write_text(self): + # Test the writing functionality for text mode + test_data = 'New test data' + self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test + with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj: + locked_obj.checkin(test_data) + self.mock_file.write.assert_called_once_with(test_data) + + def test_checkout_context_manager_write_binary(self): + # Test the writing functionality for binary mode + test_data = b'New binary data' + self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test + with checkout('s3://test-bucket/test-file.bin', isbinary=True) as locked_obj: + locked_obj.checkin(test_data) + self.mock_file.write.assert_called_once_with(test_data) + + +if __name__ == '__main__': + unittest.main()