Skip to content

Commit

Permalink
Update logger requirement for classes
Browse files Browse the repository at this point in the history
  • Loading branch information
tyiuhc committed Jul 3, 2024
1 parent 5130cdd commit 5ed0a98
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/amplitude_experiment/cohort/cohort_download_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:

class DirectCohortDownloadApi(CohortDownloadApi):
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int,
server_url: str, logger: logging.Logger = None):
server_url: str, logger: logging.Logger):
super().__init__()
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.cohort_request_delay_millis = cohort_request_delay_millis
self.server_url = server_url
self.logger = logger or logging.getLogger("Amplitude")
self.logger = logger
self.__setup_connection_pool()

def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:
Expand Down
4 changes: 2 additions & 2 deletions src/amplitude_experiment/cohort/cohort_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class CohortLoader:
def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage,
logger: logging.Logger = None):
logger: logging.Logger):
self.cohort_download_api = cohort_download_api
self.cohort_storage = cohort_storage
self.jobs: Dict[str, Future] = {}
Expand All @@ -19,7 +19,7 @@ def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: Cohor
max_workers=32,
thread_name_prefix='CohortLoaderExecutor'
)
self.logger = logger or logging.getLogger("Amplitude")
self.logger = logger

def load_cohort(self, cohort_id: str) -> Future:
with self.lock_jobs:
Expand Down
21 changes: 14 additions & 7 deletions src/amplitude_experiment/deployment/deployment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(
flag_config_api: FlagConfigApi,
flag_config_storage: FlagConfigStorage,
cohort_storage: CohortStorage,
logger: logging.Logger,
cohort_loader: Optional[CohortLoader] = None,
logger: logging.Logger = None
):
self.config = config
self.flag_config_api = flag_config_api
Expand Down Expand Up @@ -55,25 +55,30 @@ def __update_flag_configs(self):
except Exception as e:
self.logger.error(f'Failed to fetch flag configs: {e}')
raise e

flag_keys = {flag['key'] for flag in flag_configs}
self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys)

if not self.cohort_loader:
for flag_config in flag_configs:
self.logger.debug(f"Putting non-cohort flag {flag_config['key']}")
self.flag_config_storage.put_flag_config(flag_config)
return

new_cohort_ids = set()
for flag_config in flag_configs:
new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config))

existing_cohort_ids = self.cohort_storage.get_cohort_ids()
cohort_ids_to_download = new_cohort_ids - existing_cohort_ids
cohort_download_error = None
cohort_download_errors = []

# download all new cohorts
for cohort_id in cohort_ids_to_download:
try:
self.cohort_loader.load_cohort(cohort_id).result()
except Exception as e:
cohort_download_error = e
cohort_download_errors.append((cohort_id, str(e)))
self.logger.error(f"Download cohort {cohort_id} failed: {e}")

# get updated set of cohort ids
Expand All @@ -85,7 +90,6 @@ def __update_flag_configs(self):
if not cohort_ids or not self.cohort_loader:
self.flag_config_storage.put_flag_config(flag_config)
self.logger.debug(f"Putting non-cohort flag {flag_config['key']}")
# only store flag config if all required cohorts exist
elif cohort_ids.issubset(updated_cohort_ids):
self.flag_config_storage.put_flag_config(flag_config)
self.logger.debug(f"Putting flag {flag_config['key']}")
Expand All @@ -97,9 +101,12 @@ def __update_flag_configs(self):
# delete unused cohorts
self._delete_unused_cohorts()
self.logger.debug(f"Refreshed {len(flag_configs) - failed_flag_count} flag configs.")
# if not all required cohorts exist, throw an error
if cohort_download_error:
raise cohort_download_error

# if there are any download errors, raise an aggregated exception
if cohort_download_errors:
error_count = len(cohort_download_errors)
error_messages = "\n".join([f"Cohort {cohort_id}: {error}" for cohort_id, error in cohort_download_errors])
raise Exception(f"{error_count} cohort(s) failed to download:\n{error_messages}")

def __update_cohorts(self):
self.cohort_loader.update_stored_cohorts().result()
Expand Down
2 changes: 1 addition & 1 deletion src/amplitude_experiment/local/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None):
flag_config_api = FlagConfigApiV2(api_key, self.config.server_url,
self.config.flag_config_poller_request_timeout_millis)
self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage,
self.cohort_storage, cohort_loader, self.logger)
self.cohort_storage, self.logger, cohort_loader)

def start(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/cohort/cohort_download_api_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
from src.amplitude_experiment.cohort.cohort import Cohort
from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi
Expand All @@ -17,7 +19,7 @@ def response(code: int, body: dict = None):
class CohortDownloadApiTest(unittest.TestCase):

def setUp(self):
self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com")
self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com", mock.create_autospec(logging.Logger))

def test_cohort_download_success(self):
cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'})
Expand Down
4 changes: 3 additions & 1 deletion tests/cohort/cohort_loader_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import unittest
from unittest import mock
from unittest.mock import MagicMock

from src.amplitude_experiment.cohort.cohort import Cohort
Expand All @@ -9,7 +11,7 @@ class CohortLoaderTest(unittest.TestCase):
def setUp(self):
self.api = MagicMock()
self.storage = InMemoryCohortStorage()
self.loader = CohortLoader(self.api, self.storage)
self.loader = CohortLoader(self.api, self.storage, mock.create_autospec(logging.Logger))

def test_load_success(self):
self.api.get_cohort.side_effect = [
Expand Down
10 changes: 5 additions & 5 deletions tests/deployment/deployment_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def test_start_throws_if_first_flag_config_load_fails(self):
flag_config_storage = mock.Mock()
cohort_storage = mock.Mock()
cohort_storage.get_cohort_ids.return_value = set()
cohort_loader = CohortLoader(cohort_download_api, cohort_storage)
logger = mock.create_autospec(logging.Logger)
cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger)
runner = DeploymentRunner(
LocalEvaluationConfig(),
flag_api,
flag_config_storage,
cohort_storage,
logger,
cohort_loader,
logger # Pass the logger mock here
)
flag_api.get_flag_configs.side_effect = RuntimeError("test")
with self.assertRaises(RuntimeError):
Expand All @@ -58,19 +58,19 @@ def test_start_throws_if_first_cohort_load_fails(self):
flag_config_storage = mock.Mock()
cohort_storage = mock.Mock()
cohort_storage.get_cohort_ids.return_value = set()
cohort_loader = CohortLoader(cohort_download_api, cohort_storage)
logger = mock.create_autospec(logging.Logger)
cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger)
runner = DeploymentRunner(
LocalEvaluationConfig(),
flag_api, flag_config_storage,
cohort_storage,
logger,
cohort_loader,
logger # Pass the logger mock here
)
with patch.object(runner, '_delete_unused_cohorts'):
flag_api.get_flag_configs.return_value = [self.flag]
cohort_download_api.get_cohort.side_effect = RuntimeError("test")
with self.assertRaises(RuntimeError):
with self.assertRaises(Exception):
runner.start()


Expand Down

0 comments on commit 5ed0a98

Please sign in to comment.