Skip to content

Commit

Permalink
Add checkout/checkin functionality to common_utils to lock access to …
Browse files Browse the repository at this point in the history
…S3 files (#68)

Created checkout/in to support context manager for atomic access to S3 files. Unit tests created and tested.
  • Loading branch information
davidparks21 authored Mar 12, 2024
1 parent ef35293 commit d5ec52a
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 122 deletions.
147 changes: 84 additions & 63 deletions src/braingeneers/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
import braingeneers
import braingeneers.utils.smart_open_braingeneers as smart_open
from typing import Callable, Iterable, Union, List, Tuple, Dict, Any
import functools
import inspect
import multiprocessing
import posixpath
import itertools
import pathlib
import json
import hashlib

_s3_client = None # S3 client for boto3, lazy initialization performed in _lazy_init_s3_client()
_message_broker = None # Lazy initialization of the message broker
Expand Down Expand Up @@ -204,83 +200,109 @@ def f(x, y):
return list(result_iterator)


def checkout(s3_file: str, mode: str = 'r') -> io.IOBase:
class checkout:
"""
Check out a file from S3 for reading or writing, use checkin to release the file.
Any subsequent calls to checkout will block until the file is returned with checkin(s3_file).
A context manager for atomically checking out a file from S3 for reading or writing.
Example usage:
f = checkout('s3://braingeneersdev/test/test_file.bin', mode='rb')
new_bytes = do_something(f.read())
checkin('s3://braingeneersdev/test/test_file.bin', new_bytes)
Example usage to update metadata:
f = checkout('s3://braingeneersdev/test/metadata.json')
metadata_dict = json.loads(f.read())
# Read-then-update metadata.json (or any text based file on S3)
with checkout('s3://braingeneers/ephys/9999-0-0-e-test/metadata.json', isbinary=False) as locked_obj:
metadata_dict = json.loads(locked_obj.get_value())
metadata_dict['new_key'] = 'new_value'
metadata_updated_str = json.dumps(metadata_dict, indent=2)
checkin('s3://braingeneersdev/test/metadata.json', updated_metadata_str)
:param s3_file: The S3 file path to check out.
:param mode: The mode to open the file in, 'r' (text mode) or 'rb' (binary mode), analogous to system open(filename, mode)
"""
# Avoid circular import
from braingeneers.iot.messaging import MessageBroker

assert mode in ('r', 'rb'), 'Use "r" (text) or "rb" (binary) mode only. File changes are applied at checkout(...)'

global _message_broker, _named_locks
if _message_broker is None:
print('creating message broker')
_message_broker = MessageBroker()
mb = _message_broker

lock_str = f'common-utils-checkout-{s3_file}'
named_lock = mb.get_lock(lock_str)
named_lock.acquire()
_named_locks[s3_file] = named_lock
f = smart_open.open(s3_file, mode)
return f


def checkin(s3_file: str, file: Union[str, bytes, io.IOBase]):
"""
Releases a file that was checked out with checkout.
:param s3_file: The S3 file path, must match checkout.
:param file: The string, bytes, or file object to write back to S3.
locked_obj.checkin(metadata_updated_str)
# Read-then-update data.npy (or any binary file on S3)
with checkout('s3://braingeneersdev/test/data.npy', isbinary=True) as locked_obj:
file_obj = locked_obj.get_file()
ndarray = np.load(file_obj)
ndarray[3, 3] = 42
locked_obj.checkin(ndarray.tobytes())
# Edit a file in place, note checkin is not needed, the file is updated when the context manager exits
with checkout('s3://braingeneersdev/test/test_file.bin', isbinary=True) as locked_obj:
with zipfile.ZipFile(locked_obj.get_file(), 'a') as z:
z.writestr('new_file.txt', 'new file contents')
locked_obj functions:
get_value() # returns a string or bytes object (depending on isbinary)
get_file() # returns a file-like object akin to open()
checkin() # updates the file, accepts string, bytes, or file like objects
"""
assert isinstance(file, (str, bytes, io.IOBase)), 'file must be a string, bytes, or file object.'

with smart_open.open(s3_file, 'wb') as f:
if isinstance(file, str):
f.write(file.encode())
elif isinstance(file, bytes):
f.write(file)
else:
file.seek(0)
data = file.read()
f.write(data if isinstance(data, bytes) else data.encode())

global _named_locks
named_lock = _named_locks[s3_file]
named_lock.release()
class LockedObject:
def __init__(self, s3_file_object: io.IOBase, s3_path_str: str, isbinary: bool):
self.s3_path_str = s3_path_str
self.s3_file_object = s3_file_object # underlying file object
self.isbinary = isbinary # binary or text mode
self.modified = False # Track if the file has been modified

def get_value(self):
# Read file object from outer class s3_file_object
self.s3_file_object.seek(0)
return self.s3_file_object.read()

def get_file(self):
# Mark file as potentially modified when accessed
self.modified = True
# Return file object from outer class s3_file_object
self.s3_file_object.seek(0)
return self.s3_file_object

def checkin(self, update_file: Union[str, bytes, io.IOBase]):
# Validate input
if not isinstance(update_file, (str, bytes, io.IOBase)):
raise TypeError('File must be a string, bytes, or file object.')
if isinstance(update_file, str) or isinstance(update_file, io.StringIO):
if self.isbinary:
raise ValueError(
'Cannot check in a string or text file when checkout is specified for binary mode.')
if isinstance(update_file, bytes) or isinstance(update_file, io.BytesIO):
if not self.isbinary:
raise ValueError('Cannot check in bytes or a binary file when checkout is specified for text mode.')

mode = 'w' if not self.isbinary else 'wb'
with smart_open.open(self.s3_path_str, mode=mode) as f:
f.write(update_file if not isinstance(update_file, io.IOBase) else update_file.read())

def __init__(self, s3_path_str: str, isbinary: bool = False):
# TODO: avoid circular import
from braingeneers.iot.messaging import MessageBroker

self.s3_path_str = s3_path_str
self.isbinary = isbinary
self.mb = MessageBroker()
self.named_lock = None # message broker lock
self.locked_obj = None # user facing locked object

def __enter__(self):
lock_str = f'common-utils-checkout-{self.s3_path_str}'
named_lock = self.mb.get_lock(lock_str)
named_lock.acquire()
self.named_lock = named_lock
f = smart_open.open(self.s3_path_str, 'rb' if self.isbinary else 'r')
self.locked_obj = checkout.LockedObject(f, self.s3_path_str, self.isbinary)
return self.locked_obj

def __exit__(self, exc_type, exc_val, exc_tb):
if self.locked_obj.modified:
# If the file was modified, automatically check in the changes
self.locked_obj.checkin(self.locked_obj.get_file())
self.named_lock.release()


def force_release_checkout(s3_file: str):
"""
Force release the lock on a file that was checked out with checkout.
"""
# Avoid circular import
# TODO: avoid circular import
from braingeneers.iot.messaging import MessageBroker

global _message_broker
if _message_broker is None:
_message_broker = MessageBroker()
mb = _message_broker

lock_str = f'common-utils-checkout-{s3_file}'
mb.delete_lock(lock_str)
_message_broker.delete_lock(f'common-utils-checkout-{s3_file}')


def pretty_print(data, n=10, indent=0):
Expand Down Expand Up @@ -342,4 +364,3 @@ def pretty_print(data, n=10, indent=0):
if len(data) > n:
print(f"{indent_space} ... (+{len(data) - n} more items)")
print(f"{indent_space}]", end='')

110 changes: 51 additions & 59 deletions src/braingeneers/utils/common_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import io
import unittest
from unittest.mock import patch, MagicMock
from common_utils import checkout, checkin, force_release_checkout, map2
from braingeneers.iot import messaging
import common_utils
from common_utils import checkout, force_release_checkout
from braingeneers.iot import messaging
import os
import tempfile
import braingeneers.utils.smart_open_braingeneers as smart_open


def multiply(x, y):
return x * y
from typing import Union


class TestFileListFunction(unittest.TestCase):
Expand Down Expand Up @@ -52,59 +50,53 @@ def test_local_no_files(self):
self.assertEqual(result, [])


class TestCheckingCheckout(unittest.TestCase):
def setUp(self) -> None:
self.text_value = 'unittest1'
self.filepath = 's3://braingeneersdev/unittest/test.txt'
force_release_checkout(self.filepath)

with smart_open.open(self.filepath, 'w') as f:
f.write(self.text_value)

def test_checkout_checkin(self):
f = checkout(self.filepath)
self.assertEqual(f.read(), self.text_value)
checkin(self.filepath, f)


class TestMap2(unittest.TestCase):
def test_basic_functionality(self):
"""Test map2 with a simple function, no fixed values, no parallelism."""

def simple_add(x, y):
return x + y

args = [(1, 2), (2, 3), (3, 4)]
expected = [3, 5, 7]
result = map2(simple_add, args=args, parallelism=False)
self.assertEqual(result, expected)

def test_with_fixed_values(self):
"""Test map2 with fixed values."""

def f(a, b, c):
return f'{a} {b} {c}'

args = [2, 20, 200]
expected = ['1 2 3', '1 20 3', '1 200 3']
result = map2(func=f, args=args, fixed_values=dict(a=1, c=3), parallelism=False)
self.assertEqual(result, expected)

def test_with_parallelism(self):
"""Test map2 with parallelism enabled (assuming the environment supports it)."""
args = [(1, 2), (2, 3), (3, 4)]
expected = [2, 6, 12]
result = map2(multiply, args=args, parallelism=True)
self.assertEqual(result, expected)

def test_with_invalid_args(self):
"""Test map2 with invalid args to ensure it raises the correct exceptions."""

def simple_subtract(x, y):
return x - y

with self.assertRaises(AssertionError):
map2(simple_subtract, args=[1], parallelism="invalid")
class TestCheckout(unittest.TestCase):

def setUp(self):
# Setup mock for smart_open and MessageBroker
self.message_broker_patch = patch('braingeneers.iot.messaging.MessageBroker')

# Start the patches
self.mock_message_broker = self.message_broker_patch.start()

# Mock the message broker's get_lock and delete_lock methods
self.mock_message_broker.return_value.get_lock.return_value = MagicMock()
self.mock_message_broker.return_value.delete_lock = MagicMock()

self.mock_file = MagicMock(spec=io.StringIO)
self.mock_file.read.return_value = 'Test data' # Ensure this is correctly setting the return value for read
self.mock_file.__enter__.return_value = self.mock_file
self.mock_file.__exit__.return_value = None
self.smart_open_mock = MagicMock(spec=smart_open)
self.smart_open_mock.open.return_value = self.mock_file

common_utils.smart_open = self.smart_open_mock

def tearDown(self):
# Stop all patches
self.message_broker_patch.stop()

def test_checkout_context_manager_read(self):
# Test the reading functionality
with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj:
data = locked_obj.get_value()
self.assertEqual(data, 'Test data')

def test_checkout_context_manager_write_text(self):
# Test the writing functionality for text mode
test_data = 'New test data'
self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test
with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj:
locked_obj.checkin(test_data)
self.mock_file.write.assert_called_once_with(test_data)

def test_checkout_context_manager_write_binary(self):
# Test the writing functionality for binary mode
test_data = b'New binary data'
self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test
with checkout('s3://test-bucket/test-file.bin', isbinary=True) as locked_obj:
locked_obj.checkin(test_data)
self.mock_file.write.assert_called_once_with(test_data)


if __name__ == '__main__':
Expand Down

0 comments on commit d5ec52a

Please sign in to comment.