Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added datasets_electrophysiology.cached_load_data function #75

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/braingeneers/data/datasets_electrophysiology.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import warnings
import copy
import diskcache

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -60,15 +61,10 @@ def list_uuids():

def save_metadata(metadata: dict):
"""
Saves a metadata file back to S3. This is not multi-writer safe, you can use a lock as shown in the example:

from braingeneers.iot.messaging import MessageBroker()
import braingeneers.data.datasets_electrophysiology as de

with MessageBroker().get_lock('a-unique-lock-name-for-your-process'):
metadata = de.load_metadata(uuid)
metadata = do_something_to(metadata)
de.save_metadata(metadata)
Saves a metadata file back to S3. This is not multi-writer safe, you can use:
braingeneers.utils.common_utils.checkout
braingeneers.utils.common_utils.checkin
to lock the file while you are writing to it.

:param metadata: the metadata dictionary as obtained from load_metadata(uuid)
"""
Expand All @@ -83,6 +79,38 @@ def save_metadata(metadata: dict):
f.write(json.dumps(metadata, indent=2))


def cached_load_data(cache_path: str, max_size_gb: int = 10, **kwargs):
"""
Wraps a call to load_data with a diskcache at path `cache_path`.
This is multiprocessing/thread safe.
All arguments after the cache_path are passed to load_data (see load_data docs)
You must specify the load_data argument names to avoid ambiguity with the cached_load_data parameters.

When reading data from S3 (or even a compressed local file), this can provide a significant speedup by
storing the results of load_data in a local (uncompressed) cache.

Example usage:
from braingeneers.data.datasets_electrophysiology import load_metadata, cached_load_data

metadata = load_metadata('9999-00-00-e-test')
data = cached_load_data(cache_path='/tmp/cache-dir', metadata=metadata, experiment=0, offset=0, length=1000)

Note: this can safely be used with `map2` from `braingeneers.utils.common_utils` to parallelize calls to load_data.

:param cache_path: str, path to the cache directory.
:param max_size_gb: int, maximum size of the cache in GB (10 GB default). If the cache exceeds this size, the oldest items will be removed.
:param kwargs: keyword arguments to pass to load_data, see load_data documentation.
"""
cache = diskcache.Cache(cache_path, size_limit=10 ** 9 * max_size_gb)
key = json.dumps(kwargs)
if key in cache:
return cache[key]
else:
data = load_data(**kwargs)
cache[key] = data
return data


def load_metadata(batch_uuid: str) -> dict:
"""
Loads the batch UUID metadata.
Expand Down
94 changes: 90 additions & 4 deletions src/braingeneers/data/datasets_electrophysiology_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import unittest
import tempfile
import shutil
import diskcache
import json
import threading
import braingeneers
import braingeneers.data.datasets_electrophysiology as ephys
import json
from braingeneers import skip_unittest_if_offline
# import braingeneers.utils.smart_open_braingeneers as smart_open
import smart_open
import braingeneers.utils.smart_open_braingeneers as smart_open
import boto3
import numpy as np

from unittest.mock import patch
from braingeneers.data.datasets_electrophysiology import cached_load_data
from unittest.mock import patch


class MaxwellReaderTests(unittest.TestCase):

@skip_unittest_if_offline
Expand Down Expand Up @@ -414,5 +419,86 @@ def test_online_load_data_hengenlab_float32(self):
self.assertEqual(np.float32, data.dtype)


class TestCachedLoadData(unittest.TestCase):

def setUp(self):
# Create a temporary directory for the cache
self.cache_dir = tempfile.mkdtemp(prefix='test_cache_')

def tearDown(self):
# Remove the temporary directory after the test
shutil.rmtree(self.cache_dir)

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_caching_mechanism(self, mock_load_data):
"""
Test that data is properly cached and retrieved on subsequent calls with the same parameters.
"""
mock_load_data.return_value = 'mock_data'
metadata = {'uuid': 'test_uuid'}

# First call should invoke load_data
first_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0)
mock_load_data.assert_called_once()

# Second call should retrieve data from cache and not invoke load_data again
second_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0)
self.assertEqual(first_call_data, second_call_data)
mock_load_data.assert_called_once() # Still called only once

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_cache_eviction_when_full(self, mock_load_data):
"""
Test that the oldest items are evicted from the cache when it exceeds its size limit.
"""
mock_load_data.side_effect = lambda **kwargs: f"data_{kwargs['experiment']}"
max_size_gb = 0.000001 # Set a very small cache size to test eviction

# Populate the cache with enough data to exceed its size limit
for i in range(10):
cached_load_data(self.cache_dir, max_size_gb=max_size_gb, metadata={'uuid': 'test_uuid'}, experiment=i)

cache = diskcache.Cache(self.cache_dir)
self.assertLess(len(cache), 10) # Ensure some items were evicted

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_arguments_passed_to_load_data(self, mock_load_data):
"""
Test that all arguments after cache_path are correctly passed to the underlying load_data function.
"""
# Mock load_data to return a serializable object, e.g., a numpy array
mock_load_data.return_value = np.array([1, 2, 3])

kwargs = {'metadata': {'uuid': 'test_uuid'}, 'experiment': 0, 'offset': 0, 'length': 1000}
cached_load_data(self.cache_dir, **kwargs)
mock_load_data.assert_called_with(**kwargs)

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_multiprocessing_thread_safety(self, mock_load_data):
"""
Test that the caching mechanism is multiprocessing/thread-safe.
"""
# Mock load_data to return a serializable object, e.g., a numpy array
mock_load_data.return_value = np.array([1, 2, 3])

def thread_function(cache_path, metadata, experiment):
# This function uses the mocked load_data indirectly via cached_load_data
cached_load_data(cache_path, metadata=metadata, experiment=experiment)

metadata = {'uuid': 'test_uuid'}
threads = []
for i in range(10):
t = threading.Thread(target=thread_function, args=(self.cache_dir, metadata, i))
threads.append(t)
t.start()

for t in threads:
t.join()

# If the cache is thread-safe, this operation should complete without error
# This assertion is basic and assumes the test's success implies thread safety
self.assertTrue(True)


if __name__ == '__main__':
unittest.main()
Loading