Skip to content

Commit

Permalink
Normalize unit test formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
atspaeth committed May 24, 2024
1 parent 5a98f68 commit 295911a
Show file tree
Hide file tree
Showing 8 changed files with 526 additions and 354 deletions.
95 changes: 53 additions & 42 deletions tests/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,45 @@
import os
import tempfile
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch

import braingeneers.utils.smart_open_braingeneers as smart_open
from braingeneers.utils import common_utils
from braingeneers.utils.common_utils import checkout, map2


class TestFileListFunction(unittest.TestCase):

@patch('braingeneers.utils.common_utils._lazy_init_s3_client') # Updated to common_utils
@patch(
"braingeneers.utils.common_utils._lazy_init_s3_client"
) # Updated to common_utils
def test_s3_files_exist(self, mock_s3_client):
# Mock S3 client response
mock_response = {
'Contents': [
{'Key': 'file1.txt', 'LastModified': '2023-01-01', 'Size': 123},
{'Key': 'file2.txt', 'LastModified': '2023-01-02', 'Size': 456}
"Contents": [
{"Key": "file1.txt", "LastModified": "2023-01-01", "Size": 123},
{"Key": "file2.txt", "LastModified": "2023-01-02", "Size": 456},
]
}
mock_s3_client.return_value.list_objects.return_value = mock_response

result = common_utils.file_list('s3://test-bucket/') # Updated to common_utils
expected = [('file2.txt', '2023-01-02', 456), ('file1.txt', '2023-01-01', 123)]
result = common_utils.file_list("s3://test-bucket/") # Updated to common_utils
expected = [("file2.txt", "2023-01-02", 456), ("file1.txt", "2023-01-01", 123)]
self.assertEqual(result, expected)

@patch('braingeneers.utils.common_utils._lazy_init_s3_client') # Updated to common_utils
@patch(
"braingeneers.utils.common_utils._lazy_init_s3_client"
) # Updated to common_utils
def test_s3_no_files(self, mock_s3_client):
# Mock S3 client response for no files
mock_s3_client.return_value.list_objects.return_value = {}
result = common_utils.file_list('s3://test-bucket/') # Updated to common_utils
result = common_utils.file_list("s3://test-bucket/") # Updated to common_utils
self.assertEqual(result, [])

def test_local_files_exist(self):
with tempfile.TemporaryDirectory() as temp_dir:
for f in ['tempfile1.txt', 'tempfile2.txt']:
with open(os.path.join(temp_dir, f), 'w') as w:
w.write('nothing')
for f in ["tempfile1.txt", "tempfile2.txt"]:
with open(os.path.join(temp_dir, f), "w") as w:
w.write("nothing")

result = common_utils.file_list(temp_dir) # Updated to common_utils
# The result should contain two files with their details
Expand All @@ -50,10 +53,9 @@ def test_local_no_files(self):


class TestCheckout(unittest.TestCase):

def setUp(self):
# Setup mock for smart_open and MessageBroker
self.message_broker_patch = patch('braingeneers.iot.messaging.MessageBroker')
self.message_broker_patch = patch("braingeneers.iot.messaging.MessageBroker")

# Start the patches
self.mock_message_broker = self.message_broker_patch.start()
Expand All @@ -63,7 +65,9 @@ def setUp(self):
self.mock_message_broker.return_value.delete_lock = MagicMock()

self.mock_file = MagicMock(spec=io.StringIO)
self.mock_file.read.return_value = 'Test data' # Ensure this is correctly setting the return value for read
self.mock_file.read.return_value = (
"Test data" # Ensure this is correctly setting the return value for read
)
self.mock_file.__enter__.return_value = self.mock_file
self.mock_file.__exit__.return_value = None
self.smart_open_mock = MagicMock(spec=smart_open)
Expand All @@ -77,23 +81,23 @@ def tearDown(self):

def test_checkout_context_manager_read(self):
# Test the reading functionality
with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj:
with checkout("s3://test-bucket/test-file.txt", isbinary=False) as locked_obj:
data = locked_obj.get_value()
self.assertEqual(data, 'Test data')
self.assertEqual(data, "Test data")

def test_checkout_context_manager_write_text(self):
# Test the writing functionality for text mode
test_data = 'New test data'
test_data = "New test data"
self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test
with checkout('s3://test-bucket/test-file.txt', isbinary=False) as locked_obj:
with checkout("s3://test-bucket/test-file.txt", isbinary=False) as locked_obj:
locked_obj.checkin(test_data)
self.mock_file.write.assert_called_once_with(test_data)

def test_checkout_context_manager_write_binary(self):
# Test the writing functionality for binary mode
test_data = b'New binary data'
test_data = b"New binary data"
self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test
with checkout('s3://test-bucket/test-file.bin', isbinary=True) as locked_obj:
with checkout("s3://test-bucket/test-file.bin", isbinary=True) as locked_obj:
locked_obj.checkin(test_data)
self.mock_file.write.assert_called_once_with(test_data)

Expand All @@ -103,41 +107,48 @@ def test_with_pass_through_kwargs_handling(self):

def f_with_kwargs(cache_path: str, max_size_gb: int = 10, **kwargs):
# Simulate loading data where 'experiment_name' and other parameters are expected to come through **kwargs
self.assertTrue(isinstance(kwargs, dict), 'kwargs should be a dict')
self.assertFalse('kwargs' in kwargs)
return 'some data'

experiments = [{'experiment': 'exp1'}, {'experiment': 'exp2'}] # List of experiment names to be passed as individual kwargs
self.assertTrue(isinstance(kwargs, dict), "kwargs should be a dict")
self.assertFalse("kwargs" in kwargs)
return "some data"

experiments = [
{"experiment": "exp1"},
{"experiment": "exp2"},
] # List of experiment names to be passed as individual kwargs
fixed_values = {
"cache_path": '/tmp/ephys_cache',
"cache_path": "/tmp/ephys_cache",
"max_size_gb": 50,
"metadata": {"some": "metadata"},
"channels": ["channel1"],
"length": -1,
}

# Execute the test under the assumption that map2 is supposed to handle 'experiment_name' in **kwargs correctly
map2(f_with_kwargs, kwargs=experiments, fixed_values=fixed_values, parallelism=False)
map2(
f_with_kwargs,
kwargs=experiments,
fixed_values=fixed_values,
parallelism=False,
)
self.assertTrue(True) # If the test reaches this point, it has passed


class TestMap2Function(unittest.TestCase):
def test_with_kwargs_function_parallelism_false(self):
# Define a test function that takes a positional argument and arbitrary kwargs
def test_func(a, **kwargs):
return a + kwargs.get('increment', 0)
return a + kwargs.get("increment", 0)

# Define the arguments and kwargs to pass to map2
args = [(1,), (2,), (3,)] # positional arguments
kwargs = [{'increment': 10}, {'increment': 20}, {'increment': 30}] # kwargs for each call
kwargs = [
{"increment": 10},
{"increment": 20},
{"increment": 30},
] # kwargs for each call

# Call map2 with the test function, args, kwargs, and parallelism=False
result = map2(
func=test_func,
args=args,
kwargs=kwargs,
parallelism=False
)
result = map2(func=test_func, args=args, kwargs=kwargs, parallelism=False)

# Expected results after applying the function with the given args and kwargs
expected_results = [11, 22, 33]
Expand All @@ -148,20 +159,20 @@ def test_func(a, **kwargs):
def test_with_fixed_values_and_variable_kwargs_parallelism_false(self):
# Define a test function that takes fixed positional argument and arbitrary kwargs
def test_func(a, **kwargs):
return a + kwargs.get('increment', 0)
return a + kwargs.get("increment", 0)

# Since 'a' is now a fixed value, we no longer need to provide it in args
args = [] # No positional arguments are passed here

# Define the kwargs to pass to map2, each dict represents kwargs for one call
kwargs = [{'increment': 10}, {'increment': 20}, {'increment': 30}]
kwargs = [{"increment": 10}, {"increment": 20}, {"increment": 30}]

# Call map2 with the test function, no args, variable kwargs, fixed_values containing 'a', and parallelism=False
result = map2(
func=test_func,
kwargs=kwargs,
fixed_values={'a': 1}, # 'a' is fixed for all calls
parallelism=False
fixed_values={"a": 1}, # 'a' is fixed for all calls
parallelism=False,
)

# Expected results after applying the function with the fixed 'a' and given kwargs
Expand All @@ -171,5 +182,5 @@ def test_func(a, **kwargs):
self.assertEqual(result, expected_results)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 295911a

Please sign in to comment.