From 760215acd2aa8ada4f1bda3683c0b77de7c89c64 Mon Sep 17 00:00:00 2001 From: Lon Blauvelt Date: Wed, 28 Feb 2024 13:47:18 -0800 Subject: [PATCH 1/5] Update MQTT client. (#71) --- pyproject.toml | 2 +- src/braingeneers/iot/messaging.py | 17 ++++++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3bbaf21..9edf7b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ 'matplotlib', 'nptyping', 'numpy', - 'paho-mqtt', + 'paho-mqtt>=2', 'pandas', 'powerlaw', 'redis', diff --git a/src/braingeneers/iot/messaging.py b/src/braingeneers/iot/messaging.py index 7742c80..ce16e90 100644 --- a/src/braingeneers/iot/messaging.py +++ b/src/braingeneers/iot/messaging.py @@ -1,27 +1,22 @@ """ A simplified MQTT client for Braingeneers specific connections """ - import redis -import tempfile -import functools -import json -import inspect import logging import os import re -import time import io import configparser import threading import queue import uuid -from typing import Callable, Tuple, List, Dict, Union import random import json import braingeneers.iot.shadows as sh -from paho.mqtt import client as mqtt_client -from deprecated import deprecated import pickle -from tenacity import retry, wait_exponential, after_log + +from typing import Callable, Tuple, List, Dict, Union +from deprecated import deprecated +from paho.mqtt import client as mqtt_client +from paho.mqtt.enums import CallbackAPIVersion AWS_REGION = 'us-west-2' @@ -767,7 +762,7 @@ def on_log(client, userdata, level, buf): self.logger.debug("MQTT log: %s", buf) client_id = f'braingeneerspy-{random.randint(0, 1000)}' - self._mqtt_connection = mqtt_client.Client(client_id) + self._mqtt_connection = mqtt_client.Client(CallbackAPIVersion.VERSION1, client_id) self._mqtt_connection.username_pw_set(self._mqtt_profile_id, self._mqtt_profile_key) self._mqtt_connection.on_connect = on_connect self._mqtt_connection.on_log = on_log From 997ea9debfe8ede24d49a6f0c439d0f0f9302331 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 28 Feb 2024 13:52:30 -0800 Subject: [PATCH 2/5] build(deps): bump codecov/codecov-action from 4.0.1 to 4.1.0 (#70) Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 4.0.1 to 4.1.0. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v4.0.1...v4.1.0) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12a2944..f903be2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,4 +47,4 @@ jobs: --durations=20 - name: Upload coverage report - uses: codecov/codecov-action@v4.0.1 + uses: codecov/codecov-action@v4.1.0 From 6ae06d4e5c5492e53b8174b50c6517bab1bfeadf Mon Sep 17 00:00:00 2001 From: Lon Blauvelt Date: Wed, 28 Feb 2024 13:54:37 -0800 Subject: [PATCH 3/5] Remove unused Class IndexedList. (#60) --- .../data/datasets_electrophysiology.py | 41 ------------------- 1 file changed, 41 deletions(-) diff --git a/src/braingeneers/data/datasets_electrophysiology.py b/src/braingeneers/data/datasets_electrophysiology.py index 9986b9b..04ccd34 100644 --- a/src/braingeneers/data/datasets_electrophysiology.py +++ b/src/braingeneers/data/datasets_electrophysiology.py @@ -1303,47 +1303,6 @@ def _axion_get_data(file_name, file_data_start_position, return final_raw_data_reshaped -# class IndexedList(list): -# """ -# A variant of OrderedDict indexable by index (int) or name (str). -# This class forces ints to represent index by location, else index by name/object. -# Example usages: -# metadata['ephys_experiments']['experiment0'] # index by name (must use str type) -# metadata['ephys_experiments'][0] # index by location (must use int type) -# """ -# -# def __init__(self, original_list: list, key: callable): -# self.keys_ordered = [key(v) for v in original_list] -# self.dict = {key(v): v for v in original_list} -# super().__init__() -# -# def __getitem__(self, key): -# print(key) -# if isinstance(key, int): -# return self.dict[self.keys_ordered[key]] -# elif isinstance(key, str): -# return self.dict[key] -# else: -# raise KeyError(f'Key must be type int (index by location) or str (index by name), got type: {type(key)}') -# -# def __iter__(self) -> Iterator: -# def g(): -# for k in self.keys_ordered: -# yield self.dict[k] -# -# return g() -# -# def __hash__(self): -# return self.dict.__hash__() -# -# def __eq__(self, other): -# return isinstance(other, IndexedList) and self.dict.__eq__(other.dict) -# -# def __add__(self, value): -# self.keys_ordered.append(value) -# self.dict[value] = value - - def get_mearec_h5_recordings_file(batch_uuid: str): """ Returns the filepath to the MEArec .h5/.hdf5 recordings file for the given UUID. From dc02a996b1d796cdfdee1244f3e247c2bd571f7f Mon Sep 17 00:00:00 2001 From: David Parks Date: Tue, 5 Mar 2024 14:39:01 -0800 Subject: [PATCH 4/5] Fixed a feature in map2 when the fixed_values include the first parameter of a function and a later parameter is the variable. Added gpt infused unit tests. (#73) --- src/braingeneers/utils/common_utils.py | 36 +++++++++------- src/braingeneers/utils/common_utils_test.py | 47 ++++++++++++++++++++- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/src/braingeneers/utils/common_utils.py b/src/braingeneers/utils/common_utils.py index a04d553..cf0f22f 100644 --- a/src/braingeneers/utils/common_utils.py +++ b/src/braingeneers/utils/common_utils.py @@ -4,7 +4,7 @@ from botocore.exceptions import ClientError import os import braingeneers -from typing import List, Tuple, Union, Callable, Iterable +from typing import Callable, Iterable, Union, List, Tuple, Dict, Any import functools import inspect import multiprocessing @@ -122,6 +122,14 @@ def file_list(filepath: str) -> List[Tuple[str, str, int]]: return files_and_details +# Define the wrapper function as a top-level function +def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple) -> Any: + """Internal wrapper function for map2 to handle fixed values and dynamic arguments.""" + # Merge fixed_values with provided arguments, aligning provided args with required_params + call_args = {**fixed_values, **dict(zip(required_params, args))} + return func(**call_args) + + def map2(func: Callable, args: Iterable[Union[Tuple, object]] = None, fixed_values: dict = None, @@ -168,27 +176,25 @@ def f(x, y): :return: a list of the return values of func """ assert isinstance(fixed_values, (dict, type(None))) - assert isinstance(parallelism, int) + assert parallelism is False or isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer" parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism - assert isinstance(parallelism, int) + assert isinstance(parallelism, int), "parallelism must be resolved to an integer" + + fixed_values = fixed_values or {} + func_signature = inspect.signature(func) + required_params = [p.name for p in func_signature.parameters.values() if + p.default == inspect.Parameter.empty and p.name not in fixed_values] - func_partial = functools.partial(func, **(fixed_values or {})) - n_required_params = sum([p.default == inspect.Parameter.empty for p in inspect.signature(func).parameters.values()]) - n_fixed_values = len(fixed_values or {}) args_list = list(args or []) - args_tuples = args \ - if len(args_list) > 0 \ - and isinstance(args_list[0], tuple) \ - and len(args_list[0]) >= n_required_params - n_fixed_values \ - else [(a,) for a in args_list] + args_tuples = args_list if all(isinstance(a, tuple) for a in args_list) else [(a,) for a in args_list] if parallelism == 1: - result_iterator = itertools.starmap(func_partial, args_tuples) + result_iterator = map(lambda args: _map2_wrapper(fixed_values, required_params, func, args), args_tuples) else: - # noinspection PyPep8Naming - ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading is True else multiprocessing.Pool + ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool with ProcessOrThreadPool(parallelism) as pool: - result_iterator = pool.starmap(func_partial, args_tuples) + result_iterator = pool.starmap(_map2_wrapper, + [(fixed_values, required_params, func, args) for args in args_tuples]) return list(result_iterator) diff --git a/src/braingeneers/utils/common_utils_test.py b/src/braingeneers/utils/common_utils_test.py index bc33ca4..0ce534b 100644 --- a/src/braingeneers/utils/common_utils_test.py +++ b/src/braingeneers/utils/common_utils_test.py @@ -1,10 +1,15 @@ import unittest from unittest.mock import patch, MagicMock -import common_utils # Updated import statement +import common_utils +from common_utils import map2 import os import tempfile +def multiply(x, y): + return x * y + + class TestFileListFunction(unittest.TestCase): @patch('common_utils._lazy_init_s3_client') # Updated to common_utils @@ -45,5 +50,45 @@ def test_local_no_files(self): self.assertEqual(result, []) +class TestMap2(unittest.TestCase): + def test_basic_functionality(self): + """Test map2 with a simple function, no fixed values, no parallelism.""" + + def simple_add(x, y): + return x + y + + args = [(1, 2), (2, 3), (3, 4)] + expected = [3, 5, 7] + result = map2(simple_add, args=args, parallelism=False) + self.assertEqual(result, expected) + + def test_with_fixed_values(self): + """Test map2 with fixed values.""" + + def f(a, b, c): + return f'{a} {b} {c}' + + args = [2, 20, 200] + expected = ['1 2 3', '1 20 3', '1 200 3'] + result = map2(func=f, args=args, fixed_values=dict(a=1, c=3), parallelism=False) + self.assertEqual(result, expected) + + def test_with_parallelism(self): + """Test map2 with parallelism enabled (assuming the environment supports it).""" + args = [(1, 2), (2, 3), (3, 4)] + expected = [2, 6, 12] + result = map2(multiply, args=args, parallelism=True) + self.assertEqual(result, expected) + + def test_with_invalid_args(self): + """Test map2 with invalid args to ensure it raises the correct exceptions.""" + + def simple_subtract(x, y): + return x - y + + with self.assertRaises(AssertionError): + map2(simple_subtract, args=[1], parallelism="invalid") + + if __name__ == '__main__': unittest.main() From fedc18fb0798b9e5014e2c746959adefe91e3152 Mon Sep 17 00:00:00 2001 From: David Parks Date: Tue, 5 Mar 2024 14:45:14 -0800 Subject: [PATCH 5/5] Removed hengenlab metadata reader. This functionality will be moved to a workflow. The neuraltoolkit dependency was awkward and causing other bugs. (#72) * added AtomicGetSetEphysMetadata * Corrected bug and added force_release function * Added documentation commented on by Kate * Moved AtomicGetSetEphysMetadata to common_utils * Switched to using common_utils.checkin|checkout methods * minor update * Changing smart_open import * Added assert and docs to use "r" or "rb" mode. * clean up comments * Removed hengenlab metadata reader. This functionality will be moved to a workflow. The neuraltoolkit dependency was awkward and causing other bugs. * Updated unittests * missed a hengenlab reference --------- Co-authored-by: Ash Co-authored-by: Lon Blauvelt --- pyproject.toml | 4 - .../data/datasets_electrophysiology.py | 131 ------------------ .../data/datasets_electrophysiology_test.py | 41 ------ src/braingeneers/iot/messaging.py | 1 - src/braingeneers/iot/messaging_test.py | 1 + src/braingeneers/utils/common_utils.py | 94 ++++++++++++- src/braingeneers/utils/common_utils_test.py | 19 ++- 7 files changed, 108 insertions(+), 183 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9edf7b9..7eaa5ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,16 +62,12 @@ local_scheme = "no-local-version" [project.optional-dependencies] all = [ 'braingeneers[ml]', - 'braingeneers[hengenlab]', 'braingeneers[dev]', ] ml = [ 'torch', 'scikit-learn', ] -hengenlab = [ - 'neuraltoolkit==0.3.1', # channel mapping information -] dev = [ "pytest >=6", "pytest-cov >=3", diff --git a/src/braingeneers/data/datasets_electrophysiology.py b/src/braingeneers/data/datasets_electrophysiology.py index 04ccd34..e5ffabf 100644 --- a/src/braingeneers/data/datasets_electrophysiology.py +++ b/src/braingeneers/data/datasets_electrophysiology.py @@ -28,10 +28,6 @@ import re from types import ModuleType import bisect -try: - import neuraltoolkit as ntk # optional import -except ImportError: - pass VALID_LOAD_DATA_DTYPES = [np.int16, np.float16, np.float32, np.float64] @@ -847,133 +843,6 @@ def _read_hengenlab_ecube_timestamp(filepath: str) -> int: return int(np.frombuffer(f.read(8), dtype=np.uint64)) -def generate_metadata_hengenlab(batch_uuid: str, - dataset_name: str, - experiment_name: Union[List[str], str] = 'experiment1', - fs: int = 25000, - n_threads: int = 32, - save: bool = False): - """ - Generates a metadata json and experiment1...experimentN section for a hengenlab dataset upload. - File locations in S3 for hengenlab neural data files: - s3://braingeneers/ephys/YYYY-MM-DD-e-${DATASET_NAME}/original/data/*.bin - Contiguous recording periods - :param batch_uuid: location on braingeneers storage (S3) - :param dataset_name: the dataset_name as defined in `neuraltoolkit`. Metadata will be pulled from `neuraltoolkit`. - :param experiment_name: Dataset name as stored in `neuraltoolkit`. For example "CAF26" - :param fs: sampling rate, default to 25,000 - :param n_threads: number of threads to use for reading ecube timestamps (default: 32) - :param save: (default False) option to save the metadata.json back to S3 - (or the current braingeneers.default_endpoint) - :return: metadata.json - """ - # hengenlab's (current) source of record for experiment metadata is stored in a repo which can't be imported - # due to unacceptable dependencies. Instead, the source code is being downloaded with the relevant static - # functions parsed out explicitly. This is a hacky approach, but this data shouldn't be stored - # in a repo and is expected to be replaced with a proper database in the future. - # All current solutions to this problem are bad, this is the least objectionable solution. - crit_utils_src = requests.get('https://raw.githubusercontent.com/hengenlab/sahara_work/master/crit_utils.py').text - - src_get_birthday = re.search(r'(def get_birthday\(animal, returnall=False\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_regions = re.search(r'(def get_regions\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_sex = re.search(r'(def get_sex\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_genotype = re.search(r'(def get_genotype\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - src_get_hstype = re.search(r'(def get_hstype\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1) - - module = ModuleType('tempmodule') - module.dt = datetime # the only import necessary to run these static functions - exec(compile(src_get_birthday, '', 'exec'), module.__dict__) - exec(compile(src_get_regions, '', 'exec'), module.__dict__) - exec(compile(src_get_sex, '', 'exec'), module.__dict__) - exec(compile(src_get_genotype, '', 'exec'), module.__dict__) - exec(compile(src_get_hstype, '', 'exec'), module.__dict__) - - headstage_types = module.get_hstype(dataset_name.lower()) - - # list neural data files on S3 - s3_path = f's3://braingeneers/ephys/{batch_uuid}/original/{experiment_name}/' - neural_data_files = common_utils.file_list(s3_path) - assert len(neural_data_files) > 0, f'No neural data files found at: {s3_path}' - - args = [s3_path + ndf[0] for ndf in neural_data_files] - - # get ecube times for each file - ecube_timestamps = common_utils.map2( - _read_hengenlab_ecube_timestamp, - args=args, - parallelism=n_threads, - use_multithreading=True, - ) - - # sort data files by ecube timestamps - neural_data_files = [(*ndf, et) for ndf, et in zip(neural_data_files, ecube_timestamps)] - neural_data_files.sort(key=lambda ndf: ndf[3]) - - # parse n_channels from file name - channels_match = re.search(r'.*Headstages_(\d+)_Channels.*', neural_data_files[0][0]) - assert channels_match is not None, f'Unable to parse n_channels from filename: {neural_data_files[0][0]}' - n_channels = int(channels_match.group(1)) - - # parse timestamp from first file name - timestamp_match = re.search(r'.*_Channels_int16_(.+)\.bin', neural_data_files[0][0]) - assert timestamp_match is not None, f'Unable to parse timestamp from filename: {neural_data_files[0][0]}' - timestamp = datetime.strptime(timestamp_match.group(1), '%Y-%m-%d_%H-%M-%S') - - channels_per_probe = n_channels // len(headstage_types) - channel_map = list(itertools.chain(*[ - (ntk.find_channel_map(hstype, number_of_channels=channels_per_probe) + i * channels_per_probe).tolist() - for i, hstype in enumerate(headstage_types) - ])) - - metadata = dict( - uuid=batch_uuid, - timestamp=timestamp.isoformat(), - issue='', - channel_map=channel_map, - headstage_types=headstage_types, - notes=dict( - purpose_of_experiment='', - comments='', - biology=dict( - sample_type='mouse', - dataset_name=dataset_name, - birthday=module.get_birthday(dataset_name.lower()).isoformat(), - gender=module.get_sex(dataset_name.lower()), - genotype=module.get_genotype(dataset_name.lower()), - ), - ), - ephys_experiments=[dict( - name=experiment_name, - hardware='Hengenlab', - num_channels=n_channels, - sample_rate=fs, - voltage_scaling_factor=0.19073486328125, - timestamp=timestamp.isoformat(), - units='\u00b5V', - version='1.0.0', - blocks=[ - { - 'num_frames': (size - 8) // 2 // n_channels, - 'path': f'original/{experiment_name}/{neural_data_file.split("/")[-1]}', - 'timestamp': datetime.strptime( - re.search(r'.*_Channels_int16_(.+)\.bin', neural_data_file).group(1), - '%Y-%m-%d_%H-%M-%S', - ).isoformat(), - 'ecube_time': ecube_time, - } - for neural_data_file, last_modified_timestamp, size, ecube_time in neural_data_files - ], - )], - ) - - if save is True: - save_path = f's3://braingeneers/ephys/{batch_uuid}/metadata.json' - with smart_open.open(save_path, 'w') as f: - f.write(json.dumps(metadata, indent=2)) - - return metadata - - # --- AXION READER ----------------------------- def from_uint64(all_values): """ diff --git a/src/braingeneers/data/datasets_electrophysiology_test.py b/src/braingeneers/data/datasets_electrophysiology_test.py index f3c2122..0ce7f7a 100644 --- a/src/braingeneers/data/datasets_electrophysiology_test.py +++ b/src/braingeneers/data/datasets_electrophysiology_test.py @@ -413,47 +413,6 @@ def test_online_load_data_hengenlab_float32(self): self.assertEqual((192, 4), data.shape) self.assertEqual(np.float32, data.dtype) - @skip_unittest_if_offline - def test_online_generate_metadata(self): - metadata = ephys.generate_metadata_hengenlab( - batch_uuid=self.batch_uuid, - dataset_name='CAF26', - save=False, - ) - - # top level items - self.assertEqual(metadata['uuid'], '2020-04-12-e-hengenlab-caf26') - self.assertEqual(metadata['timestamp'], '2020-08-07T14:00:15') - self.assertEqual(metadata['issue'], '') - self.assertEqual(metadata['headstage_types'], ['EAB50chmap_00', 'APT_PCB', 'APT_PCB']) - - # notes - self.assertEqual(metadata['notes']['biology']['sample_type'], 'mouse') - self.assertEqual(metadata['notes']['biology']['dataset_name'], 'CAF26') - self.assertEqual(metadata['notes']['biology']['birthday'], '2020-02-20T07:30:00') - self.assertEqual(metadata['notes']['biology']['genotype'], 'wt') - - # ephys_experiments - self.assertEqual(len(metadata['ephys_experiments']), 1) - self.assertTrue(isinstance(metadata['ephys_experiments'], list)) - - experiment = metadata['ephys_experiments'][0] - self.assertEqual(experiment['name'], 'experiment1') - self.assertEqual(experiment['hardware'], 'Hengenlab') - self.assertEqual(experiment['num_channels'], 192) - self.assertEqual(experiment['sample_rate'], 25000) - self.assertEqual(experiment['voltage_scaling_factor'], 0.19073486328125) - self.assertEqual(experiment['timestamp'], '2020-08-07T14:00:15') - self.assertEqual(experiment['units'], '\u00b5V') - self.assertEqual(experiment['version'], '1.0.0') - self.assertEqual(len(experiment['blocks']), 324) - - block1 = metadata['ephys_experiments'][0]['blocks'][1] - self.assertEqual(block1['num_frames'], 7500000) - self.assertEqual(block1['path'], 'original/experiment1/Headstages_192_Channels_int16_2020-08-07_14-05-16.bin') - self.assertEqual(block1['timestamp'], '2020-08-07T14:05:16') - self.assertEqual(block1['ecube_time'], 301061600050) - if __name__ == '__main__': unittest.main() diff --git a/src/braingeneers/iot/messaging.py b/src/braingeneers/iot/messaging.py index ce16e90..db19cfc 100644 --- a/src/braingeneers/iot/messaging.py +++ b/src/braingeneers/iot/messaging.py @@ -816,4 +816,3 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _mqtt_topic_regex(topic: str) -> str: """ Converts a topic string with wildcards to a regex string """ return "^" + topic.replace("+", "[^/]+").replace("#", ".*").replace("$", "\\$") + "$" - diff --git a/src/braingeneers/iot/messaging_test.py b/src/braingeneers/iot/messaging_test.py index 19dd379..205acb0 100644 --- a/src/braingeneers/iot/messaging_test.py +++ b/src/braingeneers/iot/messaging_test.py @@ -303,5 +303,6 @@ def test_acquire_release(self): lock.release() + if __name__ == '__main__': unittest.main() diff --git a/src/braingeneers/utils/common_utils.py b/src/braingeneers/utils/common_utils.py index cf0f22f..ddd9124 100644 --- a/src/braingeneers/utils/common_utils.py +++ b/src/braingeneers/utils/common_utils.py @@ -1,9 +1,11 @@ """ Common utility functions """ +import io import urllib import boto3 from botocore.exceptions import ClientError import os import braingeneers +import braingeneers.utils.smart_open_braingeneers as smart_open from typing import Callable, Iterable, Union, List, Tuple, Dict, Any import functools import inspect @@ -11,9 +13,12 @@ import posixpath import itertools import pathlib - +import json +import hashlib _s3_client = None # S3 client for boto3, lazy initialization performed in _lazy_init_s3_client() +_message_broker = None # Lazy initialization of the message broker +_named_locks = {} # Named locks for checkout and checkin def _lazy_init_s3_client(): @@ -199,6 +204,85 @@ def f(x, y): return list(result_iterator) +def checkout(s3_file: str, mode: str = 'r') -> io.IOBase: + """ + Check out a file from S3 for reading or writing, use checkin to release the file. + Any subsequent calls to checkout will block until the file is returned with checkin(s3_file). + + Example usage: + f = checkout('s3://braingeneersdev/test/test_file.bin', mode='rb') + new_bytes = do_something(f.read()) + checkin('s3://braingeneersdev/test/test_file.bin', new_bytes) + + Example usage to update metadata: + f = checkout('s3://braingeneersdev/test/metadata.json') + metadata_dict = json.loads(f.read()) + metadata_dict['new_key'] = 'new_value' + metadata_updated_str = json.dumps(metadata_dict, indent=2) + checkin('s3://braingeneersdev/test/metadata.json', updated_metadata_str) + + :param s3_file: The S3 file path to check out. + :param mode: The mode to open the file in, 'r' (text mode) or 'rb' (binary mode), analogous to system open(filename, mode) + """ + # Avoid circular import + from braingeneers.iot.messaging import MessageBroker + + assert mode in ('r', 'rb'), 'Use "r" (text) or "rb" (binary) mode only. File changes are applied at checkout(...)' + + global _message_broker, _named_locks + if _message_broker is None: + print('creating message broker') + _message_broker = MessageBroker() + mb = _message_broker + + lock_str = f'common-utils-checkout-{s3_file}' + named_lock = mb.get_lock(lock_str) + named_lock.acquire() + _named_locks[s3_file] = named_lock + f = smart_open.open(s3_file, mode) + return f + + +def checkin(s3_file: str, file: Union[str, bytes, io.IOBase]): + """ + Releases a file that was checked out with checkout. + + :param s3_file: The S3 file path, must match checkout. + :param file: The string, bytes, or file object to write back to S3. + """ + assert isinstance(file, (str, bytes, io.IOBase)), 'file must be a string, bytes, or file object.' + + with smart_open.open(s3_file, 'wb') as f: + if isinstance(file, str): + f.write(file.encode()) + elif isinstance(file, bytes): + f.write(file) + else: + file.seek(0) + data = file.read() + f.write(data if isinstance(data, bytes) else data.encode()) + + global _named_locks + named_lock = _named_locks[s3_file] + named_lock.release() + + +def force_release_checkout(s3_file: str): + """ + Force release the lock on a file that was checked out with checkout. + """ + # Avoid circular import + from braingeneers.iot.messaging import MessageBroker + + global _message_broker + if _message_broker is None: + _message_broker = MessageBroker() + mb = _message_broker + + lock_str = f'common-utils-checkout-{s3_file}' + mb.delete_lock(lock_str) + + def pretty_print(data, n=10, indent=0): """ Custom pretty print function that uniformly truncates any collection (list or dictionary) @@ -206,13 +290,13 @@ def pretty_print(data, n=10, indent=0): Ensures mapping sections and similar are displayed compactly. Example usage (to display metadata.json): - + from braingeneers.utils.common_utils import pretty_print from braingeneers.data import datasets_electrophysiology as de - + metadata = de.load_metadata('2023-04-17-e-connectoid16235_CCH') pretty_print(metadata) - + Parameters: - data: The data to pretty print, either a list or a dictionary. - n: Maximum number of elements or items to display before truncation. @@ -227,7 +311,7 @@ def pretty_print(data, n=10, indent=0): else: truncated_keys = keys omitted_keys = None - + print('{') for key in truncated_keys: value = data[key] diff --git a/src/braingeneers/utils/common_utils_test.py b/src/braingeneers/utils/common_utils_test.py index 0ce534b..ae812b9 100644 --- a/src/braingeneers/utils/common_utils_test.py +++ b/src/braingeneers/utils/common_utils_test.py @@ -1,9 +1,11 @@ import unittest from unittest.mock import patch, MagicMock +from common_utils import checkout, checkin, force_release_checkout, map2 +from braingeneers.iot import messaging import common_utils -from common_utils import map2 import os import tempfile +import braingeneers.utils.smart_open_braingeneers as smart_open def multiply(x, y): @@ -50,6 +52,21 @@ def test_local_no_files(self): self.assertEqual(result, []) +class TestCheckingCheckout(unittest.TestCase): + def setUp(self) -> None: + self.text_value = 'unittest1' + self.filepath = 's3://braingeneersdev/unittest/test.txt' + force_release_checkout(self.filepath) + + with smart_open.open(self.filepath, 'w') as f: + f.write(self.text_value) + + def test_checkout_checkin(self): + f = checkout(self.filepath) + self.assertEqual(f.read(), self.text_value) + checkin(self.filepath, f) + + class TestMap2(unittest.TestCase): def test_basic_functionality(self): """Test map2 with a simple function, no fixed values, no parallelism."""