diff --git a/src/braingeneers/data/datasets_electrophysiology.py b/src/braingeneers/data/datasets_electrophysiology.py index e5ffabf..c1502e0 100644 --- a/src/braingeneers/data/datasets_electrophysiology.py +++ b/src/braingeneers/data/datasets_electrophysiology.py @@ -5,6 +5,7 @@ import json import warnings import copy +import diskcache import matplotlib.pyplot as plt import numpy as np @@ -60,15 +61,10 @@ def list_uuids(): def save_metadata(metadata: dict): """ - Saves a metadata file back to S3. This is not multi-writer safe, you can use a lock as shown in the example: - - from braingeneers.iot.messaging import MessageBroker() - import braingeneers.data.datasets_electrophysiology as de - - with MessageBroker().get_lock('a-unique-lock-name-for-your-process'): - metadata = de.load_metadata(uuid) - metadata = do_something_to(metadata) - de.save_metadata(metadata) + Saves a metadata file back to S3. This is not multi-writer safe, you can use: + braingeneers.utils.common_utils.checkout + braingeneers.utils.common_utils.checkin + to lock the file while you are writing to it. :param metadata: the metadata dictionary as obtained from load_metadata(uuid) """ @@ -83,6 +79,38 @@ def save_metadata(metadata: dict): f.write(json.dumps(metadata, indent=2)) +def cached_load_data(cache_path: str, max_size_gb: int = 10, **kwargs): + """ + Wraps a call to load_data with a diskcache at path `cache_path`. + This is multiprocessing/thread safe. + All arguments after the cache_path are passed to load_data (see load_data docs) + You must specify the load_data argument names to avoid ambiguity with the cached_load_data parameters. + + When reading data from S3 (or even a compressed local file), this can provide a significant speedup by + storing the results of load_data in a local (uncompressed) cache. + + Example usage: + from braingeneers.data.datasets_electrophysiology import load_metadata, cached_load_data + + metadata = load_metadata('9999-00-00-e-test') + data = cached_load_data(cache_path='/tmp/cache-dir', metadata=metadata, experiment=0, offset=0, length=1000) + + Note: this can safely be used with `map2` from `braingeneers.utils.common_utils` to parallelize calls to load_data. + + :param cache_path: str, path to the cache directory. + :param max_size_gb: int, maximum size of the cache in GB (10 GB default). If the cache exceeds this size, the oldest items will be removed. + :param kwargs: keyword arguments to pass to load_data, see load_data documentation. + """ + cache = diskcache.Cache(cache_path, size_limit=10 ** 9 * max_size_gb) + key = json.dumps(kwargs) + if key in cache: + return cache[key] + else: + data = load_data(**kwargs) + cache[key] = data + return data + + def load_metadata(batch_uuid: str) -> dict: """ Loads the batch UUID metadata. diff --git a/src/braingeneers/data/datasets_electrophysiology_test.py b/src/braingeneers/data/datasets_electrophysiology_test.py index 0ce7f7a..807b576 100644 --- a/src/braingeneers/data/datasets_electrophysiology_test.py +++ b/src/braingeneers/data/datasets_electrophysiology_test.py @@ -1,15 +1,20 @@ import unittest +import tempfile +import shutil +import diskcache +import json +import threading import braingeneers import braingeneers.data.datasets_electrophysiology as ephys -import json from braingeneers import skip_unittest_if_offline -# import braingeneers.utils.smart_open_braingeneers as smart_open -import smart_open +import braingeneers.utils.smart_open_braingeneers as smart_open import boto3 import numpy as np - +from unittest.mock import patch +from braingeneers.data.datasets_electrophysiology import cached_load_data from unittest.mock import patch + class MaxwellReaderTests(unittest.TestCase): @skip_unittest_if_offline @@ -414,5 +419,86 @@ def test_online_load_data_hengenlab_float32(self): self.assertEqual(np.float32, data.dtype) +class TestCachedLoadData(unittest.TestCase): + + def setUp(self): + # Create a temporary directory for the cache + self.cache_dir = tempfile.mkdtemp(prefix='test_cache_') + + def tearDown(self): + # Remove the temporary directory after the test + shutil.rmtree(self.cache_dir) + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_caching_mechanism(self, mock_load_data): + """ + Test that data is properly cached and retrieved on subsequent calls with the same parameters. + """ + mock_load_data.return_value = 'mock_data' + metadata = {'uuid': 'test_uuid'} + + # First call should invoke load_data + first_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0) + mock_load_data.assert_called_once() + + # Second call should retrieve data from cache and not invoke load_data again + second_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0) + self.assertEqual(first_call_data, second_call_data) + mock_load_data.assert_called_once() # Still called only once + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_cache_eviction_when_full(self, mock_load_data): + """ + Test that the oldest items are evicted from the cache when it exceeds its size limit. + """ + mock_load_data.side_effect = lambda **kwargs: f"data_{kwargs['experiment']}" + max_size_gb = 0.000001 # Set a very small cache size to test eviction + + # Populate the cache with enough data to exceed its size limit + for i in range(10): + cached_load_data(self.cache_dir, max_size_gb=max_size_gb, metadata={'uuid': 'test_uuid'}, experiment=i) + + cache = diskcache.Cache(self.cache_dir) + self.assertLess(len(cache), 10) # Ensure some items were evicted + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_arguments_passed_to_load_data(self, mock_load_data): + """ + Test that all arguments after cache_path are correctly passed to the underlying load_data function. + """ + # Mock load_data to return a serializable object, e.g., a numpy array + mock_load_data.return_value = np.array([1, 2, 3]) + + kwargs = {'metadata': {'uuid': 'test_uuid'}, 'experiment': 0, 'offset': 0, 'length': 1000} + cached_load_data(self.cache_dir, **kwargs) + mock_load_data.assert_called_with(**kwargs) + + @patch('braingeneers.data.datasets_electrophysiology.load_data') + def test_multiprocessing_thread_safety(self, mock_load_data): + """ + Test that the caching mechanism is multiprocessing/thread-safe. + """ + # Mock load_data to return a serializable object, e.g., a numpy array + mock_load_data.return_value = np.array([1, 2, 3]) + + def thread_function(cache_path, metadata, experiment): + # This function uses the mocked load_data indirectly via cached_load_data + cached_load_data(cache_path, metadata=metadata, experiment=experiment) + + metadata = {'uuid': 'test_uuid'} + threads = [] + for i in range(10): + t = threading.Thread(target=thread_function, args=(self.cache_dir, metadata, i)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # If the cache is thread-safe, this operation should complete without error + # This assertion is basic and assumes the test's success implies thread safety + self.assertTrue(True) + + if __name__ == '__main__': unittest.main()