diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 349094d..5616917 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -18,14 +18,14 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: class DirectCohortDownloadApi(CohortDownloadApi): def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int, - server_url: str, logger: logging.Logger = None): + server_url: str, logger: logging.Logger): super().__init__() self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size self.cohort_request_delay_millis = cohort_request_delay_millis self.server_url = server_url - self.logger = logger or logging.getLogger("Amplitude") + self.logger = logger self.__setup_connection_pool() def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index 17b0181..d525085 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -10,7 +10,7 @@ class CohortLoader: def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage, - logger: logging.Logger = None): + logger: logging.Logger): self.cohort_download_api = cohort_download_api self.cohort_storage = cohort_storage self.jobs: Dict[str, Future] = {} @@ -19,7 +19,7 @@ def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: Cohor max_workers=32, thread_name_prefix='CohortLoaderExecutor' ) - self.logger = logger or logging.getLogger("Amplitude") + self.logger = logger def load_cohort(self, cohort_id: str) -> Future: with self.lock_jobs: diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 39ef62d..b273c0d 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -18,8 +18,8 @@ def __init__( flag_config_api: FlagConfigApi, flag_config_storage: FlagConfigStorage, cohort_storage: CohortStorage, + logger: logging.Logger, cohort_loader: Optional[CohortLoader] = None, - logger: logging.Logger = None ): self.config = config self.flag_config_api = flag_config_api @@ -55,25 +55,30 @@ def __update_flag_configs(self): except Exception as e: self.logger.error(f'Failed to fetch flag configs: {e}') raise e + flag_keys = {flag['key'] for flag in flag_configs} self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys) + if not self.cohort_loader: for flag_config in flag_configs: self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") self.flag_config_storage.put_flag_config(flag_config) return + new_cohort_ids = set() for flag_config in flag_configs: new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) + existing_cohort_ids = self.cohort_storage.get_cohort_ids() cohort_ids_to_download = new_cohort_ids - existing_cohort_ids - cohort_download_error = None + cohort_download_errors = [] + # download all new cohorts for cohort_id in cohort_ids_to_download: try: self.cohort_loader.load_cohort(cohort_id).result() except Exception as e: - cohort_download_error = e + cohort_download_errors.append((cohort_id, str(e))) self.logger.error(f"Download cohort {cohort_id} failed: {e}") # get updated set of cohort ids @@ -85,7 +90,6 @@ def __update_flag_configs(self): if not cohort_ids or not self.cohort_loader: self.flag_config_storage.put_flag_config(flag_config) self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") - # only store flag config if all required cohorts exist elif cohort_ids.issubset(updated_cohort_ids): self.flag_config_storage.put_flag_config(flag_config) self.logger.debug(f"Putting flag {flag_config['key']}") @@ -97,9 +101,12 @@ def __update_flag_configs(self): # delete unused cohorts self._delete_unused_cohorts() self.logger.debug(f"Refreshed {len(flag_configs) - failed_flag_count} flag configs.") - # if not all required cohorts exist, throw an error - if cohort_download_error: - raise cohort_download_error + + # if there are any download errors, raise an aggregated exception + if cohort_download_errors: + error_count = len(cohort_download_errors) + error_messages = "\n".join([f"Cohort {cohort_id}: {error}" for cohort_id, error in cohort_download_errors]) + raise Exception(f"{error_count} cohort(s) failed to download:\n{error_messages}") def __update_cohorts(self): self.cohort_loader.update_stored_cohorts().result() diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index e2b45f7..657c415 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -71,7 +71,7 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, - self.cohort_storage, cohort_loader, self.logger) + self.cohort_storage, self.logger, cohort_loader) def start(self): """ diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index 03c6600..3af9c8e 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -1,5 +1,7 @@ import json +import logging import unittest +from unittest import mock from unittest.mock import MagicMock, patch from src.amplitude_experiment.cohort.cohort import Cohort from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi @@ -17,7 +19,7 @@ def response(code: int, body: dict = None): class CohortDownloadApiTest(unittest.TestCase): def setUp(self): - self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com") + self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com", mock.create_autospec(logging.Logger)) def test_cohort_download_success(self): cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py index 309a8ff..a739b9d 100644 --- a/tests/cohort/cohort_loader_test.py +++ b/tests/cohort/cohort_loader_test.py @@ -1,4 +1,6 @@ +import logging import unittest +from unittest import mock from unittest.mock import MagicMock from src.amplitude_experiment.cohort.cohort import Cohort @@ -9,7 +11,7 @@ class CohortLoaderTest(unittest.TestCase): def setUp(self): self.api = MagicMock() self.storage = InMemoryCohortStorage() - self.loader = CohortLoader(self.api, self.storage) + self.loader = CohortLoader(self.api, self.storage, mock.create_autospec(logging.Logger)) def test_load_success(self): self.api.get_cohort.side_effect = [ diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index c4c0c8f..c5052db 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -38,15 +38,15 @@ def test_start_throws_if_first_flag_config_load_fails(self): flag_config_storage = mock.Mock() cohort_storage = mock.Mock() cohort_storage.get_cohort_ids.return_value = set() - cohort_loader = CohortLoader(cohort_download_api, cohort_storage) logger = mock.create_autospec(logging.Logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, cohort_storage, + logger, cohort_loader, - logger # Pass the logger mock here ) flag_api.get_flag_configs.side_effect = RuntimeError("test") with self.assertRaises(RuntimeError): @@ -58,19 +58,19 @@ def test_start_throws_if_first_cohort_load_fails(self): flag_config_storage = mock.Mock() cohort_storage = mock.Mock() cohort_storage.get_cohort_ids.return_value = set() - cohort_loader = CohortLoader(cohort_download_api, cohort_storage) logger = mock.create_autospec(logging.Logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, cohort_storage, + logger, cohort_loader, - logger # Pass the logger mock here ) with patch.object(runner, '_delete_unused_cohorts'): flag_api.get_flag_configs.return_value = [self.flag] cohort_download_api.get_cohort.side_effect = RuntimeError("test") - with self.assertRaises(RuntimeError): + with self.assertRaises(Exception): runner.start()