Skip to content

Commit

Permalink
Added datasets_electrophysiology.cached_load_data function which wrap…
Browse files Browse the repository at this point in the history
…s load_data and provides a local disk cache.
  • Loading branch information
davidparks21 committed Mar 6, 2024
1 parent fedc18f commit fbf73f8
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 13 deletions.
46 changes: 37 additions & 9 deletions src/braingeneers/data/datasets_electrophysiology.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import warnings
import copy
import diskcache

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -60,15 +61,10 @@ def list_uuids():

def save_metadata(metadata: dict):
"""
Saves a metadata file back to S3. This is not multi-writer safe, you can use a lock as shown in the example:
from braingeneers.iot.messaging import MessageBroker()
import braingeneers.data.datasets_electrophysiology as de
with MessageBroker().get_lock('a-unique-lock-name-for-your-process'):
metadata = de.load_metadata(uuid)
metadata = do_something_to(metadata)
de.save_metadata(metadata)
Saves a metadata file back to S3. This is not multi-writer safe, you can use:
braingeneers.utils.common_utils.checkout
braingeneers.utils.common_utils.checkin
to lock the file while you are writing to it.
:param metadata: the metadata dictionary as obtained from load_metadata(uuid)
"""
Expand All @@ -83,6 +79,38 @@ def save_metadata(metadata: dict):
f.write(json.dumps(metadata, indent=2))


def cached_load_data(cache_path: str, max_size_gb: int = 10, **kwargs):
"""
Wraps a call to load_data with a diskcache at path `cache_path`.
This is multiprocessing/thread safe.
All arguments after the cache_path are passed to load_data (see load_data docs)
You must specify the load_data argument names to avoid ambiguity with the cached_load_data parameters.
When reading data from S3 (or even a compressed local file), this can provide a significant speedup by
storing the results of load_data in a local (uncompressed) cache.
Example usage:
from braingeneers.data.datasets_electrophysiology import load_metadata, cached_load_data
metadata = load_metadata('9999-00-00-e-test')
data = cached_load_data(cache_path='/tmp/cache-dir', metadata=metadata, experiment=0, offset=0, length=1000)
Note: this can safely be used with `map2` from `braingeneers.utils.common_utils` to parallelize calls to load_data.
:param cache_path: str, path to the cache directory.
:param max_size_gb: int, maximum size of the cache in GB (10 GB default). If the cache exceeds this size, the oldest items will be removed.
:param kwargs: keyword arguments to pass to load_data, see load_data documentation.
"""
cache = diskcache.Cache(cache_path, size_limit=10 ** 9 * max_size_gb)
key = json.dumps(kwargs)
if key in cache:
return cache[key]
else:
data = load_data(**kwargs)
cache[key] = data
return data


def load_metadata(batch_uuid: str) -> dict:
"""
Loads the batch UUID metadata.
Expand Down
94 changes: 90 additions & 4 deletions src/braingeneers/data/datasets_electrophysiology_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import unittest
import tempfile
import shutil
import diskcache
import json
import threading
import braingeneers
import braingeneers.data.datasets_electrophysiology as ephys
import json
from braingeneers import skip_unittest_if_offline
# import braingeneers.utils.smart_open_braingeneers as smart_open
import smart_open
import braingeneers.utils.smart_open_braingeneers as smart_open
import boto3
import numpy as np

from unittest.mock import patch
from braingeneers.data.datasets_electrophysiology import cached_load_data
from unittest.mock import patch


class MaxwellReaderTests(unittest.TestCase):

@skip_unittest_if_offline
Expand Down Expand Up @@ -414,5 +419,86 @@ def test_online_load_data_hengenlab_float32(self):
self.assertEqual(np.float32, data.dtype)


class TestCachedLoadData(unittest.TestCase):

def setUp(self):
# Create a temporary directory for the cache
self.cache_dir = tempfile.mkdtemp(prefix='test_cache_')

def tearDown(self):
# Remove the temporary directory after the test
shutil.rmtree(self.cache_dir)

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_caching_mechanism(self, mock_load_data):
"""
Test that data is properly cached and retrieved on subsequent calls with the same parameters.
"""
mock_load_data.return_value = 'mock_data'
metadata = {'uuid': 'test_uuid'}

# First call should invoke load_data
first_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0)
mock_load_data.assert_called_once()

# Second call should retrieve data from cache and not invoke load_data again
second_call_data = cached_load_data(self.cache_dir, metadata=metadata, experiment=0)
self.assertEqual(first_call_data, second_call_data)
mock_load_data.assert_called_once() # Still called only once

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_cache_eviction_when_full(self, mock_load_data):
"""
Test that the oldest items are evicted from the cache when it exceeds its size limit.
"""
mock_load_data.side_effect = lambda **kwargs: f"data_{kwargs['experiment']}"
max_size_gb = 0.000001 # Set a very small cache size to test eviction

# Populate the cache with enough data to exceed its size limit
for i in range(10):
cached_load_data(self.cache_dir, max_size_gb=max_size_gb, metadata={'uuid': 'test_uuid'}, experiment=i)

cache = diskcache.Cache(self.cache_dir)
self.assertLess(len(cache), 10) # Ensure some items were evicted

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_arguments_passed_to_load_data(self, mock_load_data):
"""
Test that all arguments after cache_path are correctly passed to the underlying load_data function.
"""
# Mock load_data to return a serializable object, e.g., a numpy array
mock_load_data.return_value = np.array([1, 2, 3])

kwargs = {'metadata': {'uuid': 'test_uuid'}, 'experiment': 0, 'offset': 0, 'length': 1000}
cached_load_data(self.cache_dir, **kwargs)
mock_load_data.assert_called_with(**kwargs)

@patch('braingeneers.data.datasets_electrophysiology.load_data')
def test_multiprocessing_thread_safety(self, mock_load_data):
"""
Test that the caching mechanism is multiprocessing/thread-safe.
"""
# Mock load_data to return a serializable object, e.g., a numpy array
mock_load_data.return_value = np.array([1, 2, 3])

def thread_function(cache_path, metadata, experiment):
# This function uses the mocked load_data indirectly via cached_load_data
cached_load_data(cache_path, metadata=metadata, experiment=experiment)

metadata = {'uuid': 'test_uuid'}
threads = []
for i in range(10):
t = threading.Thread(target=thread_function, args=(self.cache_dir, metadata, i))
threads.append(t)
t.start()

for t in threads:
t.join()

# If the cache is thread-safe, this operation should complete without error
# This assertion is basic and assumes the test's success implies thread safety
self.assertTrue(True)


if __name__ == '__main__':
unittest.main()

0 comments on commit fbf73f8

Please sign in to comment.