Skip to content

Commit

Permalink
Merge branch 'master' into safe_update_metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
davidparks21 committed Mar 12, 2024
2 parents 5f9903d + fedc18f commit 6f7abfb
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 244 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ jobs:
--durations=20
- name: Upload coverage report
uses: codecov/codecov-action@v4.0.1
uses: codecov/codecov-action@v4.1.0
6 changes: 1 addition & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
'matplotlib',
'nptyping',
'numpy',
'paho-mqtt',
'paho-mqtt>=2',
'pandas',
'powerlaw',
'redis',
Expand All @@ -62,16 +62,12 @@ local_scheme = "no-local-version"
[project.optional-dependencies]
all = [
'braingeneers[ml]',
'braingeneers[hengenlab]',
'braingeneers[dev]',
]
ml = [
'torch',
'scikit-learn',
]
hengenlab = [
'neuraltoolkit==0.3.1', # channel mapping information
]
dev = [
"pytest >=6",
"pytest-cov >=3",
Expand Down
172 changes: 0 additions & 172 deletions src/braingeneers/data/datasets_electrophysiology.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
import re
from types import ModuleType
import bisect
try:
import neuraltoolkit as ntk # optional import
except ImportError:
pass


VALID_LOAD_DATA_DTYPES = [np.int16, np.float16, np.float32, np.float64]
Expand Down Expand Up @@ -847,133 +843,6 @@ def _read_hengenlab_ecube_timestamp(filepath: str) -> int:
return int(np.frombuffer(f.read(8), dtype=np.uint64))


def generate_metadata_hengenlab(batch_uuid: str,
dataset_name: str,
experiment_name: Union[List[str], str] = 'experiment1',
fs: int = 25000,
n_threads: int = 32,
save: bool = False):
"""
Generates a metadata json and experiment1...experimentN section for a hengenlab dataset upload.
File locations in S3 for hengenlab neural data files:
s3://braingeneers/ephys/YYYY-MM-DD-e-${DATASET_NAME}/original/data/*.bin
Contiguous recording periods
:param batch_uuid: location on braingeneers storage (S3)
:param dataset_name: the dataset_name as defined in `neuraltoolkit`. Metadata will be pulled from `neuraltoolkit`.
:param experiment_name: Dataset name as stored in `neuraltoolkit`. For example "CAF26"
:param fs: sampling rate, default to 25,000
:param n_threads: number of threads to use for reading ecube timestamps (default: 32)
:param save: (default False) option to save the metadata.json back to S3
(or the current braingeneers.default_endpoint)
:return: metadata.json
"""
# hengenlab's (current) source of record for experiment metadata is stored in a repo which can't be imported
# due to unacceptable dependencies. Instead, the source code is being downloaded with the relevant static
# functions parsed out explicitly. This is a hacky approach, but this data shouldn't be stored
# in a repo and is expected to be replaced with a proper database in the future.
# All current solutions to this problem are bad, this is the least objectionable solution.
crit_utils_src = requests.get('https://raw.githubusercontent.com/hengenlab/sahara_work/master/crit_utils.py').text

src_get_birthday = re.search(r'(def get_birthday\(animal, returnall=False\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1)
src_get_regions = re.search(r'(def get_regions\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1)
src_get_sex = re.search(r'(def get_sex\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1)
src_get_genotype = re.search(r'(def get_genotype\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1)
src_get_hstype = re.search(r'(def get_hstype\(animal\):.+?)\ndef ', crit_utils_src, flags=re.S).group(1)

module = ModuleType('tempmodule')
module.dt = datetime # the only import necessary to run these static functions
exec(compile(src_get_birthday, '', 'exec'), module.__dict__)
exec(compile(src_get_regions, '', 'exec'), module.__dict__)
exec(compile(src_get_sex, '', 'exec'), module.__dict__)
exec(compile(src_get_genotype, '', 'exec'), module.__dict__)
exec(compile(src_get_hstype, '', 'exec'), module.__dict__)

headstage_types = module.get_hstype(dataset_name.lower())

# list neural data files on S3
s3_path = f's3://braingeneers/ephys/{batch_uuid}/original/{experiment_name}/'
neural_data_files = common_utils.file_list(s3_path)
assert len(neural_data_files) > 0, f'No neural data files found at: {s3_path}'

args = [s3_path + ndf[0] for ndf in neural_data_files]

# get ecube times for each file
ecube_timestamps = common_utils.map2(
_read_hengenlab_ecube_timestamp,
args=args,
parallelism=n_threads,
use_multithreading=True,
)

# sort data files by ecube timestamps
neural_data_files = [(*ndf, et) for ndf, et in zip(neural_data_files, ecube_timestamps)]
neural_data_files.sort(key=lambda ndf: ndf[3])

# parse n_channels from file name
channels_match = re.search(r'.*Headstages_(\d+)_Channels.*', neural_data_files[0][0])
assert channels_match is not None, f'Unable to parse n_channels from filename: {neural_data_files[0][0]}'
n_channels = int(channels_match.group(1))

# parse timestamp from first file name
timestamp_match = re.search(r'.*_Channels_int16_(.+)\.bin', neural_data_files[0][0])
assert timestamp_match is not None, f'Unable to parse timestamp from filename: {neural_data_files[0][0]}'
timestamp = datetime.strptime(timestamp_match.group(1), '%Y-%m-%d_%H-%M-%S')

channels_per_probe = n_channels // len(headstage_types)
channel_map = list(itertools.chain(*[
(ntk.find_channel_map(hstype, number_of_channels=channels_per_probe) + i * channels_per_probe).tolist()
for i, hstype in enumerate(headstage_types)
]))

metadata = dict(
uuid=batch_uuid,
timestamp=timestamp.isoformat(),
issue='',
channel_map=channel_map,
headstage_types=headstage_types,
notes=dict(
purpose_of_experiment='',
comments='',
biology=dict(
sample_type='mouse',
dataset_name=dataset_name,
birthday=module.get_birthday(dataset_name.lower()).isoformat(),
gender=module.get_sex(dataset_name.lower()),
genotype=module.get_genotype(dataset_name.lower()),
),
),
ephys_experiments=[dict(
name=experiment_name,
hardware='Hengenlab',
num_channels=n_channels,
sample_rate=fs,
voltage_scaling_factor=0.19073486328125,
timestamp=timestamp.isoformat(),
units='\u00b5V',
version='1.0.0',
blocks=[
{
'num_frames': (size - 8) // 2 // n_channels,
'path': f'original/{experiment_name}/{neural_data_file.split("/")[-1]}',
'timestamp': datetime.strptime(
re.search(r'.*_Channels_int16_(.+)\.bin', neural_data_file).group(1),
'%Y-%m-%d_%H-%M-%S',
).isoformat(),
'ecube_time': ecube_time,
}
for neural_data_file, last_modified_timestamp, size, ecube_time in neural_data_files
],
)],
)

if save is True:
save_path = f's3://braingeneers/ephys/{batch_uuid}/metadata.json'
with smart_open.open(save_path, 'w') as f:
f.write(json.dumps(metadata, indent=2))

return metadata


# --- AXION READER -----------------------------
def from_uint64(all_values):
"""
Expand Down Expand Up @@ -1303,47 +1172,6 @@ def _axion_get_data(file_name, file_data_start_position,
return final_raw_data_reshaped


# class IndexedList(list):
# """
# A variant of OrderedDict indexable by index (int) or name (str).
# This class forces ints to represent index by location, else index by name/object.
# Example usages:
# metadata['ephys_experiments']['experiment0'] # index by name (must use str type)
# metadata['ephys_experiments'][0] # index by location (must use int type)
# """
#
# def __init__(self, original_list: list, key: callable):
# self.keys_ordered = [key(v) for v in original_list]
# self.dict = {key(v): v for v in original_list}
# super().__init__()
#
# def __getitem__(self, key):
# print(key)
# if isinstance(key, int):
# return self.dict[self.keys_ordered[key]]
# elif isinstance(key, str):
# return self.dict[key]
# else:
# raise KeyError(f'Key must be type int (index by location) or str (index by name), got type: {type(key)}')
#
# def __iter__(self) -> Iterator:
# def g():
# for k in self.keys_ordered:
# yield self.dict[k]
#
# return g()
#
# def __hash__(self):
# return self.dict.__hash__()
#
# def __eq__(self, other):
# return isinstance(other, IndexedList) and self.dict.__eq__(other.dict)
#
# def __add__(self, value):
# self.keys_ordered.append(value)
# self.dict[value] = value


def get_mearec_h5_recordings_file(batch_uuid: str):
"""
Returns the filepath to the MEArec .h5/.hdf5 recordings file for the given UUID.
Expand Down
41 changes: 0 additions & 41 deletions src/braingeneers/data/datasets_electrophysiology_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,47 +413,6 @@ def test_online_load_data_hengenlab_float32(self):
self.assertEqual((192, 4), data.shape)
self.assertEqual(np.float32, data.dtype)

@skip_unittest_if_offline
def test_online_generate_metadata(self):
metadata = ephys.generate_metadata_hengenlab(
batch_uuid=self.batch_uuid,
dataset_name='CAF26',
save=False,
)

# top level items
self.assertEqual(metadata['uuid'], '2020-04-12-e-hengenlab-caf26')
self.assertEqual(metadata['timestamp'], '2020-08-07T14:00:15')
self.assertEqual(metadata['issue'], '')
self.assertEqual(metadata['headstage_types'], ['EAB50chmap_00', 'APT_PCB', 'APT_PCB'])

# notes
self.assertEqual(metadata['notes']['biology']['sample_type'], 'mouse')
self.assertEqual(metadata['notes']['biology']['dataset_name'], 'CAF26')
self.assertEqual(metadata['notes']['biology']['birthday'], '2020-02-20T07:30:00')
self.assertEqual(metadata['notes']['biology']['genotype'], 'wt')

# ephys_experiments
self.assertEqual(len(metadata['ephys_experiments']), 1)
self.assertTrue(isinstance(metadata['ephys_experiments'], list))

experiment = metadata['ephys_experiments'][0]
self.assertEqual(experiment['name'], 'experiment1')
self.assertEqual(experiment['hardware'], 'Hengenlab')
self.assertEqual(experiment['num_channels'], 192)
self.assertEqual(experiment['sample_rate'], 25000)
self.assertEqual(experiment['voltage_scaling_factor'], 0.19073486328125)
self.assertEqual(experiment['timestamp'], '2020-08-07T14:00:15')
self.assertEqual(experiment['units'], '\u00b5V')
self.assertEqual(experiment['version'], '1.0.0')
self.assertEqual(len(experiment['blocks']), 324)

block1 = metadata['ephys_experiments'][0]['blocks'][1]
self.assertEqual(block1['num_frames'], 7500000)
self.assertEqual(block1['path'], 'original/experiment1/Headstages_192_Channels_int16_2020-08-07_14-05-16.bin')
self.assertEqual(block1['timestamp'], '2020-08-07T14:05:16')
self.assertEqual(block1['ecube_time'], 301061600050)


if __name__ == '__main__':
unittest.main()
15 changes: 6 additions & 9 deletions src/braingeneers/iot/messaging.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
""" A simplified MQTT client for Braingeneers specific connections """

import redis
import tempfile
import functools
import json
import inspect
import logging
import os
import re
Expand All @@ -15,13 +10,15 @@
import threading
import queue
import uuid
from typing import Callable, Tuple, List, Dict, Union
import random
import json
import braingeneers.iot.shadows as sh
from paho.mqtt import client as mqtt_client
from deprecated import deprecated
import pickle

from typing import Callable, Tuple, List, Dict, Union
from deprecated import deprecated
from paho.mqtt import client as mqtt_client
from paho.mqtt.enums import CallbackAPIVersion
from tenacity import retry, wait_exponential, after_log
import braingeneers.utils.smart_open_braingeneers as smart_open

Expand Down Expand Up @@ -770,7 +767,7 @@ def on_log(client, userdata, level, buf):
self.logger.debug("MQTT log: %s", buf)

client_id = f'braingeneerspy-{random.randint(0, 1000)}'
self._mqtt_connection = mqtt_client.Client(client_id)
self._mqtt_connection = mqtt_client.Client(CallbackAPIVersion.VERSION1, client_id)
self._mqtt_connection.username_pw_set(self._mqtt_profile_id, self._mqtt_profile_key)
self._mqtt_connection.on_connect = on_connect
self._mqtt_connection.on_log = on_log
Expand Down
37 changes: 21 additions & 16 deletions src/braingeneers/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import braingeneers
import braingeneers.utils.smart_open_braingeneers as smart_open
from typing import List, Tuple, Union, Callable, Iterable
from typing import Callable, Iterable, Union, List, Tuple, Dict, Any
import functools
import inspect
import multiprocessing
Expand Down Expand Up @@ -127,6 +127,14 @@ def file_list(filepath: str) -> List[Tuple[str, str, int]]:
return files_and_details


# Define the wrapper function as a top-level function
def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple) -> Any:
"""Internal wrapper function for map2 to handle fixed values and dynamic arguments."""
# Merge fixed_values with provided arguments, aligning provided args with required_params
call_args = {**fixed_values, **dict(zip(required_params, args))}
return func(**call_args)


def map2(func: Callable,
args: Iterable[Union[Tuple, object]] = None,
fixed_values: dict = None,
Expand Down Expand Up @@ -173,27 +181,25 @@ def f(x, y):
:return: a list of the return values of func
"""
assert isinstance(fixed_values, (dict, type(None)))
assert isinstance(parallelism, int)
assert parallelism is False or isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer"
parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism
assert isinstance(parallelism, int)
assert isinstance(parallelism, int), "parallelism must be resolved to an integer"

fixed_values = fixed_values or {}
func_signature = inspect.signature(func)
required_params = [p.name for p in func_signature.parameters.values() if
p.default == inspect.Parameter.empty and p.name not in fixed_values]

func_partial = functools.partial(func, **(fixed_values or {}))
n_required_params = sum([p.default == inspect.Parameter.empty for p in inspect.signature(func).parameters.values()])
n_fixed_values = len(fixed_values or {})
args_list = list(args or [])
args_tuples = args \
if len(args_list) > 0 \
and isinstance(args_list[0], tuple) \
and len(args_list[0]) >= n_required_params - n_fixed_values \
else [(a,) for a in args_list]
args_tuples = args_list if all(isinstance(a, tuple) for a in args_list) else [(a,) for a in args_list]

if parallelism == 1:
result_iterator = itertools.starmap(func_partial, args_tuples)
result_iterator = map(lambda args: _map2_wrapper(fixed_values, required_params, func, args), args_tuples)
else:
# noinspection PyPep8Naming
ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading is True else multiprocessing.Pool
ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool
with ProcessOrThreadPool(parallelism) as pool:
result_iterator = pool.starmap(func_partial, args_tuples)
result_iterator = pool.starmap(_map2_wrapper,
[(fixed_values, required_params, func, args) for args in args_tuples])

return list(result_iterator)

Expand Down Expand Up @@ -291,7 +297,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.named_lock.release()



def force_release_checkout(s3_file: str):
"""
Force release the lock on a file that was checked out with checkout.
Expand Down

0 comments on commit 6f7abfb

Please sign in to comment.