From d5ec52a09e49ca170d12b6416236673b24506ff4 Mon Sep 17 00:00:00 2001 From: David Parks Date: Tue, 12 Mar 2024 11:53:04 -0700 Subject: [PATCH] Add checkout/checkin functionality to common_utils to lock access to S3 files (#68) Created checkout/in to support context manager for atomic access to S3 files. Unit tests created and tested. --- src/braingeneers/utils/common_utils.py | 147 +++++++++++--------- src/braingeneers/utils/common_utils_test.py | 110 +++++++-------- 2 files changed, 135 insertions(+), 122 deletions(-) diff --git a/src/braingeneers/utils/common_utils.py b/src/braingeneers/utils/common_utils.py index ddd9124..c0a4031 100644 --- a/src/braingeneers/utils/common_utils.py +++ b/src/braingeneers/utils/common_utils.py @@ -7,14 +7,10 @@ import braingeneers import braingeneers.utils.smart_open_braingeneers as smart_open from typing import Callable, Iterable, Union, List, Tuple, Dict, Any -import functools import inspect import multiprocessing import posixpath -import itertools import pathlib -import json -import hashlib _s3_client = None # S3 client for boto3, lazy initialization performed in _lazy_init_s3_client() _message_broker = None # Lazy initialization of the message broker @@ -204,83 +200,109 @@ def f(x, y): return list(result_iterator) -def checkout(s3_file: str, mode: str = 'r') -> io.IOBase: +class checkout: """ - Check out a file from S3 for reading or writing, use checkin to release the file. - Any subsequent calls to checkout will block until the file is returned with checkin(s3_file). + A context manager for atomically checking out a file from S3 for reading or writing. Example usage: - f = checkout('s3://braingeneersdev/test/test_file.bin', mode='rb') - new_bytes = do_something(f.read()) - checkin('s3://braingeneersdev/test/test_file.bin', new_bytes) - Example usage to update metadata: - f = checkout('s3://braingeneersdev/test/metadata.json') - metadata_dict = json.loads(f.read()) + # Read-then-update metadata.json (or any text based file on S3) + with checkout('s3://braingeneers/ephys/9999-0-0-e-test/metadata.json', isbinary=False) as locked_obj: + metadata_dict = json.loads(locked_obj.get_value()) metadata_dict['new_key'] = 'new_value' metadata_updated_str = json.dumps(metadata_dict, indent=2) - checkin('s3://braingeneersdev/test/metadata.json', updated_metadata_str) - - :param s3_file: The S3 file path to check out. - :param mode: The mode to open the file in, 'r' (text mode) or 'rb' (binary mode), analogous to system open(filename, mode) - """ - # Avoid circular import - from braingeneers.iot.messaging import MessageBroker - - assert mode in ('r', 'rb'), 'Use "r" (text) or "rb" (binary) mode only. File changes are applied at checkout(...)' - - global _message_broker, _named_locks - if _message_broker is None: - print('creating message broker') - _message_broker = MessageBroker() - mb = _message_broker - - lock_str = f'common-utils-checkout-{s3_file}' - named_lock = mb.get_lock(lock_str) - named_lock.acquire() - _named_locks[s3_file] = named_lock - f = smart_open.open(s3_file, mode) - return f - - -def checkin(s3_file: str, file: Union[str, bytes, io.IOBase]): - """ - Releases a file that was checked out with checkout. - - :param s3_file: The S3 file path, must match checkout. - :param file: The string, bytes, or file object to write back to S3. + locked_obj.checkin(metadata_updated_str) + + # Read-then-update data.npy (or any binary file on S3) + with checkout('s3://braingeneersdev/test/data.npy', isbinary=True) as locked_obj: + file_obj = locked_obj.get_file() + ndarray = np.load(file_obj) + ndarray[3, 3] = 42 + locked_obj.checkin(ndarray.tobytes()) + + # Edit a file in place, note checkin is not needed, the file is updated when the context manager exits + with checkout('s3://braingeneersdev/test/test_file.bin', isbinary=True) as locked_obj: + with zipfile.ZipFile(locked_obj.get_file(), 'a') as z: + z.writestr('new_file.txt', 'new file contents') + + locked_obj functions: + get_value() # returns a string or bytes object (depending on isbinary) + get_file() # returns a file-like object akin to open() + checkin() # updates the file, accepts string, bytes, or file like objects """ - assert isinstance(file, (str, bytes, io.IOBase)), 'file must be a string, bytes, or file object.' - - with smart_open.open(s3_file, 'wb') as f: - if isinstance(file, str): - f.write(file.encode()) - elif isinstance(file, bytes): - f.write(file) - else: - file.seek(0) - data = file.read() - f.write(data if isinstance(data, bytes) else data.encode()) - - global _named_locks - named_lock = _named_locks[s3_file] - named_lock.release() + class LockedObject: + def __init__(self, s3_file_object: io.IOBase, s3_path_str: str, isbinary: bool): + self.s3_path_str = s3_path_str + self.s3_file_object = s3_file_object # underlying file object + self.isbinary = isbinary # binary or text mode + self.modified = False # Track if the file has been modified + + def get_value(self): + # Read file object from outer class s3_file_object + self.s3_file_object.seek(0) + return self.s3_file_object.read() + + def get_file(self): + # Mark file as potentially modified when accessed + self.modified = True + # Return file object from outer class s3_file_object + self.s3_file_object.seek(0) + return self.s3_file_object + + def checkin(self, update_file: Union[str, bytes, io.IOBase]): + # Validate input + if not isinstance(update_file, (str, bytes, io.IOBase)): + raise TypeError('File must be a string, bytes, or file object.') + if isinstance(update_file, str) or isinstance(update_file, io.StringIO): + if self.isbinary: + raise ValueError( + 'Cannot check in a string or text file when checkout is specified for binary mode.') + if isinstance(update_file, bytes) or isinstance(update_file, io.BytesIO): + if not self.isbinary: + raise ValueError('Cannot check in bytes or a binary file when checkout is specified for text mode.') + + mode = 'w' if not self.isbinary else 'wb' + with smart_open.open(self.s3_path_str, mode=mode) as f: + f.write(update_file if not isinstance(update_file, io.IOBase) else update_file.read()) + + def __init__(self, s3_path_str: str, isbinary: bool = False): + # TODO: avoid circular import + from braingeneers.iot.messaging import MessageBroker + + self.s3_path_str = s3_path_str + self.isbinary = isbinary + self.mb = MessageBroker() + self.named_lock = None # message broker lock + self.locked_obj = None # user facing locked object + + def __enter__(self): + lock_str = f'common-utils-checkout-{self.s3_path_str}' + named_lock = self.mb.get_lock(lock_str) + named_lock.acquire() + self.named_lock = named_lock + f = smart_open.open(self.s3_path_str, 'rb' if self.isbinary else 'r') + self.locked_obj = checkout.LockedObject(f, self.s3_path_str, self.isbinary) + return self.locked_obj + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.locked_obj.modified: + # If the file was modified, automatically check in the changes + self.locked_obj.checkin(self.locked_obj.get_file()) + self.named_lock.release() def force_release_checkout(s3_file: str): """ Force release the lock on a file that was checked out with checkout. """ - # Avoid circular import + # TODO: avoid circular import from braingeneers.iot.messaging import MessageBroker global _message_broker if _message_broker is None: _message_broker = MessageBroker() - mb = _message_broker - lock_str = f'common-utils-checkout-{s3_file}' - mb.delete_lock(lock_str) + _message_broker.delete_lock(f'common-utils-checkout-{s3_file}') def pretty_print(data, n=10, indent=0): @@ -342,4 +364,3 @@ def pretty_print(data, n=10, indent=0): if len(data) > n: print(f"{indent_space} ... (+{len(data) - n} more items)") print(f"{indent_space}]", end='') - diff --git a/src/braingeneers/utils/common_utils_test.py b/src/braingeneers/utils/common_utils_test.py index ae812b9..be1cc90 100644 --- a/src/braingeneers/utils/common_utils_test.py +++ b/src/braingeneers/utils/common_utils_test.py @@ -1,15 +1,13 @@ +import io import unittest from unittest.mock import patch, MagicMock -from common_utils import checkout, checkin, force_release_checkout, map2 -from braingeneers.iot import messaging import common_utils +from common_utils import checkout, force_release_checkout +from braingeneers.iot import messaging import os import tempfile import braingeneers.utils.smart_open_braingeneers as smart_open - - -def multiply(x, y): - return x * y +from typing import Union class TestFileListFunction(unittest.TestCase): @@ -52,59 +50,53 @@ def test_local_no_files(self): self.assertEqual(result, []) -class TestCheckingCheckout(unittest.TestCase): - def setUp(self) -> None: - self.text_value = 'unittest1' - self.filepath = 's3://braingeneersdev/unittest/test.txt' - force_release_checkout(self.filepath) - - with smart_open.open(self.filepath, 'w') as f: - f.write(self.text_value) - - def test_checkout_checkin(self): - f = checkout(self.filepath) - self.assertEqual(f.read(), self.text_value) - checkin(self.filepath, f) - - -class TestMap2(unittest.TestCase): - def test_basic_functionality(self): - """Test map2 with a simple function, no fixed values, no parallelism.""" - - def simple_add(x, y): - return x + y - - args = [(1, 2), (2, 3), (3, 4)] - expected = [3, 5, 7] - result = map2(simple_add, args=args, parallelism=False) - self.assertEqual(result, expected) - - def test_with_fixed_values(self): - """Test map2 with fixed values.""" - - def f(a, b, c): - return f'{a} {b} {c}' - - args = [2, 20, 200] - expected = ['1 2 3', '1 20 3', '1 200 3'] - result = map2(func=f, args=args, fixed_values=dict(a=1, c=3), parallelism=False) - self.assertEqual(result, expected) - - def test_with_parallelism(self): - """Test map2 with parallelism enabled (assuming the environment supports it).""" - args = [(1, 2), (2, 3), (3, 4)] - expected = [2, 6, 12] - result = map2(multiply, args=args, parallelism=True) - self.assertEqual(result, expected) - - def test_with_invalid_args(self): - """Test map2 with invalid args to ensure it raises the correct exceptions.""" - - def simple_subtract(x, y): - return x - y - - with self.assertRaises(AssertionError): - map2(simple_subtract, args=[1], parallelism="invalid") +class TestCheckout(unittest.TestCase): + + def setUp(self): + # Setup mock for smart_open and MessageBroker + self.message_broker_patch = patch('braingeneers.iot.messaging.MessageBroker') + + # Start the patches + self.mock_message_broker = self.message_broker_patch.start() + + # Mock the message broker's get_lock and delete_lock methods + self.mock_message_broker.return_value.get_lock.return_value = MagicMock() + self.mock_message_broker.return_value.delete_lock = MagicMock() + + self.mock_file = MagicMock(spec=io.StringIO) + self.mock_file.read.return_value = 'Test data' # Ensure this is correctly setting the return value for read + self.mock_file.__enter__.return_value = self.mock_file + self.mock_file.__exit__.return_value = None + self.smart_open_mock = MagicMock(spec=smart_open) + self.smart_open_mock.open.return_value = self.mock_file + + common_utils.smart_open = self.smart_open_mock + + def tearDown(self): + # Stop all patches + self.message_broker_patch.stop() + + def test_checkout_context_manager_read(self): + # Test the reading functionality + with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj: + data = locked_obj.get_value() + self.assertEqual(data, 'Test data') + + def test_checkout_context_manager_write_text(self): + # Test the writing functionality for text mode + test_data = 'New test data' + self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test + with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj: + locked_obj.checkin(test_data) + self.mock_file.write.assert_called_once_with(test_data) + + def test_checkout_context_manager_write_binary(self): + # Test the writing functionality for binary mode + test_data = b'New binary data' + self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test + with checkout('s3://test-bucket/test-file.bin', isbinary=True) as locked_obj: + locked_obj.checkin(test_data) + self.mock_file.write.assert_called_once_with(test_data) if __name__ == '__main__':