From def99246d53120dc00c9291dbf1c46245786799d Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 1 Feb 2024 11:13:23 -0800 Subject: [PATCH 01/44] feat: Enable use of project API key for default deployments --- src/amplitude_experiment/factory.py | 20 ++++++++++++-------- src/amplitude_experiment/local/config.py | 6 +++++- src/amplitude_experiment/remote/config.py | 6 +++++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/amplitude_experiment/factory.py b/src/amplitude_experiment/factory.py index 841c55c..b118325 100644 --- a/src/amplitude_experiment/factory.py +++ b/src/amplitude_experiment/factory.py @@ -15,15 +15,17 @@ def initialize_remote(api_key: str, config: RemoteEvaluationConfig = None) -> Re """ Initializes a remote evaluation client. Parameters: - api_key (str): The Amplitude API Key + api_key (str): The Amplitude Project API Key used in the client. If a deployment key is provided in the + config, it will be used instead config (RemoteEvaluationConfig): Optional Config Returns: A remote evaluation client. """ - if remote_evaluation_instances.get(api_key) is None: - remote_evaluation_instances[api_key] = RemoteEvaluationClient(api_key, config) - return remote_evaluation_instances[api_key] + used_key = config.deployment_key if config and config.deployment_key else api_key + if remote_evaluation_instances.get(used_key) is None: + remote_evaluation_instances[used_key] = RemoteEvaluationClient(used_key, config) + return remote_evaluation_instances[used_key] @staticmethod def initialize_local(api_key: str, config: LocalEvaluationConfig = None) -> LocalEvaluationClient: @@ -32,12 +34,14 @@ def initialize_local(api_key: str, config: LocalEvaluationConfig = None) -> Loca user without requiring a remote call to the amplitude evaluation server. In order to best leverage local evaluation, all flags, and experiments being evaluated server side should be configured as local. Parameters: - api_key (str): The Amplitude API Key + api_key (str): The Amplitude Project API Key used in the client. If a deployment key is provided in the + config, it will be used instead config (RemoteEvaluationConfig): Optional Config Returns: A local evaluation client. """ - if local_evaluation_instances.get(api_key) is None: - local_evaluation_instances[api_key] = LocalEvaluationClient(api_key, config) - return local_evaluation_instances[api_key] + used_key = config.deployment_key if config and config.deployment_key else api_key + if local_evaluation_instances.get(used_key) is None: + local_evaluation_instances[used_key] = LocalEvaluationClient(used_key, config) + return local_evaluation_instances[used_key] diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index 027467d..6714815 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -10,7 +10,8 @@ def __init__(self, debug: bool = False, server_url: str = DEFAULT_SERVER_URL, flag_config_polling_interval_millis: int = 30000, flag_config_poller_request_timeout_millis: int = 10000, - assignment_config: AssignmentConfig = None): + assignment_config: AssignmentConfig = None, + deployment_key: str = None): """ Initialize a config Parameters: @@ -21,6 +22,8 @@ def __init__(self, debug: bool = False, to perform local evaluation. flag_config_poller_request_timeout_millis (int): The request timeout, in milliseconds, used when fetching variants. + deployment_key (str): The Experiment deployment key. If provided, it is used + instead of the project API key Returns: The config object @@ -30,3 +33,4 @@ def __init__(self, debug: bool = False, self.flag_config_polling_interval_millis = flag_config_polling_interval_millis self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis self.assignment_config = assignment_config + self.deployment_key = deployment_key diff --git a/src/amplitude_experiment/remote/config.py b/src/amplitude_experiment/remote/config.py index 7e84bf5..b0a217c 100644 --- a/src/amplitude_experiment/remote/config.py +++ b/src/amplitude_experiment/remote/config.py @@ -10,7 +10,8 @@ def __init__(self, debug=False, fetch_retry_backoff_min_millis=500, fetch_retry_backoff_max_millis=10000, fetch_retry_backoff_scalar=1.5, - fetch_retry_timeout_millis=10000): + fetch_retry_timeout_millis=10000, + deployment_key=None): """ Initialize a config Parameters: @@ -25,6 +26,8 @@ def __init__(self, debug=False, greater than the max, the max is used for all subsequent retries. fetch_retry_backoff_scalar (float): Scales the minimum backoff exponentially. fetch_retry_timeout_millis (int): The request timeout for retrying fetch requests. + deployment_key (str): The Experiment deployment key. If provided, it is used + instead of the project API key Returns: The config object @@ -37,3 +40,4 @@ def __init__(self, debug=False, self.fetch_retry_backoff_max_millis = fetch_retry_backoff_max_millis self.fetch_retry_backoff_scalar = fetch_retry_backoff_scalar self.fetch_retry_timeout_millis = fetch_retry_timeout_millis + self.deployment_key = deployment_key From 405dcb6f0762c027f5524891cd7601cceaea6e86 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 6 Jun 2024 15:17:36 -0700 Subject: [PATCH 02/44] initial commit --- .../cohort/cohort_description.py | 12 ++ .../cohort/cohort_download_api.py | 173 +++++++++++++++++ .../cohort/cohort_loader.py | 77 ++++++++ .../cohort/cohort_storage.py | 64 +++++++ .../cohort/cohort_sync_config.py | 6 + .../deployment/deployment_runner.py | 94 +++++++++ src/amplitude_experiment/exception.py | 12 ++ .../flag/flag_config_api.py | 46 +++++ .../flag/flag_config_storage.py | 31 +++ src/amplitude_experiment/local/config.py | 5 +- src/amplitude_experiment/util/flag_config.py | 54 ++++++ tests/cohort/cohort_download_api_test.py | 180 ++++++++++++++++++ tests/cohort/cohort_loader_test.py | 94 +++++++++ tests/deployment/deployment_runner_test.py | 70 +++++++ tests/util/flag_config_test.py | 125 ++++++++++++ 15 files changed, 1042 insertions(+), 1 deletion(-) create mode 100644 src/amplitude_experiment/cohort/cohort_description.py create mode 100644 src/amplitude_experiment/cohort/cohort_download_api.py create mode 100644 src/amplitude_experiment/cohort/cohort_loader.py create mode 100644 src/amplitude_experiment/cohort/cohort_storage.py create mode 100644 src/amplitude_experiment/cohort/cohort_sync_config.py create mode 100644 src/amplitude_experiment/deployment/deployment_runner.py create mode 100644 src/amplitude_experiment/flag/flag_config_api.py create mode 100644 src/amplitude_experiment/flag/flag_config_storage.py create mode 100644 src/amplitude_experiment/util/flag_config.py create mode 100644 tests/cohort/cohort_download_api_test.py create mode 100644 tests/cohort/cohort_loader_test.py create mode 100644 tests/deployment/deployment_runner_test.py create mode 100644 tests/util/flag_config_test.py diff --git a/src/amplitude_experiment/cohort/cohort_description.py b/src/amplitude_experiment/cohort/cohort_description.py new file mode 100644 index 0000000..d4d882b --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_description.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass, field +from typing import ClassVar + +USER_GROUP_TYPE: ClassVar[str] = "User" + + +@dataclass +class CohortDescription: + id: str + last_computed: int + size: int + group_type: str = field(default=USER_GROUP_TYPE) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py new file mode 100644 index 0000000..ff0894a --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -0,0 +1,173 @@ +import time +import logging +import base64 +import json +import csv +from io import StringIO +from typing import Set + +from src.amplitude_experiment.cohort.cohort_description import CohortDescription, USER_GROUP_TYPE +from src.amplitude_experiment.connection_pool import HTTPConnectionPool +from src.amplitude_experiment.exception import CachedCohortDownloadException, HTTPErrorResponseException + +CDN_COHORT_SYNC_URL = 'https://cohort.lab.amplitude.com' + + +class CohortDownloadApi: + def __init__(self): + self.cdn_server_url = CDN_COHORT_SYNC_URL + + def get_cohort_description(self, cohort_id: str) -> CohortDescription: + raise NotImplementedError + + def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: + raise NotImplementedError + + +class DirectCohortDownloadApiV5(CohortDownloadApi): + def __init__(self, api_key: str, secret_key: str): + super().__init__() + self.api_key = api_key + self.secret_key = secret_key + self.__setup_connection_pool() + self.request_status_delay = 2 # seconds, adjust as necessary + + def get_cohort_description(self, cohort_id: str) -> CohortDescription: + response = self.get_cohort_info(cohort_id) + cohort_info = json.loads(response.read().decode("utf8")) + return CohortDescription( + id=cohort_info['cohort_id'], + last_computed=cohort_info['last_computed'], + size=cohort_info['size'], + group_type=cohort_info['group_type'], + ) + + def get_cohort_info(self, cohort_id: str): + conn = self._connection_pool.acquire() + try: + return conn.request('GET', f'api/3/cohorts/info/{cohort_id}', + headers={'Authorization': f'Basic {self._get_basic_auth()}'}) + finally: + self._connection_pool.release(conn) + + def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: + try: + logging.debug(f"getCohortMembers({cohort_description.id}): start - {cohort_description}") + initial_response = self._get_cohort_async_request(cohort_description) + request_id = json.loads(initial_response.read().decode('utf-8'))['request_id'] + logging.debug(f"getCohortMembers({cohort_description.id}): requestId={request_id}") + + errors = 0 + while True: + try: + status_response = self._get_cohort_async_request_status(request_id) + logging.debug(f"getCohortMembers({cohort_description.id}): status={status_response.status}") + if status_response.status == 200: + break + elif status_response.status != 202: + raise HTTPErrorResponseException(status_response.status, + f"Unexpected response code: {status_response.status}") + except Exception as e: + if not isinstance(e, HTTPErrorResponseException) or e.status_code != 429: + errors += 1 + logging.debug(f"getCohortMembers({cohort_description.id}): request-status error {errors} - {e}") + if errors >= 3: + raise e + time.sleep(self.request_status_delay) + + location = self._get_cohort_async_request_location(request_id) + members = self._get_cohort_async_request_members(cohort_description.id, cohort_description.group_type, + location) + logging.debug(f"getCohortMembers({cohort_description.id}): end - resultSize={len(members)}") + return members + except Exception as e1: + try: + cached_members = self._get_cached_cohort_members(cohort_description.id, cohort_description.group_type) + logging.debug( + f"getCohortMembers({cohort_description.id}): end cached fallback - resultSize={len(cached_members)}") + raise CachedCohortDownloadException(cached_members, e1) + except Exception as e2: + raise e2 + + def _get_cohort_async_request(self, cohort_description: CohortDescription): + conn = self._connection_pool.acquire() + try: + return conn.request('GET', f'api/5/cohorts/request/{cohort_description.id}', + headers={'Authorization': f'Basic {self._get_basic_auth()}'}, + queries={'lastComputed': str(cohort_description.last_computed)}) + finally: + self._connection_pool.release(conn) + + def _get_cohort_async_request_status(self, request_id: str): + conn = self._connection_pool.acquire() + try: + return conn.request('GET', f'api/5/cohorts/request-status/{request_id}', + headers={'Authorization': f'Basic {self._get_basic_auth()}'}) + finally: + self._connection_pool.release(conn) + + def _get_cohort_async_request_location(self, request_id: str): + conn = self._connection_pool.acquire() + try: + response = conn.request('GET', f'api/5/cohorts/request-status/{request_id}/file', + headers={'Authorization': f'Basic {self._get_basic_auth()}'}) + location_header = response.headers.get('location') + if not location_header: + raise ValueError('Cohort response location must not be null') + return location_header + finally: + self._connection_pool.release(conn) + + def _get_cohort_async_request_members(self, cohort_id: str, group_type: str, location: str) -> Set[str]: + headers = { + 'X-Amp-Authorization': f'Basic {self._get_basic_auth()}', + 'X-Cohort-ID': cohort_id, + } + conn = self._connection_pool.acquire() + try: + response = conn.request('GET', location, headers) + return self._parse_csv_response(response.read(), group_type) + finally: + self._connection_pool.release(conn) + + def get_cached_cohort_members(self, cohort_id: str, group_type: str) -> Set[str]: + headers = { + 'X-Amp-Authorization': f'Basic {self._get_basic_auth()}', + 'X-Cohort-ID': cohort_id, + } + conn = self._connection_pool.acquire() + try: + response = conn.request('GET', 'cohorts', headers) + input_stream = response.read() + if not input_stream: + raise ValueError('Cohort response body must not be null') + return self._parse_csv_response(input_stream, group_type) + finally: + self._connection_pool.release(conn) + + @staticmethod + def _parse_csv_response(input_stream: bytes, group_type: str) -> Set[str]: + csv_file = StringIO(input_stream.decode('utf-8')) + csv_data = list(csv.DictReader(csv_file)) + if group_type == USER_GROUP_TYPE: + return {row['user_id'] for row in csv_data if row['user_id']} + else: + values = set() + for row in csv_data: + try: + value = row.get('\tgroup_value', row.get('group_value')) + if value: + values.add(value.lstrip("\t")) + except ValueError: + pass + return values + + def _get_basic_auth(self) -> str: + credentials = f'{self.api_key}:{self.secret_key}' + return base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + + def __setup_connection_pool(self): + scheme, _, host = self.cdn_server_url.split('/', 3) + timeout = 10 + self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, read_timeout=timeout, + scheme=scheme) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py new file mode 100644 index 0000000..10ad7a1 --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -0,0 +1,77 @@ +from typing import Dict, Set, Optional +from concurrent.futures import ThreadPoolExecutor, Future +import threading + +from src.amplitude_experiment.cohort.cohort_description import CohortDescription +from src.amplitude_experiment.cohort.cohort_download_api import CohortDownloadApi, DirectCohortDownloadApiV5 +from src.amplitude_experiment.cohort.cohort_storage import CohortStorage + + +class CohortLoader: + def __init__(self, max_cohort_size: int, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage, + direct_cohort_download_api: Optional[DirectCohortDownloadApiV5] = None): + self.max_cohort_size = max_cohort_size + self.cohort_download_api = cohort_download_api + self.cohort_storage = cohort_storage + self.direct_cohort_download_api = direct_cohort_download_api + + self.jobs: Dict[str, Future] = {} + self.cached_jobs: Dict[str, Future] = {} + + self.lock_jobs = threading.Lock() + self.lock_cached_jobs = threading.Lock() + + self.executor = ThreadPoolExecutor( + max_workers=32, + thread_name_prefix='CohortLoaderExecutor' + ) + + def load_cohort(self, cohort_id: str) -> Future: + with self.lock_jobs: + if cohort_id not in self.jobs: + def task(): + print(f"Loading cohort {cohort_id}") + cohort_description = self.get_cohort_description(cohort_id) + if self.should_download_cohort(cohort_description): + cohort_members = self.download_cohort(cohort_description) + self.cohort_storage.put_cohort(cohort_description, cohort_members) + + future = self.executor.submit(task) + future.add_done_callback(lambda _: self.jobs.pop(cohort_id, None)) + self.jobs[cohort_id] = future + return self.jobs[cohort_id] + + def load_cached_cohort(self, cohort_id: str) -> Future: + with self.lock_cached_jobs: + if cohort_id not in self.cached_jobs: + def task(): + print(f"Loading cohort from cache {cohort_id}") + cohort_description = self.get_cohort_description(cohort_id) + cohort_description.last_computed = 0 + if self.should_download_cohort(cohort_description): + cohort_members = self.download_cached_cohort(cohort_description) + self.cohort_storage.put_cohort(cohort_description, cohort_members) + + future = self.executor.submit(task) + self.cached_jobs[cohort_id] = future + future.add_done_callback(lambda _: self.cached_jobs.pop(cohort_id, None)) + return future + else: + return self.cached_jobs[cohort_id] + + def get_cohort_description(self, cohort_id: str) -> CohortDescription: + return self.cohort_download_api.get_cohort_description(cohort_id) + + def should_download_cohort(self, cohort_description: CohortDescription) -> bool: + storage_description = self.cohort_storage.get_cohort_description(cohort_description.id) + return (cohort_description.size <= self.max_cohort_size and + cohort_description.last_computed > (storage_description.last_computed if storage_description else -1)) + + def download_cohort(self, cohort_description: CohortDescription) -> Set[str]: + return self.cohort_download_api.get_cohort_members(cohort_description) + + def download_cached_cohort(self, cohort_description: CohortDescription) -> Set[str]: + return (self.direct_cohort_download_api.get_cached_cohort_members(cohort_description.id, + cohort_description.group_type) + if self.direct_cohort_download_api else + self.cohort_download_api.get_cohort_members(cohort_description)) diff --git a/src/amplitude_experiment/cohort/cohort_storage.py b/src/amplitude_experiment/cohort/cohort_storage.py new file mode 100644 index 0000000..479aba8 --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_storage.py @@ -0,0 +1,64 @@ +from typing import Dict, Set, Optional +from threading import RLock + +from src.amplitude_experiment.cohort.cohort_description import CohortDescription, USER_GROUP_TYPE + + +class CohortStorage: + def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: + raise NotImplementedError() + + def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: + raise NotImplementedError() + + def get_cohort_description(self, cohort_id: str) -> Optional[CohortDescription]: + raise NotImplementedError() + + def get_cohort_descriptions(self) -> Dict[str, CohortDescription]: + raise NotImplementedError() + + def put_cohort(self, cohort_description: CohortDescription, members: Set[str]): + raise NotImplementedError() + + def delete_cohort(self, group_type: str, cohort_id: str): + raise NotImplementedError() + + +class InMemoryCohortStorage(CohortStorage): + def __init__(self): + self.lock = RLock() + self.cohort_store: Dict[str, Dict[str, Set[str]]] = {} + self.description_store: Dict[str, CohortDescription] = {} + + def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: + return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids) + + def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: + result = set() + with self.lock: + group_type_cohorts = self.cohort_store.get(group_type, {}) + for cohort_id, members in group_type_cohorts.items(): + if cohort_id in cohort_ids and group_name in members: + result.add(cohort_id) + return result + + def get_cohort_description(self, cohort_id: str) -> Optional[CohortDescription]: + with self.lock: + return self.description_store.get(cohort_id) + + def get_cohort_descriptions(self) -> Dict[str, CohortDescription]: + with self.lock: + return self.description_store.copy() + + def put_cohort(self, cohort_description: CohortDescription, members: Set[str]): + with self.lock: + self.cohort_store.setdefault(cohort_description.group_type, {})[cohort_description.id] = members + self.description_store[cohort_description.id] = cohort_description + + def delete_cohort(self, group_type: str, cohort_id: str): + with self.lock: + group_cohorts = self.cohort_store.get(group_type, {}) + if cohort_id in group_cohorts: + del group_cohorts[cohort_id] + if cohort_id in self.description_store: + del self.description_store[cohort_id] diff --git a/src/amplitude_experiment/cohort/cohort_sync_config.py b/src/amplitude_experiment/cohort/cohort_sync_config.py new file mode 100644 index 0000000..b609bed --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_sync_config.py @@ -0,0 +1,6 @@ +class CohortSyncConfig: + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000): + self.api_key = api_key + self.secret_key = secret_key + self.max_cohort_size = max_cohort_size + diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py new file mode 100644 index 0000000..bd85c62 --- /dev/null +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -0,0 +1,94 @@ +import logging +from typing import Optional +import threading +import time + +from src.amplitude_experiment import LocalEvaluationConfig +from src.amplitude_experiment.cohort.cohort_loader import CohortLoader +from src.amplitude_experiment.cohort.cohort_storage import CohortStorage +from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi +from src.amplitude_experiment.flag.flag_config_storage import FlagConfigStorage +from src.amplitude_experiment.local.poller import Poller +from src.amplitude_experiment.util.flag_config import get_all_cohort_ids + + +class DeploymentRunner: + def __init__( + self, + config: LocalEvaluationConfig, + flag_config_api: FlagConfigApi, + flag_config_storage: FlagConfigStorage, + cohort_storage: CohortStorage, + cohort_loader: Optional[CohortLoader] = None, + ): + self.config = config + self.flag_config_api = flag_config_api + self.flag_config_storage = flag_config_storage + self.cohort_storage = cohort_storage + self.cohort_loader = cohort_loader + self.lock = threading.Lock() + self.poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_refresh) + self.logger = logging.getLogger("Amplitude") + + def start(self): + with self.lock: + self.refresh(initial=True) + self.poller.start() + + def stop(self): + self.poller.stop() + + def __periodic_refresh(self): + while True: + try: + self.refresh(initial=False) + except Exception as e: + self.logger.error("Refresh flag configs failed.", e) + time.sleep(self.config.flag_config_polling_interval_millis / 1000) + + def refresh(self, initial: bool): + self.logger.debug("Refreshing flag configs.") + flag_configs = self.flag_config_api.get_flag_configs() + + flag_keys = {flag['key'] for flag in flag_configs} + self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) + + if initial: + cached_futures = {} + for flag_config in flag_configs: + cohort_ids = get_all_cohort_ids(flag_config) + if not self.cohort_loader or not cohort_ids: + self.flag_config_storage.put_flag_config(flag_config) + continue + for cohort_id in cohort_ids: + future = self.cohort_loader.load_cached_cohort(cohort_id) + future.add_done_callback(lambda _: self.flag_config_storage.put_flag_config(flag_config)) + cached_futures[cohort_id] = future + try: + for future in cached_futures.values(): + future.result() + except Exception as e: + self.logger.warning("Failed to download a cohort from the cache", e) + + futures = {} + for flag_config in flag_configs: + cohort_ids = get_all_cohort_ids(flag_config) + if not self.cohort_loader or not cohort_ids: + self.flag_config_storage.put_flag_config(flag_config) + continue + for cohort_id in cohort_ids: + future = self.cohort_loader.load_cohort(cohort_id) + future.add_done_callback(lambda _: self.flag_config_storage.put_flag_config(flag_config)) + futures[cohort_id] = future + if initial: + for future in futures.values(): + future.result() + + flag_cohort_ids = {flag['key'] for flag in self.flag_config_storage.get_flag_configs().values()} + deleted_cohort_ids = set(self.cohort_storage.get_cohort_descriptions().keys()) - flag_cohort_ids + for deleted_cohort_id in deleted_cohort_ids: + deleted_cohort_description = self.cohort_storage.get_cohort_description(deleted_cohort_id) + if deleted_cohort_description: + self.cohort_storage.delete_cohort(deleted_cohort_description.group_type, deleted_cohort_id) + + self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 58dd305..defdca3 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -2,3 +2,15 @@ class FetchException(Exception): def __init__(self, status_code, message): super().__init__(message) self.status_code = status_code + + +class CachedCohortDownloadException(Exception): + def __init__(self, cached_members, message): + super().__init__(message) + self.cached_members = cached_members + + +class HTTPErrorResponseException(Exception): + def __init__(self, status_code, message): + super().__init__(message) + self.status_code = status_code diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py new file mode 100644 index 0000000..10b84bb --- /dev/null +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -0,0 +1,46 @@ +import json +from typing import List + +from ..version import __version__ + +from src.amplitude_experiment.connection_pool import HTTPConnectionPool + + +class FlagConfigApi: + def get_flag_configs(self) -> List: + pass + + +class FlagConfigApiV2(FlagConfigApi): + def __init__(self, deployment_key: str, server_url: str, flag_config_poller_request_timeout_millis: int): + self.deployment_key = deployment_key + self.server_url = server_url + self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis + + def get_flag_configs(self) -> List: + return self._get_flag_configs() + + def _get_flag_configs(self) -> List: + conn = self._connection_pool.acquire() + headers = { + 'Authorization': f"Api-Key {self.deployment_key}", + 'Content-Type': 'application/json;charset=utf-8', + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" + } + body = None + try: + response = conn.request('GET', '/sdk/v2/flags?v=0', body, headers) + response_body = response.read().decode("utf8") + if response.status != 200: + raise Exception( + f"[Experiment] Get flagConfigs - received error response: ${response.status}: ${response_body}") + flags = json.loads(response_body) + return flags + finally: + self._connection_pool.release(conn) + + def __setup_connection_pool(self): + scheme, _, host = self.server_url.split('/', 3) + timeout = self.flag_config_poller_request_timeout_millis / 1000 + self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, + read_timeout=timeout, scheme=scheme) diff --git a/src/amplitude_experiment/flag/flag_config_storage.py b/src/amplitude_experiment/flag/flag_config_storage.py new file mode 100644 index 0000000..c949f7a --- /dev/null +++ b/src/amplitude_experiment/flag/flag_config_storage.py @@ -0,0 +1,31 @@ +from typing import Dict, Callable +from threading import Lock + + +class FlagConfigStorage: + def get_flag_configs(self) -> Dict: + pass + + def put_flag_config(self, flag_config: Dict): + pass + + def remove_if(self, condition: Callable[[Dict], bool]): + pass + + +class InMemoryFlagConfigStorage(FlagConfigStorage): + def __init__(self): + self.flag_configs = {} + self.flag_configs_lock = Lock() + + def get_flag_configs(self) -> Dict[str, Dict]: + with self.flag_configs_lock: + return self.flag_configs.copy() + + def put_flag_config(self, flag_config: Dict): + with self.flag_configs_lock: + self.flag_configs[flag_config['key']] = flag_config + + def remove_if(self, condition: Callable[[Dict], bool]): + with self.flag_configs_lock: + self.flag_configs = {key: value for key, value in self.flag_configs.items() if not condition(value)} diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index 6714815..1183a9c 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -1,4 +1,5 @@ from ..assignment import AssignmentConfig +from ..cohort.cohort_sync_config import CohortSyncConfig class LocalEvaluationConfig: @@ -11,7 +12,8 @@ def __init__(self, debug: bool = False, flag_config_polling_interval_millis: int = 30000, flag_config_poller_request_timeout_millis: int = 10000, assignment_config: AssignmentConfig = None, - deployment_key: str = None): + deployment_key: str = None, + cohort_sync_config: CohortSyncConfig = None): """ Initialize a config Parameters: @@ -34,3 +36,4 @@ def __init__(self, debug: bool = False, self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis self.assignment_config = assignment_config self.deployment_key = deployment_key + self.cohort_sync_config = cohort_sync_config diff --git a/src/amplitude_experiment/util/flag_config.py b/src/amplitude_experiment/util/flag_config.py new file mode 100644 index 0000000..b18791f --- /dev/null +++ b/src/amplitude_experiment/util/flag_config.py @@ -0,0 +1,54 @@ +from typing import List, Dict, Set, Any + +from src.amplitude_experiment.cohort.cohort_description import USER_GROUP_TYPE + + +def is_cohort_filter(condition: Dict[str, Any]) -> bool: + return ( + condition['op'] in {"set contains any", "set does not contain any"} + and condition['selector'] + and condition['selector'][-1] == "cohort_ids" + ) + + +def get_grouped_cohort_condition_ids(segment: Dict[str, Any]) -> Dict[str, Set[str]]: + cohort_ids = {} + conditions = segment.get('conditions', []) + for outer in conditions: + for condition in outer: + if is_cohort_filter(condition): + if len(condition['selector']) > 2: + context_subtype = condition['selector'][1] + if context_subtype == "user": + group_type = USER_GROUP_TYPE + elif "groups" in condition['selector']: + group_type = condition['selector'][2] + else: + continue + cohort_ids.setdefault(group_type, set()).update(condition['values']) + return cohort_ids + + +def get_grouped_cohort_ids(flag: Dict[str, Any]) -> Dict[str, Set[str]]: + cohort_ids = {} + segments = flag.get('segments', []) + for segment in segments: + for key, values in get_grouped_cohort_condition_ids(segment).items(): + cohort_ids.setdefault(key, set()).update(values) + return cohort_ids + + +def get_all_cohort_ids(flag: Dict[str, Any]) -> Set[str]: + return {cohort_id for values in get_grouped_cohort_ids(flag).values() for cohort_id in values} + + +def get_grouped_cohort_ids_from_flags(flags: List[Dict[str, Any]]) -> Dict[str, Set[str]]: + cohort_ids = {} + for flag in flags: + for key, values in get_grouped_cohort_ids(flag).items(): + cohort_ids.setdefault(key, set()).update(values) + return cohort_ids + + +def get_all_cohort_ids_from_flags(flags: List[Dict[str, Any]]) -> Set[str]: + return {cohort_id for values in get_grouped_cohort_ids_from_flags(flags).values() for cohort_id in values} diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py new file mode 100644 index 0000000..9ef0ed2 --- /dev/null +++ b/tests/cohort/cohort_download_api_test.py @@ -0,0 +1,180 @@ +import json +import unittest +from unittest.mock import MagicMock +from src.amplitude_experiment.cohort.cohort_description import CohortDescription, USER_GROUP_TYPE +from src.amplitude_experiment.exception import CachedCohortDownloadException +from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApiV5 +from urllib.parse import urlparse + + +def response(code: int): + mock_response = MagicMock() + mock_response.status = code + mock_response.headers = {'location': 'https://example.com/cohorts/Cohort_asdf?asdf=asdf#asdf'} + return mock_response + + +class CohortDownloadApiTest(unittest.TestCase): + location = 'https://example.com/cohorts/Cohort_asdf?asdf=asdf#asdf' + + def test_cohort_download_success(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1) + async_request_response = MagicMock() + async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() + + async_request_status_response = response(200) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(return_value=async_request_response) + api._get_cohort_async_request_status = MagicMock(return_value=async_request_status_response) + api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) + api._get_cohort_async_request_members = MagicMock(return_value={'user'}) + + members = api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) + api._get_cohort_async_request.assert_called_once_with(cohort) + api._get_cohort_async_request_status.assert_called_once_with('4321') + api._get_cohort_async_request_location.assert_called_once_with('4321') + api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) + + def test_cohort_download_many_202s_success(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1) + async_request_response = MagicMock() + async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() + + async_request_status_202_response = response(202) + async_request_status_200_response = response(200) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(return_value=async_request_response) + api._get_cohort_async_request_status = MagicMock( + side_effect=[async_request_status_202_response] * 9 + [async_request_status_200_response]) + api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) + api._get_cohort_async_request_members = MagicMock(return_value={'user'}) + + members = api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) + api._get_cohort_async_request.assert_called_once_with(cohort) + self.assertEqual(api._get_cohort_async_request_status.call_count, 10) + api._get_cohort_async_request_location.assert_called_once_with('4321') + api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) + + def test_cohort_request_status_with_two_failures_succeeds(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1) + async_request_response = MagicMock() + async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() + + async_request_status_503_response = response(503) + async_request_status_200_response = response(200) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(return_value=async_request_response) + api._get_cohort_async_request_status = MagicMock( + side_effect=[async_request_status_503_response, async_request_status_503_response, + async_request_status_200_response]) + api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) + api._get_cohort_async_request_members = MagicMock(return_value={'user'}) + + members = api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) + api._get_cohort_async_request.assert_called_once_with(cohort) + self.assertEqual(api._get_cohort_async_request_status.call_count, 3) + api._get_cohort_async_request_location.assert_called_once_with('4321') + api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) + + def test_cohort_request_status_throws_after_3_failures_cache_fallback_succeeds(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1) + async_request_response = MagicMock() + async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() + + async_request_status_response = response(503) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(return_value=async_request_response) + api._get_cohort_async_request_status = MagicMock(return_value=async_request_status_response) + api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) + api._get_cohort_async_request_members = MagicMock(return_value={'user'}) + api._get_cached_cohort_members = MagicMock(return_value={'user2'}) + + with self.assertRaises(CachedCohortDownloadException) as e: + api.get_cohort_members(cohort) + + self.assertEqual({'user2'}, e.exception.cached_members) + api._get_cohort_async_request.assert_called_once_with(cohort) + self.assertEqual(api._get_cohort_async_request_status.call_count, 3) + api._get_cohort_async_request_location.assert_not_called() + api._get_cohort_async_request_members.assert_not_called() + api._get_cached_cohort_members.assert_called_once_with('1234', USER_GROUP_TYPE) + + def test_cohort_request_status_429s_keep_retrying(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1) + async_request_response = MagicMock() + async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() + + async_request_status_429_response = response(429) + async_request_status_200_response = response(200) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(return_value=async_request_response) + api._get_cohort_async_request_status = MagicMock( + side_effect=[async_request_status_429_response] * 9 + [async_request_status_200_response]) + api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) + api._get_cohort_async_request_members = MagicMock(return_value={'user'}) + + members = api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) + api._get_cohort_async_request.assert_called_once_with(cohort) + self.assertEqual(api._get_cohort_async_request_status.call_count, 10) + api._get_cohort_async_request_location.assert_called_once_with('4321') + api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) + + def test_cohort_async_request_download_failure_falls_back_on_cached_request(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(side_effect=Exception('fail')) + api._get_cached_cohort_members = MagicMock(return_value={'user'}) + + with self.assertRaises(CachedCohortDownloadException) as e: + api.get_cohort_members(cohort) + + self.assertEqual({'user'}, e.exception.cached_members) + api._get_cached_cohort_members.assert_called_once_with('1234', USER_GROUP_TYPE) + + def test_group_cohort_download_success(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type="org name") + async_request_response = MagicMock() + async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() + + async_request_status_response = response(200) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(return_value=async_request_response) + api._get_cohort_async_request_status = MagicMock(return_value=async_request_status_response) + api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) + api._get_cohort_async_request_members = MagicMock(return_value={'group'}) + + members = api.get_cohort_members(cohort) + self.assertEqual({'group'}, members) + api._get_cohort_async_request.assert_called_once_with(cohort) + api._get_cohort_async_request_status.assert_called_once_with('4321') + api._get_cohort_async_request_location.assert_called_once_with('4321') + api._get_cohort_async_request_members.assert_called_once_with('1234', 'org name', urlparse(self.location)) + + def test_group_cohort_request_status_429s_keep_retrying(self): + cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type="org name") + async_request_response = MagicMock() + async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() + + async_request_status_429_response = response(429) + async_request_status_200_response = response(200) + api = DirectCohortDownloadApiV5('api', 'secret') + api._get_cohort_async_request = MagicMock(return_value=async_request_response) + api._get_cohort_async_request_status = MagicMock( + side_effect=[async_request_status_429_response] * 9 + [async_request_status_200_response]) + api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) + api._get_cohort_async_request_members = MagicMock(return_value={'group'}) + + members = api.get_cohort_members(cohort) + self.assertEqual({'group'}, members) + api._get_cohort_async_request.assert_called_once_with(cohort) + self.assertEqual(api._get_cohort_async_request_status.call_count, 10) + api._get_cohort_async_request_location.assert_called_once_with('4321') + api._get_cohort_async_request_members.assert_called_once_with('1234', 'org name', urlparse(self.location)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py new file mode 100644 index 0000000..7d0212c --- /dev/null +++ b/tests/cohort/cohort_loader_test.py @@ -0,0 +1,94 @@ +import unittest +from unittest.mock import MagicMock + +from src.amplitude_experiment.cohort.cohort_description import CohortDescription +from src.amplitude_experiment.cohort.cohort_loader import CohortLoader +from src.amplitude_experiment.cohort.cohort_storage import InMemoryCohortStorage + + +class CohortLoaderTest(unittest.TestCase): + def setUp(self): + self.config = MagicMock() + self.api = MagicMock() + self.storage = InMemoryCohortStorage() + self.loader = CohortLoader(15000, self.api, self.storage) + + def test_load_success(self): + self.api.get_cohort_description.side_effect = [cohort_description("a"), cohort_description("b")] + self.api.get_cohort_members.side_effect = [{"1"}, {"1", "2"}] + + # Submitting tasks asynchronously + future_a = self.loader.load_cohort("a") + future_b = self.loader.load_cohort("b") + + # Asserting after tasks complete + future_a.result() + future_b.result() + + storage_description_a = self.storage.get_cohort_description("a") + storage_description_b = self.storage.get_cohort_description("b") + self.assertEqual(cohort_description("a"), storage_description_a) + self.assertEqual(cohort_description("b"), storage_description_b) + + storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) + storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) + self.assertEqual({"a", "b"}, storage_user1_cohorts) + self.assertEqual({"b"}, storage_user2_cohorts) + + def test_load_cohorts_greater_than_max_cohort_size_are_filtered(self): + self.api.get_cohort_description.side_effect = [cohort_description("a", size=float("inf")), + cohort_description("b", size=1)] + self.api.get_cohort_members.side_effect = [{"1", "2"}] + + self.loader.load_cohort("a").result() + self.loader.load_cohort("b").result() + + storage_description_a = self.storage.get_cohort_description("a") + storage_description_b = self.storage.get_cohort_description("b") + self.assertIsNone(storage_description_a) + self.assertEqual(cohort_description("b", size=1), storage_description_b) + + storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) + storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) + self.assertEqual({"b"}, storage_user1_cohorts) + self.assertEqual({"b"}, storage_user2_cohorts) + + def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): + self.storage.put_cohort(cohort_description("a", last_computed=0), set()) + self.storage.put_cohort(cohort_description("b", last_computed=0), set()) + self.api.get_cohort_description.side_effect = [cohort_description("a", last_computed=0), + cohort_description("b", last_computed=1)] + self.api.get_cohort_members.side_effect = [{"1", "2"}] + + self.loader.load_cohort("a").result() + self.loader.load_cohort("b").result() + + storage_description_a = self.storage.get_cohort_description("a") + storage_description_b = self.storage.get_cohort_description("b") + self.assertEqual(cohort_description("a", last_computed=0), storage_description_a) + self.assertEqual(cohort_description("b", last_computed=1), storage_description_b) + + storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) + storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) + self.assertEqual({"b"}, storage_user1_cohorts) + self.assertEqual({"b"}, storage_user2_cohorts) + + def test_load_download_failure_throws(self): + self.api.get_cohort_description.side_effect = [cohort_description("a"), cohort_description("b"), + cohort_description("c")] + self.api.get_cohort_members.side_effect = [{"1"}, Exception("Connection timed out"), {"1"}] + + self.loader.load_cohort("a").result() + with self.assertRaises(Exception): + self.loader.load_cohort("b").result() + self.loader.load_cohort("c").result() + + self.assertEqual({"a", "c"}, self.storage.get_cohorts_for_user("1", {"a", "b", "c"})) + + +def cohort_description(cohort_id, last_computed=0, size=0): + return CohortDescription(id=cohort_id, last_computed=last_computed, size=size) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py new file mode 100644 index 0000000..35c9e7d --- /dev/null +++ b/tests/deployment/deployment_runner_test.py @@ -0,0 +1,70 @@ +import unittest +from unittest import mock + +from src.amplitude_experiment import LocalEvaluationConfig +from src.amplitude_experiment.cohort.cohort_loader import CohortLoader +from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi +from src.amplitude_experiment.deployment.deployment_runner import DeploymentRunner + +COHORT_ID = '1234' + + +class DeploymentRunnerTest(unittest.TestCase): + + def setUp(self): + self.flag = { + "key": "flag", + "variants": {}, + "segments": [ + { + "conditions": [ + [ + { + "selector": ["context", "user", "cohort_ids"], + "op": "set contains any", + "values": [COHORT_ID], + } + ] + ], + } + ] + } + + def test_start_throws_if_first_flag_config_load_fails(self): + flag_api = mock.create_autospec(FlagConfigApi) + cohort_download_api = mock.Mock() + flag_config_storage = mock.Mock() + cohort_storage = mock.Mock() + cohort_loader = CohortLoader(100, cohort_download_api, cohort_storage) + runner = DeploymentRunner( + LocalEvaluationConfig(), + flag_api, + flag_config_storage, + cohort_storage, + cohort_loader + ) + flag_api.get_flag_configs.side_effect = RuntimeError("test") + with self.assertRaises(RuntimeError): + runner.start() + + def test_start_throws_if_first_cohort_load_fails(self): + flag_api = mock.create_autospec(FlagConfigApi) + cohort_download_api = mock.Mock() + flag_config_storage = mock.Mock() + cohort_storage = mock.Mock() + cohort_loader = CohortLoader(100, cohort_download_api, cohort_storage) + runner = DeploymentRunner( + LocalEvaluationConfig(), + flag_api, flag_config_storage, + cohort_storage, + cohort_loader + ) + flag_api.get_flag_configs.return_value = [self.flag] + cohort_download_api.get_cohort_description.side_effect = RuntimeError("test") + + with self.assertRaises(RuntimeError): + runner.start() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/util/flag_config_test.py b/tests/util/flag_config_test.py new file mode 100644 index 0000000..63d6cf8 --- /dev/null +++ b/tests/util/flag_config_test.py @@ -0,0 +1,125 @@ +import unittest +from typing import List, Dict, Set, Any + +# Assuming the following utility functions are defined in a module named cohort_utils.py +from src.amplitude_experiment.util.flag_config import ( + get_all_cohort_ids_from_flags, + get_grouped_cohort_ids_from_flags, + get_all_cohort_ids, + get_grouped_cohort_ids, +) + + +class CohortUtilsTestCase(unittest.TestCase): + + def setUp(self): + self.flags = [ + { + 'key': 'flag-1', + 'metadata': { + 'deployed': True, + 'evaluationMode': 'local', + 'flagType': 'release', + 'flagVersion': 1 + }, + 'segments': [ + { + 'conditions': [ + [ + { + 'op': 'set contains any', + 'selector': ['context', 'user', 'cohort_ids'], + 'values': ['cohort1', 'cohort2'] + } + ] + ], + 'metadata': {'segmentName': 'Segment A'}, + 'variant': 'on' + }, + { + 'metadata': {'segmentName': 'All Other Users'}, + 'variant': 'off' + } + ], + 'variants': { + 'off': { + 'key': 'off', + 'metadata': {'default': True} + }, + 'on': { + 'key': 'on', + 'value': 'on' + } + } + }, + { + 'key': 'flag-2', + 'metadata': { + 'deployed': True, + 'evaluationMode': 'local', + 'flagType': 'release', + 'flagVersion': 2 + }, + 'segments': [ + { + 'conditions': [ + [ + { + 'op': 'set contains any', + 'selector': ['context', 'user', 'cohort_ids'], + 'values': ['cohort3', 'cohort4', 'cohort5', 'cohort6'] + } + ] + ], + 'metadata': {'segmentName': 'Segment B'}, + 'variant': 'on' + }, + { + 'metadata': {'segmentName': 'All Other Users'}, + 'variant': 'off' + } + ], + 'variants': { + 'off': { + 'key': 'off', + 'metadata': {'default': True} + }, + 'on': { + 'key': 'on', + 'value': 'on' + } + } + } + ] + + def test_get_all_cohort_ids(self): + expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + for flag in self.flags: + cohort_ids = get_all_cohort_ids(flag) + self.assertTrue(cohort_ids.issubset(expected_cohort_ids)) + + def test_get_grouped_cohort_ids(self): + expected_grouped_cohort_ids = { + 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + } + for flag in self.flags: + grouped_cohort_ids = get_grouped_cohort_ids(flag) + for key, values in grouped_cohort_ids.items(): + self.assertTrue(key in expected_grouped_cohort_ids) + self.assertTrue(values.issubset(expected_grouped_cohort_ids[key])) + + def test_get_all_cohort_ids_from_flags(self): + expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + cohort_ids = get_all_cohort_ids_from_flags(self.flags) + self.assertEqual(cohort_ids, expected_cohort_ids) + + def test_get_grouped_cohort_ids_from_flags(self): + expected_grouped_cohort_ids = { + 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + } + grouped_cohort_ids = get_grouped_cohort_ids_from_flags(self.flags) + self.assertEqual(grouped_cohort_ids, expected_grouped_cohort_ids) + + +if __name__ == '__main__': + unittest.main() From eac1cd3554d147779c4dc2f932bb86f5bd120e86 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 6 Jun 2024 16:20:34 -0700 Subject: [PATCH 03/44] update local eval client --- .../deployment/deployment_runner.py | 6 +-- src/amplitude_experiment/local/client.py | 50 +++++++++++++++++-- src/amplitude_experiment/user.py | 20 +++++++- src/amplitude_experiment/util/flag_config.py | 8 +-- src/amplitude_experiment/util/user.py | 6 +-- tests/util/flag_config_test.py | 13 +++-- 6 files changed, 81 insertions(+), 22 deletions(-) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index bd85c62..44777d6 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -9,7 +9,7 @@ from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi from src.amplitude_experiment.flag.flag_config_storage import FlagConfigStorage from src.amplitude_experiment.local.poller import Poller -from src.amplitude_experiment.util.flag_config import get_all_cohort_ids +from src.amplitude_experiment.util.flag_config import get_all_cohort_ids_from_flag class DeploymentRunner: @@ -56,7 +56,7 @@ def refresh(self, initial: bool): if initial: cached_futures = {} for flag_config in flag_configs: - cohort_ids = get_all_cohort_ids(flag_config) + cohort_ids = get_all_cohort_ids_from_flag(flag_config) if not self.cohort_loader or not cohort_ids: self.flag_config_storage.put_flag_config(flag_config) continue @@ -72,7 +72,7 @@ def refresh(self, initial: bool): futures = {} for flag_config in flag_configs: - cohort_ids = get_all_cohort_ids(flag_config) + cohort_ids = get_all_cohort_ids_from_flag(flag_config) if not self.cohort_loader or not cohort_ids: self.flag_config_storage.put_flag_config(flag_config) continue diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 4db9e5b..c9ba2ae 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -8,11 +8,19 @@ from .config import LocalEvaluationConfig from .topological_sort import topological_sort from ..assignment import Assignment, AssignmentFilter, AssignmentService +from ..cohort.cohort_description import USER_GROUP_TYPE +from ..cohort.cohort_download_api import DirectCohortDownloadApiV5 +from ..cohort.cohort_loader import CohortLoader +from ..cohort.cohort_storage import InMemoryCohortStorage +from ..deployment.deployment_runner import DeploymentRunner +from ..flag.flag_config_api import FlagConfigApiV2 +from ..flag.flag_config_storage import InMemoryFlagConfigStorage from ..user import User from ..connection_pool import HTTPConnectionPool from .poller import Poller from .evaluation.evaluation import evaluate from ..util import deprecated +from ..util.flag_config import get_grouped_cohort_ids_from_flags from ..util.user import user_to_evaluation_context from ..util.variant import evaluation_variants_json_to_variants from ..variant import Variant @@ -50,6 +58,17 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): self.flags = None self.poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__do_flags) self.lock = Lock() + self.cohort_storage = InMemoryCohortStorage() + self.flag_config_storage = InMemoryFlagConfigStorage() + if config and config.cohort_sync_config: + direct_cohort_download_api = DirectCohortDownloadApiV5(config.cohort_sync_config.api_key, + config.cohort_sync_config.secret_key) + cohort_loader = CohortLoader(config.cohort_sync_config.max_cohort_size, direct_cohort_download_api, + self.cohort_storage, direct_cohort_download_api) + flag_config_api = FlagConfigApiV2(api_key, config.server_url, + config.flag_config_poller_request_timeout_millis) + self.deployment_runner = DeploymentRunner(config, flag_config_api, self.flag_config_storage, + self.cohort_storage, cohort_loader) def start(self): """ @@ -77,8 +96,12 @@ def evaluate_v2(self, user: User, flag_keys: Set[str] = None) -> Dict[str, Varia if self.flags is None or len(self.flags) == 0: return {} self.logger.debug(f"[Experiment] Evaluate: user={user} - Flags: {self.flags}") - context = user_to_evaluation_context(user) + flag_configs = self.flag_config_storage.get_flag_configs() sorted_flags = topological_sort(self.flags, flag_keys) + if not sorted_flags: + return {} + enriched_user = self.enrich_user(user, flag_configs) + context = user_to_evaluation_context(enriched_user) flags_json = json.dumps(sorted_flags) context_json = json.dumps(context) result_json = evaluate(flags_json, context_json) @@ -167,5 +190,26 @@ def is_default_variant(variant: Variant) -> bool: return {key: variant for key, variant in variants.items() if not is_default_variant(variant)} - - + def enrich_user(self, user: User, flag_configs: Dict) -> User: + grouped_cohort_ids = get_grouped_cohort_ids_from_flags(list(flag_configs.values())) + + if USER_GROUP_TYPE in grouped_cohort_ids: + user_cohort_ids = grouped_cohort_ids[USER_GROUP_TYPE] + if user_cohort_ids and user.user_id: + user.cohort_ids = self.cohort_storage.get_cohorts_for_user(user.user_id, user_cohort_ids) + + if user.groups: + for group_type, group_names in user.groups.items(): + group_name = group_names[0] if group_names else None + if not group_name: + continue + cohort_ids = grouped_cohort_ids.get(group_type, []) + if not cohort_ids: + continue + user.add_group_cohort_ids( + group_type, + group_name, + self.cohort_storage.get_cohorts_for_group(group_type, group_name, cohort_ids) + ) + + return user diff --git a/src/amplitude_experiment/user.py b/src/amplitude_experiment/user.py index eed02c7..6dd6561 100644 --- a/src/amplitude_experiment/user.py +++ b/src/amplitude_experiment/user.py @@ -1,6 +1,6 @@ import json -from typing import Dict, Any +from typing import Dict, Any, Set class User: @@ -28,7 +28,8 @@ def __init__( library: str = None, user_properties: Dict[str, Any] = None, groups: Dict[str, str] = None, - group_properties: Dict[str, Dict[str, Dict[str, Any]]] = None + group_properties: Dict[str, Dict[str, Dict[str, Any]]] = None, + group_cohort_ids: Dict[str, Dict[str, Set[str]]] = None ): """ Initialize User instance @@ -73,6 +74,7 @@ def __init__( self.user_properties = user_properties self.groups = groups self.group_properties = group_properties + self.group_cohort_ids = group_cohort_ids def to_json(self): """Return user information as JSON string.""" @@ -81,3 +83,17 @@ def to_json(self): def __str__(self): """Return user as string""" return self.to_json() + + def add_group_cohort_ids(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> None: + """ + Add cohort IDs for a group. + Parameters: + group_type (str): The type of the group + group_name (str): The name of the group + cohort_ids (Set[str]): Set of cohort IDs associated with the group + """ + if self.group_cohort_ids is None: + self.group_cohort_ids = {} + + group_names = self.group_cohort_ids.setdefault(group_type, {}) + group_names[group_name] = cohort_ids diff --git a/src/amplitude_experiment/util/flag_config.py b/src/amplitude_experiment/util/flag_config.py index b18791f..9215879 100644 --- a/src/amplitude_experiment/util/flag_config.py +++ b/src/amplitude_experiment/util/flag_config.py @@ -29,7 +29,7 @@ def get_grouped_cohort_condition_ids(segment: Dict[str, Any]) -> Dict[str, Set[s return cohort_ids -def get_grouped_cohort_ids(flag: Dict[str, Any]) -> Dict[str, Set[str]]: +def get_grouped_cohort_ids_from_flag(flag: Dict[str, Any]) -> Dict[str, Set[str]]: cohort_ids = {} segments = flag.get('segments', []) for segment in segments: @@ -38,14 +38,14 @@ def get_grouped_cohort_ids(flag: Dict[str, Any]) -> Dict[str, Set[str]]: return cohort_ids -def get_all_cohort_ids(flag: Dict[str, Any]) -> Set[str]: - return {cohort_id for values in get_grouped_cohort_ids(flag).values() for cohort_id in values} +def get_all_cohort_ids_from_flag(flag: Dict[str, Any]) -> Set[str]: + return {cohort_id for values in get_grouped_cohort_ids_from_flag(flag).values() for cohort_id in values} def get_grouped_cohort_ids_from_flags(flags: List[Dict[str, Any]]) -> Dict[str, Set[str]]: cohort_ids = {} for flag in flags: - for key, values in get_grouped_cohort_ids(flag).items(): + for key, values in get_grouped_cohort_ids_from_flag(flag).items(): cohort_ids.setdefault(key, set()).update(values) return cohort_ids diff --git a/src/amplitude_experiment/util/user.py b/src/amplitude_experiment/util/user.py index 93e3fdd..01aa779 100644 --- a/src/amplitude_experiment/util/user.py +++ b/src/amplitude_experiment/util/user.py @@ -15,16 +15,16 @@ def user_to_evaluation_context(user: User) -> Dict[str, Any]: groups: Dict[str, Dict[str, Any]] = {} for group_type in user_groups: group_name = user_groups[group_type] - if type(group_name) == list and len(group_name) > 0: + if isinstance(group_name, list) and len(group_name) > 0: group_name = group_name[0] groups[group_type] = {'group_name': group_name} if user_group_properties is None: continue group_properties_type = user_group_properties[group_type] - if group_properties_type is None or type(group_properties_type) != dict: + if group_properties_type is None or isinstance(group_properties_type, dict): continue group_properties_name = group_properties_type[group_name] - if group_properties_name is None or type(group_properties_name) != dict: + if group_properties_name is None or isinstance(group_properties_name, dict): continue groups[group_type]['group_properties'] = group_properties_name context['groups'] = groups diff --git a/tests/util/flag_config_test.py b/tests/util/flag_config_test.py index 63d6cf8..94da095 100644 --- a/tests/util/flag_config_test.py +++ b/tests/util/flag_config_test.py @@ -1,12 +1,11 @@ import unittest -from typing import List, Dict, Set, Any # Assuming the following utility functions are defined in a module named cohort_utils.py from src.amplitude_experiment.util.flag_config import ( get_all_cohort_ids_from_flags, get_grouped_cohort_ids_from_flags, - get_all_cohort_ids, - get_grouped_cohort_ids, + get_all_cohort_ids_from_flag, + get_grouped_cohort_ids_for_flag, ) @@ -95,15 +94,15 @@ def setUp(self): def test_get_all_cohort_ids(self): expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} for flag in self.flags: - cohort_ids = get_all_cohort_ids(flag) + cohort_ids = get_all_cohort_ids_from_flag(flag) self.assertTrue(cohort_ids.issubset(expected_cohort_ids)) - def test_get_grouped_cohort_ids(self): + def test_get_grouped_cohort_ids_for_flag(self): expected_grouped_cohort_ids = { 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} } for flag in self.flags: - grouped_cohort_ids = get_grouped_cohort_ids(flag) + grouped_cohort_ids = get_grouped_cohort_ids_for_flag(flag) for key, values in grouped_cohort_ids.items(): self.assertTrue(key in expected_grouped_cohort_ids) self.assertTrue(values.issubset(expected_grouped_cohort_ids[key])) @@ -113,7 +112,7 @@ def test_get_all_cohort_ids_from_flags(self): cohort_ids = get_all_cohort_ids_from_flags(self.flags) self.assertEqual(cohort_ids, expected_cohort_ids) - def test_get_grouped_cohort_ids_from_flags(self): + def test_get_grouped_cohort_ids_for_flag_from_flags(self): expected_grouped_cohort_ids = { 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} } From 7f864d554e6febe1ee179236f0eabb2f2eed4890 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 6 Jun 2024 16:36:54 -0700 Subject: [PATCH 04/44] fix imports --- .../cohort/cohort_download_api.py | 6 +++--- src/amplitude_experiment/cohort/cohort_loader.py | 6 +++--- src/amplitude_experiment/cohort/cohort_storage.py | 2 +- .../deployment/deployment_runner.py | 14 +++++++------- src/amplitude_experiment/flag/flag_config_api.py | 2 +- src/amplitude_experiment/util/flag_config.py | 2 +- tests/util/flag_config_test.py | 4 ++-- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index ff0894a..ee59fcb 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -6,9 +6,9 @@ from io import StringIO from typing import Set -from src.amplitude_experiment.cohort.cohort_description import CohortDescription, USER_GROUP_TYPE -from src.amplitude_experiment.connection_pool import HTTPConnectionPool -from src.amplitude_experiment.exception import CachedCohortDownloadException, HTTPErrorResponseException +from .cohort_description import CohortDescription, USER_GROUP_TYPE +from ..connection_pool import HTTPConnectionPool +from ..exception import CachedCohortDownloadException, HTTPErrorResponseException CDN_COHORT_SYNC_URL = 'https://cohort.lab.amplitude.com' diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index 10ad7a1..bf68046 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -2,9 +2,9 @@ from concurrent.futures import ThreadPoolExecutor, Future import threading -from src.amplitude_experiment.cohort.cohort_description import CohortDescription -from src.amplitude_experiment.cohort.cohort_download_api import CohortDownloadApi, DirectCohortDownloadApiV5 -from src.amplitude_experiment.cohort.cohort_storage import CohortStorage +from .cohort_description import CohortDescription +from .cohort_download_api import CohortDownloadApi, DirectCohortDownloadApiV5 +from .cohort_storage import CohortStorage class CohortLoader: diff --git a/src/amplitude_experiment/cohort/cohort_storage.py b/src/amplitude_experiment/cohort/cohort_storage.py index 479aba8..c8a3430 100644 --- a/src/amplitude_experiment/cohort/cohort_storage.py +++ b/src/amplitude_experiment/cohort/cohort_storage.py @@ -1,7 +1,7 @@ from typing import Dict, Set, Optional from threading import RLock -from src.amplitude_experiment.cohort.cohort_description import CohortDescription, USER_GROUP_TYPE +from .cohort_description import CohortDescription, USER_GROUP_TYPE class CohortStorage: diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 44777d6..f8c6179 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -3,13 +3,13 @@ import threading import time -from src.amplitude_experiment import LocalEvaluationConfig -from src.amplitude_experiment.cohort.cohort_loader import CohortLoader -from src.amplitude_experiment.cohort.cohort_storage import CohortStorage -from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi -from src.amplitude_experiment.flag.flag_config_storage import FlagConfigStorage -from src.amplitude_experiment.local.poller import Poller -from src.amplitude_experiment.util.flag_config import get_all_cohort_ids_from_flag +from ..local.config import LocalEvaluationConfig +from ..cohort.cohort_loader import CohortLoader +from ..cohort.cohort_storage import CohortStorage +from ..flag.flag_config_api import FlagConfigApi +from ..flag.flag_config_storage import FlagConfigStorage +from ..local.poller import Poller +from ..util.flag_config import get_all_cohort_ids_from_flag class DeploymentRunner: diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index 10b84bb..2c3c99d 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -3,7 +3,7 @@ from ..version import __version__ -from src.amplitude_experiment.connection_pool import HTTPConnectionPool +from ..connection_pool import HTTPConnectionPool class FlagConfigApi: diff --git a/src/amplitude_experiment/util/flag_config.py b/src/amplitude_experiment/util/flag_config.py index 9215879..1697856 100644 --- a/src/amplitude_experiment/util/flag_config.py +++ b/src/amplitude_experiment/util/flag_config.py @@ -1,6 +1,6 @@ from typing import List, Dict, Set, Any -from src.amplitude_experiment.cohort.cohort_description import USER_GROUP_TYPE +from ..cohort.cohort_description import USER_GROUP_TYPE def is_cohort_filter(condition: Dict[str, Any]) -> bool: diff --git a/tests/util/flag_config_test.py b/tests/util/flag_config_test.py index 94da095..d4b2772 100644 --- a/tests/util/flag_config_test.py +++ b/tests/util/flag_config_test.py @@ -5,7 +5,7 @@ get_all_cohort_ids_from_flags, get_grouped_cohort_ids_from_flags, get_all_cohort_ids_from_flag, - get_grouped_cohort_ids_for_flag, + get_grouped_cohort_ids_from_flag, ) @@ -102,7 +102,7 @@ def test_get_grouped_cohort_ids_for_flag(self): 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} } for flag in self.flags: - grouped_cohort_ids = get_grouped_cohort_ids_for_flag(flag) + grouped_cohort_ids = get_grouped_cohort_ids_from_flag(flag) for key, values in grouped_cohort_ids.items(): self.assertTrue(key in expected_grouped_cohort_ids) self.assertTrue(values.issubset(expected_grouped_cohort_ids[key])) From 3ef92da236beecb83219d122e38cc24e5fceeb60 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 10 Jun 2024 15:44:06 -0700 Subject: [PATCH 05/44] refactor --- .../cohort/cohort_download_api.py | 167 ++++--------- .../cohort/cohort_loader.py | 58 ++--- .../cohort/cohort_sync_config.py | 1 - .../deployment/deployment_runner.py | 78 +++--- src/amplitude_experiment/exception.py | 2 +- .../flag/flag_config_api.py | 1 + src/amplitude_experiment/local/client.py | 54 ++-- src/amplitude_experiment/user.py | 8 +- tests/cohort/cohort_download_api_test.py | 232 ++++++------------ tests/cohort/cohort_loader_test.py | 34 +-- tests/deployment/deployment_runner_test.py | 7 +- tests/util/flag_config_test.py | 1 - 12 files changed, 225 insertions(+), 418 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index ee59fcb..cd15c08 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -2,15 +2,14 @@ import logging import base64 import json -import csv -from io import StringIO +from http.client import HTTPResponse from typing import Set -from .cohort_description import CohortDescription, USER_GROUP_TYPE +from .cohort_description import CohortDescription from ..connection_pool import HTTPConnectionPool -from ..exception import CachedCohortDownloadException, HTTPErrorResponseException +from ..exception import HTTPErrorResponseException, CohortTooLargeException -CDN_COHORT_SYNC_URL = 'https://cohort.lab.amplitude.com' +CDN_COHORT_SYNC_URL = 'https://api.lab.amplitude.com' class CohortDownloadApi: @@ -24,144 +23,76 @@ def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: raise NotImplementedError -class DirectCohortDownloadApiV5(CohortDownloadApi): - def __init__(self, api_key: str, secret_key: str): +class DirectCohortDownloadApi(CohortDownloadApi): + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, request_status_delay: int = 5, debug: bool = False): super().__init__() self.api_key = api_key self.secret_key = secret_key + self.max_cohort_size = max_cohort_size self.__setup_connection_pool() - self.request_status_delay = 2 # seconds, adjust as necessary + self.request_status_delay = request_status_delay + self.logger = logging.getLogger("Amplitude") + self.logger.addHandler(logging.StreamHandler()) + if debug: + self.logger.setLevel(logging.DEBUG) def get_cohort_description(self, cohort_id: str) -> CohortDescription: response = self.get_cohort_info(cohort_id) - cohort_info = json.loads(response.read().decode("utf8")) + cohort_info = json.loads(response.read().decode("utf-8")) return CohortDescription( - id=cohort_info['cohort_id'], - last_computed=cohort_info['last_computed'], + id=cohort_info['cohortId'], + last_computed=cohort_info['lastComputed'], size=cohort_info['size'], - group_type=cohort_info['group_type'], + group_type=cohort_info['groupType'], ) - def get_cohort_info(self, cohort_id: str): + def get_cohort_info(self, cohort_id: str) -> HTTPResponse: conn = self._connection_pool.acquire() try: - return conn.request('GET', f'api/3/cohorts/info/{cohort_id}', + return conn.request('GET', f'/sdk/v1/cohort/{cohort_id}?skipCohortDownload=true', headers={'Authorization': f'Basic {self._get_basic_auth()}'}) finally: self._connection_pool.release(conn) def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: - try: - logging.debug(f"getCohortMembers({cohort_description.id}): start - {cohort_description}") - initial_response = self._get_cohort_async_request(cohort_description) - request_id = json.loads(initial_response.read().decode('utf-8'))['request_id'] - logging.debug(f"getCohortMembers({cohort_description.id}): requestId={request_id}") - - errors = 0 - while True: - try: - status_response = self._get_cohort_async_request_status(request_id) - logging.debug(f"getCohortMembers({cohort_description.id}): status={status_response.status}") - if status_response.status == 200: - break - elif status_response.status != 202: - raise HTTPErrorResponseException(status_response.status, - f"Unexpected response code: {status_response.status}") - except Exception as e: - if not isinstance(e, HTTPErrorResponseException) or e.status_code != 429: - errors += 1 - logging.debug(f"getCohortMembers({cohort_description.id}): request-status error {errors} - {e}") - if errors >= 3: - raise e - time.sleep(self.request_status_delay) - - location = self._get_cohort_async_request_location(request_id) - members = self._get_cohort_async_request_members(cohort_description.id, cohort_description.group_type, - location) - logging.debug(f"getCohortMembers({cohort_description.id}): end - resultSize={len(members)}") - return members - except Exception as e1: + self.logger.debug(f"getCohortMembers({cohort_description.id}): start - {cohort_description}") + errors = 0 + while True: + response = None try: - cached_members = self._get_cached_cohort_members(cohort_description.id, cohort_description.group_type) - logging.debug( - f"getCohortMembers({cohort_description.id}): end cached fallback - resultSize={len(cached_members)}") - raise CachedCohortDownloadException(cached_members, e1) - except Exception as e2: - raise e2 - - def _get_cohort_async_request(self, cohort_description: CohortDescription): - conn = self._connection_pool.acquire() - try: - return conn.request('GET', f'api/5/cohorts/request/{cohort_description.id}', - headers={'Authorization': f'Basic {self._get_basic_auth()}'}, - queries={'lastComputed': str(cohort_description.last_computed)}) - finally: - self._connection_pool.release(conn) - - def _get_cohort_async_request_status(self, request_id: str): - conn = self._connection_pool.acquire() - try: - return conn.request('GET', f'api/5/cohorts/request-status/{request_id}', - headers={'Authorization': f'Basic {self._get_basic_auth()}'}) - finally: - self._connection_pool.release(conn) - - def _get_cohort_async_request_location(self, request_id: str): - conn = self._connection_pool.acquire() - try: - response = conn.request('GET', f'api/5/cohorts/request-status/{request_id}/file', - headers={'Authorization': f'Basic {self._get_basic_auth()}'}) - location_header = response.headers.get('location') - if not location_header: - raise ValueError('Cohort response location must not be null') - return location_header - finally: - self._connection_pool.release(conn) - - def _get_cohort_async_request_members(self, cohort_id: str, group_type: str, location: str) -> Set[str]: - headers = { - 'X-Amp-Authorization': f'Basic {self._get_basic_auth()}', - 'X-Cohort-ID': cohort_id, - } - conn = self._connection_pool.acquire() - try: - response = conn.request('GET', location, headers) - return self._parse_csv_response(response.read(), group_type) - finally: - self._connection_pool.release(conn) - - def get_cached_cohort_members(self, cohort_id: str, group_type: str) -> Set[str]: + response = self._get_cohort_members_request(cohort_description.id) + self.logger.debug(f"getCohortMembers({cohort_description.id}): status={response.status}") + if response.status == 200: + response_json = json.loads(response.read().decode("utf8")) + members = set(response_json['memberIds']) + self.logger.debug(f"getCohortMembers({cohort_description.id}): end - resultSize={len(members)}") + return members + elif response.status == 413: + raise CohortTooLargeException(response.status, + f"Cohort exceeds max cohort size: {response.status}") + elif response.status != 202: + raise HTTPErrorResponseException(response.status, + f"Unexpected response code: {response.status}") + except Exception as e: + if not isinstance(e, HTTPErrorResponseException) and response.status != 429: + errors += 1 + self.logger.debug(f"getCohortMembers({cohort_description.id}): request-status error {errors} - {e}") + if errors >= 3 or isinstance(e, CohortTooLargeException): + raise e + time.sleep(self.request_status_delay) + + def _get_cohort_members_request(self, cohort_id: str) -> HTTPResponse: headers = { - 'X-Amp-Authorization': f'Basic {self._get_basic_auth()}', - 'X-Cohort-ID': cohort_id, + 'Authorization': f'Basic {self._get_basic_auth()}', } conn = self._connection_pool.acquire() try: - response = conn.request('GET', 'cohorts', headers) - input_stream = response.read() - if not input_stream: - raise ValueError('Cohort response body must not be null') - return self._parse_csv_response(input_stream, group_type) + response = conn.request('GET', f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}', + headers=headers) + return response finally: self._connection_pool.release(conn) - @staticmethod - def _parse_csv_response(input_stream: bytes, group_type: str) -> Set[str]: - csv_file = StringIO(input_stream.decode('utf-8')) - csv_data = list(csv.DictReader(csv_file)) - if group_type == USER_GROUP_TYPE: - return {row['user_id'] for row in csv_data if row['user_id']} - else: - values = set() - for row in csv_data: - try: - value = row.get('\tgroup_value', row.get('group_value')) - if value: - values.add(value.lstrip("\t")) - except ValueError: - pass - return values - def _get_basic_auth(self) -> str: credentials = f'{self.api_key}:{self.secret_key}' return base64.b64encode(credentials.encode('utf-8')).decode('utf-8') @@ -169,5 +100,5 @@ def _get_basic_auth(self) -> str: def __setup_connection_pool(self): scheme, _, host = self.cdn_server_url.split('/', 3) timeout = 10 - self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, read_timeout=timeout, + self._connection_pool = HTTPConnectionPool(host, max_size=50, idle_timeout=30, read_timeout=timeout, scheme=scheme) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index bf68046..0a4f297 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -1,26 +1,18 @@ -from typing import Dict, Set, Optional +from typing import Dict, Set from concurrent.futures import ThreadPoolExecutor, Future import threading from .cohort_description import CohortDescription -from .cohort_download_api import CohortDownloadApi, DirectCohortDownloadApiV5 +from .cohort_download_api import CohortDownloadApi from .cohort_storage import CohortStorage class CohortLoader: - def __init__(self, max_cohort_size: int, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage, - direct_cohort_download_api: Optional[DirectCohortDownloadApiV5] = None): - self.max_cohort_size = max_cohort_size + def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage): self.cohort_download_api = cohort_download_api self.cohort_storage = cohort_storage - self.direct_cohort_download_api = direct_cohort_download_api - self.jobs: Dict[str, Future] = {} - self.cached_jobs: Dict[str, Future] = {} - self.lock_jobs = threading.Lock() - self.lock_cached_jobs = threading.Lock() - self.executor = ThreadPoolExecutor( max_workers=32, thread_name_prefix='CohortLoaderExecutor' @@ -30,48 +22,30 @@ def load_cohort(self, cohort_id: str) -> Future: with self.lock_jobs: if cohort_id not in self.jobs: def task(): - print(f"Loading cohort {cohort_id}") - cohort_description = self.get_cohort_description(cohort_id) - if self.should_download_cohort(cohort_description): - cohort_members = self.download_cohort(cohort_description) - self.cohort_storage.put_cohort(cohort_description, cohort_members) + try: + cohort_description = self.get_cohort_description(cohort_id) + if self.should_download_cohort(cohort_description): + cohort_members = self.download_cohort(cohort_description) + self.cohort_storage.put_cohort(cohort_description, cohort_members) + except Exception as e: + print(f"Failed to load cohort {cohort_id}: {e}") future = self.executor.submit(task) - future.add_done_callback(lambda _: self.jobs.pop(cohort_id, None)) + future.add_done_callback(lambda f: self._remove_job(f, cohort_id)) self.jobs[cohort_id] = future return self.jobs[cohort_id] - def load_cached_cohort(self, cohort_id: str) -> Future: - with self.lock_cached_jobs: - if cohort_id not in self.cached_jobs: - def task(): - print(f"Loading cohort from cache {cohort_id}") - cohort_description = self.get_cohort_description(cohort_id) - cohort_description.last_computed = 0 - if self.should_download_cohort(cohort_description): - cohort_members = self.download_cached_cohort(cohort_description) - self.cohort_storage.put_cohort(cohort_description, cohort_members) - - future = self.executor.submit(task) - self.cached_jobs[cohort_id] = future - future.add_done_callback(lambda _: self.cached_jobs.pop(cohort_id, None)) - return future - else: - return self.cached_jobs[cohort_id] + def _remove_job(self, future: Future, cohort_id: str): + with self.lock_jobs: + if cohort_id in self.jobs: + del self.jobs[cohort_id] def get_cohort_description(self, cohort_id: str) -> CohortDescription: return self.cohort_download_api.get_cohort_description(cohort_id) def should_download_cohort(self, cohort_description: CohortDescription) -> bool: storage_description = self.cohort_storage.get_cohort_description(cohort_description.id) - return (cohort_description.size <= self.max_cohort_size and - cohort_description.last_computed > (storage_description.last_computed if storage_description else -1)) + return cohort_description.last_computed > (storage_description.last_computed if storage_description else -1) def download_cohort(self, cohort_description: CohortDescription) -> Set[str]: return self.cohort_download_api.get_cohort_members(cohort_description) - - def download_cached_cohort(self, cohort_description: CohortDescription) -> Set[str]: - return (self.direct_cohort_download_api.get_cached_cohort_members(cohort_description.id, - cohort_description.group_type) - if self.direct_cohort_download_api else - self.cohort_download_api.get_cohort_members(cohort_description)) diff --git a/src/amplitude_experiment/cohort/cohort_sync_config.py b/src/amplitude_experiment/cohort/cohort_sync_config.py index b609bed..ba32659 100644 --- a/src/amplitude_experiment/cohort/cohort_sync_config.py +++ b/src/amplitude_experiment/cohort/cohort_sync_config.py @@ -3,4 +3,3 @@ def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000): self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size - diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index f8c6179..0e1c203 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -1,5 +1,6 @@ import logging -from typing import Optional +from concurrent.futures import Future +from typing import Optional, Set import threading import time @@ -29,6 +30,9 @@ def __init__( self.lock = threading.Lock() self.poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_refresh) self.logger = logging.getLogger("Amplitude") + self.logger.addHandler(logging.StreamHandler()) + if self.config.debug: + self.logger.setLevel(logging.DEBUG) def start(self): with self.lock: @@ -43,52 +47,60 @@ def __periodic_refresh(self): try: self.refresh(initial=False) except Exception as e: - self.logger.error("Refresh flag configs failed.", e) + self.logger.error("Refresh flag configs failed.", exc_info=e) time.sleep(self.config.flag_config_polling_interval_millis / 1000) - def refresh(self, initial: bool): + def refresh(self, initial: bool = False): self.logger.debug("Refreshing flag configs.") flag_configs = self.flag_config_api.get_flag_configs() - flag_keys = {flag['key'] for flag in flag_configs} self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) - if initial: - cached_futures = {} - for flag_config in flag_configs: - cohort_ids = get_all_cohort_ids_from_flag(flag_config) - if not self.cohort_loader or not cohort_ids: - self.flag_config_storage.put_flag_config(flag_config) - continue - for cohort_id in cohort_ids: - future = self.cohort_loader.load_cached_cohort(cohort_id) - future.add_done_callback(lambda _: self.flag_config_storage.put_flag_config(flag_config)) - cached_futures[cohort_id] = future - try: - for future in cached_futures.values(): - future.result() - except Exception as e: - self.logger.warning("Failed to download a cohort from the cache", e) - - futures = {} + futures = [] for flag_config in flag_configs: cohort_ids = get_all_cohort_ids_from_flag(flag_config) if not self.cohort_loader or not cohort_ids: + self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") self.flag_config_storage.put_flag_config(flag_config) continue - for cohort_id in cohort_ids: - future = self.cohort_loader.load_cohort(cohort_id) - future.add_done_callback(lambda _: self.flag_config_storage.put_flag_config(flag_config)) - futures[cohort_id] = future + future = self._load_cohorts_and_store_flag(flag_config, cohort_ids) + futures.append(future) + if initial: - for future in futures.values(): - future.result() + try: + for future in futures: + self.logger.debug(f"Waiting for future {future}") + future.result() + except Exception as e: + self.logger.warning("Failed to download cohort", exc_info=e) + + self._delete_unused_cohorts() + self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") + + def _load_cohorts_and_store_flag(self, flag_config: dict, cohort_ids: Set[str]) -> Future: + def task(): + try: + for cohort_id in cohort_ids: + future = self.cohort_loader.load_cohort(cohort_id) + future.result() # Wait for cohort to load + self.logger.debug(f"Cohort {cohort_id} loaded for flag {flag_config['key']}") + self.flag_config_storage.put_flag_config(flag_config) + self.logger.debug(f"Flag config {flag_config['key']} stored successfully.") + except Exception as e: + self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}", exc_info=e) + + return self.cohort_loader.executor.submit(task) + + def _delete_unused_cohorts(self): + flag_cohort_ids = set() + for flag in self.flag_config_storage.get_flag_configs().values(): + flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag)) + + storage_cohorts = self.cohort_storage.get_cohort_descriptions() + deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids - flag_cohort_ids = {flag['key'] for flag in self.flag_config_storage.get_flag_configs().values()} - deleted_cohort_ids = set(self.cohort_storage.get_cohort_descriptions().keys()) - flag_cohort_ids for deleted_cohort_id in deleted_cohort_ids: - deleted_cohort_description = self.cohort_storage.get_cohort_description(deleted_cohort_id) - if deleted_cohort_description: + deleted_cohort_description = storage_cohorts.get(deleted_cohort_id) + if deleted_cohort_description is not None: self.cohort_storage.delete_cohort(deleted_cohort_description.group_type, deleted_cohort_id) - self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index defdca3..53b6e06 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -4,7 +4,7 @@ def __init__(self, status_code, message): self.status_code = status_code -class CachedCohortDownloadException(Exception): +class CohortTooLargeException(Exception): def __init__(self, cached_members, message): super().__init__(message) self.cached_members = cached_members diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py index 2c3c99d..15db645 100644 --- a/src/amplitude_experiment/flag/flag_config_api.py +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -16,6 +16,7 @@ def __init__(self, deployment_key: str, server_url: str, flag_config_poller_requ self.deployment_key = deployment_key self.server_url = server_url self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis + self.__setup_connection_pool() def get_flag_configs(self) -> List: return self._get_flag_configs() diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index c9ba2ae..63a973e 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -9,7 +9,7 @@ from .topological_sort import topological_sort from ..assignment import Assignment, AssignmentFilter, AssignmentService from ..cohort.cohort_description import USER_GROUP_TYPE -from ..cohort.cohort_download_api import DirectCohortDownloadApiV5 +from ..cohort.cohort_download_api import DirectCohortDownloadApi from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import InMemoryCohortStorage from ..deployment.deployment_runner import DeploymentRunner @@ -55,16 +55,15 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): if self.config.debug: self.logger.setLevel(logging.DEBUG) self.__setup_connection_pool() - self.flags = None - self.poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__do_flags) self.lock = Lock() self.cohort_storage = InMemoryCohortStorage() self.flag_config_storage = InMemoryFlagConfigStorage() if config and config.cohort_sync_config: - direct_cohort_download_api = DirectCohortDownloadApiV5(config.cohort_sync_config.api_key, - config.cohort_sync_config.secret_key) - cohort_loader = CohortLoader(config.cohort_sync_config.max_cohort_size, direct_cohort_download_api, - self.cohort_storage, direct_cohort_download_api) + cohort_download_api = DirectCohortDownloadApi(config.cohort_sync_config.api_key, + config.cohort_sync_config.secret_key, + config.cohort_sync_config.max_cohort_size, + self.config.debug) + cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) flag_config_api = FlagConfigApiV2(api_key, config.server_url, config.flag_config_poller_request_timeout_millis) self.deployment_runner = DeploymentRunner(config, flag_config_api, self.flag_config_storage, @@ -75,8 +74,7 @@ def start(self): Fetch initial flag configurations and start polling for updates. You must call this function to begin polling for flag config updates. """ - self.__do_flags() - self.poller.start() + self.deployment_runner.start() def evaluate_v2(self, user: User, flag_keys: Set[str] = None) -> Dict[str, Variant]: """ @@ -93,11 +91,12 @@ def evaluate_v2(self, user: User, flag_keys: Set[str] = None) -> Dict[str, Varia Returns: The evaluated variants. """ - if self.flags is None or len(self.flags) == 0: + flag_configs = self.flag_config_storage.get_flag_configs() + if flag_configs is None or len(flag_configs) == 0: return {} - self.logger.debug(f"[Experiment] Evaluate: user={user} - Flags: {self.flags}") + self.logger.debug(f"[Experiment] Evaluate: user={user} - Flags: {flag_configs}") flag_configs = self.flag_config_storage.get_flag_configs() - sorted_flags = topological_sort(self.flags, flag_keys) + sorted_flags = topological_sort(flag_configs, flag_keys) if not sorted_flags: return {} enriched_user = self.enrich_user(user, flag_configs) @@ -138,30 +137,6 @@ def evaluate(self, user: User, flag_keys: List[str] = None) -> Dict[str, Variant variants = self.evaluate_v2(user, flag_keys) return self.__filter_default_variants(variants) - def __do_flags(self): - conn = self._connection_pool.acquire() - headers = { - 'Authorization': f"Api-Key {self.api_key}", - 'Content-Type': 'application/json;charset=utf-8', - 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" - } - body = None - self.logger.debug('[Experiment] Get flag configs') - try: - response = conn.request('GET', '/sdk/v2/flags?v=0', body, headers) - response_body = response.read().decode("utf8") - if response.status != 200: - raise Exception( - f"[Experiment] Get flagConfigs - received error response: ${response.status}: ${response_body}") - flags = json.loads(response_body) - flags_dict = {flag['key']: flag for flag in flags} - self.logger.debug(f"[Experiment] Got flag configs: {flags}") - self.lock.acquire() - self.flags = flags_dict - self.lock.release() - finally: - self._connection_pool.release(conn) - def __setup_connection_pool(self): scheme, _, host = self.config.server_url.split('/', 3) timeout = self.config.flag_config_poller_request_timeout_millis / 1000 @@ -172,7 +147,7 @@ def stop(self) -> None: """ Stop polling for flag configurations. Close resource like connection pool with client """ - self.poller.stop() + self.deployment_runner.stop() self._connection_pool.close() def __enter__(self) -> 'LocalEvaluationClient': @@ -196,7 +171,7 @@ def enrich_user(self, user: User, flag_configs: Dict) -> User: if USER_GROUP_TYPE in grouped_cohort_ids: user_cohort_ids = grouped_cohort_ids[USER_GROUP_TYPE] if user_cohort_ids and user.user_id: - user.cohort_ids = self.cohort_storage.get_cohorts_for_user(user.user_id, user_cohort_ids) + user.cohort_ids = list(self.cohort_storage.get_cohorts_for_user(user.user_id, user_cohort_ids)) if user.groups: for group_type, group_names in user.groups.items(): @@ -209,7 +184,6 @@ def enrich_user(self, user: User, flag_configs: Dict) -> User: user.add_group_cohort_ids( group_type, group_name, - self.cohort_storage.get_cohorts_for_group(group_type, group_name, cohort_ids) + list(self.cohort_storage.get_cohorts_for_group(group_type, group_name, cohort_ids)) ) - return user diff --git a/src/amplitude_experiment/user.py b/src/amplitude_experiment/user.py index 6dd6561..3935a62 100644 --- a/src/amplitude_experiment/user.py +++ b/src/amplitude_experiment/user.py @@ -1,6 +1,6 @@ import json -from typing import Dict, Any, Set +from typing import Dict, Any, Set, List class User: @@ -29,7 +29,8 @@ def __init__( user_properties: Dict[str, Any] = None, groups: Dict[str, str] = None, group_properties: Dict[str, Dict[str, Dict[str, Any]]] = None, - group_cohort_ids: Dict[str, Dict[str, Set[str]]] = None + group_cohort_ids: Dict[str, Dict[str, List[str]]] = None, + cohort_ids: List[str] = None ): """ Initialize User instance @@ -75,6 +76,7 @@ def __init__( self.groups = groups self.group_properties = group_properties self.group_cohort_ids = group_cohort_ids + self.cohort_ids = cohort_ids def to_json(self): """Return user information as JSON string.""" @@ -84,7 +86,7 @@ def __str__(self): """Return user as string""" return self.to_json() - def add_group_cohort_ids(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> None: + def add_group_cohort_ids(self, group_type: str, group_name: str, cohort_ids: List[str]) -> None: """ Add cohort IDs for a group. Parameters: diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index 9ef0ed2..e2dc953 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -1,179 +1,107 @@ import json import unittest -from unittest.mock import MagicMock -from src.amplitude_experiment.cohort.cohort_description import CohortDescription, USER_GROUP_TYPE -from src.amplitude_experiment.exception import CachedCohortDownloadException -from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApiV5 -from urllib.parse import urlparse +from unittest.mock import MagicMock, patch +from src.amplitude_experiment.cohort.cohort_description import CohortDescription +from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi +from src.amplitude_experiment.exception import CohortTooLargeException -def response(code: int): +def response(code: int, body: dict = None): mock_response = MagicMock() mock_response.status = code - mock_response.headers = {'location': 'https://example.com/cohorts/Cohort_asdf?asdf=asdf#asdf'} + if body is not None: + mock_response.read.return_value = json.dumps(body).encode() return mock_response class CohortDownloadApiTest(unittest.TestCase): - location = 'https://example.com/cohorts/Cohort_asdf?asdf=asdf#asdf' + + def setUp(self): + self.api = DirectCohortDownloadApi('api', 'secret', 15000, 1) def test_cohort_download_success(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1) - async_request_response = MagicMock() - async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() - - async_request_status_response = response(200) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(return_value=async_request_response) - api._get_cohort_async_request_status = MagicMock(return_value=async_request_status_response) - api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) - api._get_cohort_async_request_members = MagicMock(return_value={'user'}) - - members = api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) - api._get_cohort_async_request.assert_called_once_with(cohort) - api._get_cohort_async_request_status.assert_called_once_with('4321') - api._get_cohort_async_request_location.assert_called_once_with('4321') - api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) + cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) + members_response = response(200, {'memberIds': ['user']}) + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', return_value=members_response): + + members = self.api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) def test_cohort_download_many_202s_success(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1) - async_request_response = MagicMock() - async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() - - async_request_status_202_response = response(202) - async_request_status_200_response = response(200) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(return_value=async_request_response) - api._get_cohort_async_request_status = MagicMock( - side_effect=[async_request_status_202_response] * 9 + [async_request_status_200_response]) - api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) - api._get_cohort_async_request_members = MagicMock(return_value={'user'}) - - members = api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) - api._get_cohort_async_request.assert_called_once_with(cohort) - self.assertEqual(api._get_cohort_async_request_status.call_count, 10) - api._get_cohort_async_request_location.assert_called_once_with('4321') - api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) + cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) + members_response = response(200, {'memberIds': ['user']}) + async_responses = [response(202)] * 9 + [members_response] + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + members = self.api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) def test_cohort_request_status_with_two_failures_succeeds(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1) - async_request_response = MagicMock() - async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() - - async_request_status_503_response = response(503) - async_request_status_200_response = response(200) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(return_value=async_request_response) - api._get_cohort_async_request_status = MagicMock( - side_effect=[async_request_status_503_response, async_request_status_503_response, - async_request_status_200_response]) - api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) - api._get_cohort_async_request_members = MagicMock(return_value={'user'}) - - members = api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) - api._get_cohort_async_request.assert_called_once_with(cohort) - self.assertEqual(api._get_cohort_async_request_status.call_count, 3) - api._get_cohort_async_request_location.assert_called_once_with('4321') - api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) - - def test_cohort_request_status_throws_after_3_failures_cache_fallback_succeeds(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1) - async_request_response = MagicMock() - async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() - - async_request_status_response = response(503) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(return_value=async_request_response) - api._get_cohort_async_request_status = MagicMock(return_value=async_request_status_response) - api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) - api._get_cohort_async_request_members = MagicMock(return_value={'user'}) - api._get_cached_cohort_members = MagicMock(return_value={'user2'}) - - with self.assertRaises(CachedCohortDownloadException) as e: - api.get_cohort_members(cohort) - - self.assertEqual({'user2'}, e.exception.cached_members) - api._get_cohort_async_request.assert_called_once_with(cohort) - self.assertEqual(api._get_cohort_async_request_status.call_count, 3) - api._get_cohort_async_request_location.assert_not_called() - api._get_cohort_async_request_members.assert_not_called() - api._get_cached_cohort_members.assert_called_once_with('1234', USER_GROUP_TYPE) + cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) + error_response = response(503) + members_response = response(200, {'memberIds': ['user']}) + async_responses = [error_response, error_response, members_response] + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + members = self.api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) def test_cohort_request_status_429s_keep_retrying(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1) - async_request_response = MagicMock() - async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() - - async_request_status_429_response = response(429) - async_request_status_200_response = response(200) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(return_value=async_request_response) - api._get_cohort_async_request_status = MagicMock( - side_effect=[async_request_status_429_response] * 9 + [async_request_status_200_response]) - api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) - api._get_cohort_async_request_members = MagicMock(return_value={'user'}) - - members = api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) - api._get_cohort_async_request.assert_called_once_with(cohort) - self.assertEqual(api._get_cohort_async_request_status.call_count, 10) - api._get_cohort_async_request_location.assert_called_once_with('4321') - api._get_cohort_async_request_members.assert_called_once_with('1234', USER_GROUP_TYPE, urlparse(self.location)) - - def test_cohort_async_request_download_failure_falls_back_on_cached_request(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(side_effect=Exception('fail')) - api._get_cached_cohort_members = MagicMock(return_value={'user'}) - - with self.assertRaises(CachedCohortDownloadException) as e: - api.get_cohort_members(cohort) - - self.assertEqual({'user'}, e.exception.cached_members) - api._get_cached_cohort_members.assert_called_once_with('1234', USER_GROUP_TYPE) + cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) + error_response = response(429) + members_response = response(200, {'memberIds': ['user']}) + async_responses = [error_response] * 9 + [members_response] + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + members = self.api.get_cohort_members(cohort) + self.assertEqual({'user'}, members) def test_group_cohort_download_success(self): cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type="org name") - async_request_response = MagicMock() - async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() - - async_request_status_response = response(200) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(return_value=async_request_response) - api._get_cohort_async_request_status = MagicMock(return_value=async_request_status_response) - api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) - api._get_cohort_async_request_members = MagicMock(return_value={'group'}) - - members = api.get_cohort_members(cohort) - self.assertEqual({'group'}, members) - api._get_cohort_async_request.assert_called_once_with(cohort) - api._get_cohort_async_request_status.assert_called_once_with('4321') - api._get_cohort_async_request_location.assert_called_once_with('4321') - api._get_cohort_async_request_members.assert_called_once_with('1234', 'org name', urlparse(self.location)) + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': "org name"}) + members_response = response(200, {'memberIds': ['group']}) + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', return_value=members_response): + + members = self.api.get_cohort_members(cohort) + self.assertEqual({'group'}, members) def test_group_cohort_request_status_429s_keep_retrying(self): cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type="org name") - async_request_response = MagicMock() - async_request_response.read.return_value = json.dumps({'cohort_id': '1234', 'request_id': '4321'}).encode() - - async_request_status_429_response = response(429) - async_request_status_200_response = response(200) - api = DirectCohortDownloadApiV5('api', 'secret') - api._get_cohort_async_request = MagicMock(return_value=async_request_response) - api._get_cohort_async_request_status = MagicMock( - side_effect=[async_request_status_429_response] * 9 + [async_request_status_200_response]) - api._get_cohort_async_request_location = MagicMock(return_value=urlparse(self.location)) - api._get_cohort_async_request_members = MagicMock(return_value={'group'}) - - members = api.get_cohort_members(cohort) - self.assertEqual({'group'}, members) - api._get_cohort_async_request.assert_called_once_with(cohort) - self.assertEqual(api._get_cohort_async_request_status.call_count, 10) - api._get_cohort_async_request_location.assert_called_once_with('4321') - api._get_cohort_async_request_members.assert_called_once_with('1234', 'org name', urlparse(self.location)) + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': "org name"}) + error_response = response(429) + members_response = response(200, {'memberIds': ['group']}) + async_responses = [error_response] * 9 + [members_response] + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + members = self.api.get_cohort_members(cohort) + self.assertEqual({'group'}, members) + + def test_cohort_size_too_large(self): + cohort = CohortDescription(id="1234", last_computed=0, size=16000, group_type='user') + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 16000, 'groupType': 'user'}) + too_large_response = response(413) + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', return_value=too_large_response): + + with self.assertRaises(CohortTooLargeException): + self.api.get_cohort_members(cohort) if __name__ == '__main__': diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py index 7d0212c..b7706cb 100644 --- a/tests/cohort/cohort_loader_test.py +++ b/tests/cohort/cohort_loader_test.py @@ -8,10 +8,9 @@ class CohortLoaderTest(unittest.TestCase): def setUp(self): - self.config = MagicMock() self.api = MagicMock() self.storage = InMemoryCohortStorage() - self.loader = CohortLoader(15000, self.api, self.storage) + self.loader = CohortLoader(self.api, self.storage) def test_load_success(self): self.api.get_cohort_description.side_effect = [cohort_description("a"), cohort_description("b")] @@ -35,29 +34,13 @@ def test_load_success(self): self.assertEqual({"a", "b"}, storage_user1_cohorts) self.assertEqual({"b"}, storage_user2_cohorts) - def test_load_cohorts_greater_than_max_cohort_size_are_filtered(self): - self.api.get_cohort_description.side_effect = [cohort_description("a", size=float("inf")), - cohort_description("b", size=1)] - self.api.get_cohort_members.side_effect = [{"1", "2"}] - - self.loader.load_cohort("a").result() - self.loader.load_cohort("b").result() - - storage_description_a = self.storage.get_cohort_description("a") - storage_description_b = self.storage.get_cohort_description("b") - self.assertIsNone(storage_description_a) - self.assertEqual(cohort_description("b", size=1), storage_description_b) - - storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) - storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) - self.assertEqual({"b"}, storage_user1_cohorts) - self.assertEqual({"b"}, storage_user2_cohorts) - def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): self.storage.put_cohort(cohort_description("a", last_computed=0), set()) self.storage.put_cohort(cohort_description("b", last_computed=0), set()) - self.api.get_cohort_description.side_effect = [cohort_description("a", last_computed=0), - cohort_description("b", last_computed=1)] + self.api.get_cohort_description.side_effect = [ + cohort_description("a", last_computed=0), + cohort_description("b", last_computed=1) + ] self.api.get_cohort_members.side_effect = [{"1", "2"}] self.loader.load_cohort("a").result() @@ -74,8 +57,11 @@ def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): self.assertEqual({"b"}, storage_user2_cohorts) def test_load_download_failure_throws(self): - self.api.get_cohort_description.side_effect = [cohort_description("a"), cohort_description("b"), - cohort_description("c")] + self.api.get_cohort_description.side_effect = [ + cohort_description("a"), + cohort_description("b"), + cohort_description("c") + ] self.api.get_cohort_members.side_effect = [{"1"}, Exception("Connection timed out"), {"1"}] self.loader.load_cohort("a").result() diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index 35c9e7d..7a1e985 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -1,8 +1,9 @@ import unittest from unittest import mock -from src.amplitude_experiment import LocalEvaluationConfig +from src.amplitude_experiment import LocalEvaluationConfig, LocalEvaluationClient, User from src.amplitude_experiment.cohort.cohort_loader import CohortLoader +from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi from src.amplitude_experiment.deployment.deployment_runner import DeploymentRunner @@ -35,7 +36,7 @@ def test_start_throws_if_first_flag_config_load_fails(self): cohort_download_api = mock.Mock() flag_config_storage = mock.Mock() cohort_storage = mock.Mock() - cohort_loader = CohortLoader(100, cohort_download_api, cohort_storage) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, @@ -52,7 +53,7 @@ def test_start_throws_if_first_cohort_load_fails(self): cohort_download_api = mock.Mock() flag_config_storage = mock.Mock() cohort_storage = mock.Mock() - cohort_loader = CohortLoader(100, cohort_download_api, cohort_storage) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, diff --git a/tests/util/flag_config_test.py b/tests/util/flag_config_test.py index d4b2772..2802f45 100644 --- a/tests/util/flag_config_test.py +++ b/tests/util/flag_config_test.py @@ -1,6 +1,5 @@ import unittest -# Assuming the following utility functions are defined in a module named cohort_utils.py from src.amplitude_experiment.util.flag_config import ( get_all_cohort_ids_from_flags, get_grouped_cohort_ids_from_flags, From 473fe7bdfa77d73a071dcd93c9f3457a50bc0865 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 10 Jun 2024 16:22:29 -0700 Subject: [PATCH 06/44] fix tests, add logging config --- .../cohort/cohort_download_api.py | 7 ++++--- src/amplitude_experiment/cohort/cohort_loader.py | 11 +++++------ .../deployment/deployment_runner.py | 5 +++-- tests/cohort/cohort_download_api_test.py | 2 +- tests/cohort/cohort_loader_test.py | 2 -- tests/deployment/deployment_runner_test.py | 14 +++++++------- 6 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index cd15c08..d042aa1 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -24,13 +24,14 @@ def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: class DirectCohortDownloadApi(CohortDownloadApi): - def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, request_status_delay: int = 5, debug: bool = False): + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, + debug: bool = False, cohort_request_delay_millis: int = 5000): super().__init__() self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size self.__setup_connection_pool() - self.request_status_delay = request_status_delay + self.cohort_request_delay_millis = cohort_request_delay_millis self.logger = logging.getLogger("Amplitude") self.logger.addHandler(logging.StreamHandler()) if debug: @@ -79,7 +80,7 @@ def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: self.logger.debug(f"getCohortMembers({cohort_description.id}): request-status error {errors} - {e}") if errors >= 3 or isinstance(e, CohortTooLargeException): raise e - time.sleep(self.request_status_delay) + time.sleep(self.cohort_request_delay_millis/1000) def _get_cohort_members_request(self, cohort_id: str) -> HTTPResponse: headers = { diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index 0a4f297..144a0c3 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -28,17 +28,16 @@ def task(): cohort_members = self.download_cohort(cohort_description) self.cohort_storage.put_cohort(cohort_description, cohort_members) except Exception as e: - print(f"Failed to load cohort {cohort_id}: {e}") + raise e future = self.executor.submit(task) - future.add_done_callback(lambda f: self._remove_job(f, cohort_id)) + future.add_done_callback(lambda f: self._remove_job(cohort_id)) self.jobs[cohort_id] = future return self.jobs[cohort_id] - def _remove_job(self, future: Future, cohort_id: str): - with self.lock_jobs: - if cohort_id in self.jobs: - del self.jobs[cohort_id] + def _remove_job(self, cohort_id: str): + if cohort_id in self.jobs: + del self.jobs[cohort_id] def get_cohort_description(self, cohort_id: str) -> CohortDescription: return self.cohort_download_api.get_cohort_description(cohort_id) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 0e1c203..c54cada 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -69,10 +69,10 @@ def refresh(self, initial: bool = False): if initial: try: for future in futures: - self.logger.debug(f"Waiting for future {future}") future.result() except Exception as e: - self.logger.warning("Failed to download cohort", exc_info=e) + self.logger.warning("Failed to download cohort", e) + raise e self._delete_unused_cohorts() self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") @@ -88,6 +88,7 @@ def task(): self.logger.debug(f"Flag config {flag_config['key']} stored successfully.") except Exception as e: self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}", exc_info=e) + raise e return self.cohort_loader.executor.submit(task) diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index e2dc953..877498a 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -17,7 +17,7 @@ def response(code: int, body: dict = None): class CohortDownloadApiTest(unittest.TestCase): def setUp(self): - self.api = DirectCohortDownloadApi('api', 'secret', 15000, 1) + self.api = DirectCohortDownloadApi('api', 'secret', 15000, False, 100) def test_cohort_download_success(self): cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py index b7706cb..66a613b 100644 --- a/tests/cohort/cohort_loader_test.py +++ b/tests/cohort/cohort_loader_test.py @@ -16,11 +16,9 @@ def test_load_success(self): self.api.get_cohort_description.side_effect = [cohort_description("a"), cohort_description("b")] self.api.get_cohort_members.side_effect = [{"1"}, {"1", "2"}] - # Submitting tasks asynchronously future_a = self.loader.load_cohort("a") future_b = self.loader.load_cohort("b") - # Asserting after tasks complete future_a.result() future_b.result() diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index 7a1e985..02542b9 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -1,9 +1,9 @@ import unittest from unittest import mock +from unittest.mock import patch -from src.amplitude_experiment import LocalEvaluationConfig, LocalEvaluationClient, User +from src.amplitude_experiment import LocalEvaluationConfig from src.amplitude_experiment.cohort.cohort_loader import CohortLoader -from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi from src.amplitude_experiment.deployment.deployment_runner import DeploymentRunner @@ -60,11 +60,11 @@ def test_start_throws_if_first_cohort_load_fails(self): cohort_storage, cohort_loader ) - flag_api.get_flag_configs.return_value = [self.flag] - cohort_download_api.get_cohort_description.side_effect = RuntimeError("test") - - with self.assertRaises(RuntimeError): - runner.start() + with patch.object(runner, '_delete_unused_cohorts'): + flag_api.get_flag_configs.return_value = [self.flag] + cohort_download_api.get_cohort_description.side_effect = RuntimeError("test") + with self.assertRaises(RuntimeError): + runner.start() if __name__ == '__main__': From 3d5dfe99b92bf44dbc411093501e93f19df9c005 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 12 Jun 2024 12:24:01 -0700 Subject: [PATCH 07/44] add CohortNotModifiedException --- .../cohort/cohort_download_api.py | 23 +++++---- .../cohort/cohort_loader.py | 11 ++--- src/amplitude_experiment/exception.py | 8 +++- src/amplitude_experiment/local/client.py | 17 +++---- tests/cohort/cohort_download_api_test.py | 13 ++++- tests/util/flag_config_test.py | 48 +++++++++++++++++-- 6 files changed, 89 insertions(+), 31 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index d042aa1..7ee3a97 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -7,9 +7,9 @@ from .cohort_description import CohortDescription from ..connection_pool import HTTPConnectionPool -from ..exception import HTTPErrorResponseException, CohortTooLargeException +from ..exception import HTTPErrorResponseException, CohortTooLargeException, CohortNotModifiedException -CDN_COHORT_SYNC_URL = 'https://api.lab.amplitude.com' +CDN_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com' class CohortDownloadApi: @@ -19,7 +19,7 @@ def __init__(self): def get_cohort_description(self, cohort_id: str) -> CohortDescription: raise NotImplementedError - def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: + def get_cohort_members(self, cohort_description: CohortDescription, last_modified: int) -> Set[str]: raise NotImplementedError @@ -55,22 +55,24 @@ def get_cohort_info(self, cohort_id: str) -> HTTPResponse: finally: self._connection_pool.release(conn) - def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: + def get_cohort_members(self, cohort_description: CohortDescription, should_download_cohort: bool = False) -> Set[str]: self.logger.debug(f"getCohortMembers({cohort_description.id}): start - {cohort_description}") errors = 0 while True: response = None try: - response = self._get_cohort_members_request(cohort_description.id) + last_modified = -1 if should_download_cohort else cohort_description.last_computed + response = self._get_cohort_members_request(cohort_description.id, last_modified) self.logger.debug(f"getCohortMembers({cohort_description.id}): status={response.status}") if response.status == 200: response_json = json.loads(response.read().decode("utf8")) members = set(response_json['memberIds']) self.logger.debug(f"getCohortMembers({cohort_description.id}): end - resultSize={len(members)}") return members + elif response.status == 204: + raise CohortNotModifiedException(f"Cohort not modified: {response.status}") elif response.status == 413: - raise CohortTooLargeException(response.status, - f"Cohort exceeds max cohort size: {response.status}") + raise CohortTooLargeException(f"Cohort exceeds max cohort size: {response.status}") elif response.status != 202: raise HTTPErrorResponseException(response.status, f"Unexpected response code: {response.status}") @@ -82,13 +84,14 @@ def get_cohort_members(self, cohort_description: CohortDescription) -> Set[str]: raise e time.sleep(self.cohort_request_delay_millis/1000) - def _get_cohort_members_request(self, cohort_id: str) -> HTTPResponse: + def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse: headers = { 'Authorization': f'Basic {self._get_basic_auth()}', } conn = self._connection_pool.acquire() try: - response = conn.request('GET', f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}', + response = conn.request('GET', f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}' + f'&lastModified={last_modified}', headers=headers) return response finally: @@ -101,5 +104,5 @@ def _get_basic_auth(self) -> str: def __setup_connection_pool(self): scheme, _, host = self.cdn_server_url.split('/', 3) timeout = 10 - self._connection_pool = HTTPConnectionPool(host, max_size=50, idle_timeout=30, read_timeout=timeout, + self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout, scheme=scheme) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index 144a0c3..e9b9bfd 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -24,9 +24,8 @@ def load_cohort(self, cohort_id: str) -> Future: def task(): try: cohort_description = self.get_cohort_description(cohort_id) - if self.should_download_cohort(cohort_description): - cohort_members = self.download_cohort(cohort_description) - self.cohort_storage.put_cohort(cohort_description, cohort_members) + cohort_members = self.download_cohort(cohort_description) + self.cohort_storage.put_cohort(cohort_description, cohort_members) except Exception as e: raise e @@ -43,8 +42,8 @@ def get_cohort_description(self, cohort_id: str) -> CohortDescription: return self.cohort_download_api.get_cohort_description(cohort_id) def should_download_cohort(self, cohort_description: CohortDescription) -> bool: - storage_description = self.cohort_storage.get_cohort_description(cohort_description.id) - return cohort_description.last_computed > (storage_description.last_computed if storage_description else -1) + return self.cohort_storage.get_cohort_description(cohort_description.id) is None def download_cohort(self, cohort_description: CohortDescription) -> Set[str]: - return self.cohort_download_api.get_cohort_members(cohort_description) + return self.cohort_download_api.get_cohort_members(cohort_description, + self.should_download_cohort(cohort_description)) diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 53b6e06..15d9ac7 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -4,10 +4,14 @@ def __init__(self, status_code, message): self.status_code = status_code +class CohortNotModifiedException(Exception): + def __init__(self, message): + super().__init__(message) + + class CohortTooLargeException(Exception): - def __init__(self, cached_members, message): + def __init__(self, message): super().__init__(message) - self.cached_members = cached_members class HTTPErrorResponseException(Exception): diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 63a973e..4b507f0 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -58,16 +58,17 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): self.lock = Lock() self.cohort_storage = InMemoryCohortStorage() self.flag_config_storage = InMemoryFlagConfigStorage() - if config and config.cohort_sync_config: - cohort_download_api = DirectCohortDownloadApi(config.cohort_sync_config.api_key, - config.cohort_sync_config.secret_key, - config.cohort_sync_config.max_cohort_size, + cohort_loader = None + if self.config.cohort_sync_config: + cohort_download_api = DirectCohortDownloadApi(self.config.cohort_sync_config.api_key, + self.config.cohort_sync_config.secret_key, + self.config.cohort_sync_config.max_cohort_size, self.config.debug) cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) - flag_config_api = FlagConfigApiV2(api_key, config.server_url, - config.flag_config_poller_request_timeout_millis) - self.deployment_runner = DeploymentRunner(config, flag_config_api, self.flag_config_storage, - self.cohort_storage, cohort_loader) + flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, + self.config.flag_config_poller_request_timeout_millis) + self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, + self.cohort_storage, cohort_loader) def start(self): """ diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index 877498a..3da98a7 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch from src.amplitude_experiment.cohort.cohort_description import CohortDescription from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi -from src.amplitude_experiment.exception import CohortTooLargeException +from src.amplitude_experiment.exception import CohortTooLargeException, CohortNotModifiedException def response(code: int, body: dict = None): @@ -103,6 +103,17 @@ def test_cohort_size_too_large(self): with self.assertRaises(CohortTooLargeException): self.api.get_cohort_members(cohort) + def test_cohort_not_modified_exception(self): + cohort = CohortDescription(id="1234", last_computed=1000, size=1, group_type='user') + cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 1000, 'size': 1, 'groupType': 'user'}) + not_modified_response = response(204) + + with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ + patch.object(self.api, '_get_cohort_members_request', return_value=not_modified_response): + + with self.assertRaises(CohortNotModifiedException): + self.api.get_cohort_members(cohort, should_download_cohort=False) + if __name__ == '__main__': unittest.main() diff --git a/tests/util/flag_config_test.py b/tests/util/flag_config_test.py index 2802f45..0b16ab9 100644 --- a/tests/util/flag_config_test.py +++ b/tests/util/flag_config_test.py @@ -87,18 +87,57 @@ def setUp(self): 'value': 'on' } } + }, + { + 'key': 'flag-3', + 'metadata': { + 'deployed': True, + 'evaluationMode': 'local', + 'flagType': 'release', + 'flagVersion': 3 + }, + 'segments': [ + { + 'conditions': [ + [ + { + 'op': 'set contains any', + 'selector': ['context', 'groups', 'group_name', 'cohort_ids'], + 'values': ['cohort7', 'cohort8'] + } + ] + ], + 'metadata': {'segmentName': 'Segment C'}, + 'variant': 'on' + }, + { + 'metadata': {'segmentName': 'All Other Groups'}, + 'variant': 'off' + } + ], + 'variants': { + 'off': { + 'key': 'off', + 'metadata': {'default': True} + }, + 'on': { + 'key': 'on', + 'value': 'on' + } + } } ] def test_get_all_cohort_ids(self): - expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6', 'cohort7', 'cohort8'} for flag in self.flags: cohort_ids = get_all_cohort_ids_from_flag(flag) self.assertTrue(cohort_ids.issubset(expected_cohort_ids)) def test_get_grouped_cohort_ids_for_flag(self): expected_grouped_cohort_ids = { - 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'}, + 'group_name': {'cohort7', 'cohort8'} } for flag in self.flags: grouped_cohort_ids = get_grouped_cohort_ids_from_flag(flag) @@ -107,13 +146,14 @@ def test_get_grouped_cohort_ids_for_flag(self): self.assertTrue(values.issubset(expected_grouped_cohort_ids[key])) def test_get_all_cohort_ids_from_flags(self): - expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6', 'cohort7', 'cohort8'} cohort_ids = get_all_cohort_ids_from_flags(self.flags) self.assertEqual(cohort_ids, expected_cohort_ids) def test_get_grouped_cohort_ids_for_flag_from_flags(self): expected_grouped_cohort_ids = { - 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'} + 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'}, + 'group_name': {'cohort7', 'cohort8'} } grouped_cohort_ids = get_grouped_cohort_ids_from_flags(self.flags) self.assertEqual(grouped_cohort_ids, expected_grouped_cohort_ids) From 93eb012a7b4dbd09ce7ab80fd35d8b514a92568d Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 12 Jun 2024 17:43:29 -0700 Subject: [PATCH 08/44] update user transformation to evaluation context --- .../deployment/deployment_runner.py | 12 +++--- src/amplitude_experiment/user.py | 2 +- src/amplitude_experiment/util/user.py | 40 +++++++++++++------ 3 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index c54cada..9d57787 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -43,12 +43,10 @@ def stop(self): self.poller.stop() def __periodic_refresh(self): - while True: - try: - self.refresh(initial=False) - except Exception as e: - self.logger.error("Refresh flag configs failed.", exc_info=e) - time.sleep(self.config.flag_config_polling_interval_millis / 1000) + try: + self.refresh(initial=False) + except Exception as e: + self.logger.error("Refresh flag configs failed.", exc_info=e) def refresh(self, initial: bool = False): self.logger.debug("Refreshing flag configs.") @@ -82,7 +80,7 @@ def task(): try: for cohort_id in cohort_ids: future = self.cohort_loader.load_cohort(cohort_id) - future.result() # Wait for cohort to load + future.result() self.logger.debug(f"Cohort {cohort_id} loaded for flag {flag_config['key']}") self.flag_config_storage.put_flag_config(flag_config) self.logger.debug(f"Flag config {flag_config['key']} stored successfully.") diff --git a/src/amplitude_experiment/user.py b/src/amplitude_experiment/user.py index 3935a62..8cbdcea 100644 --- a/src/amplitude_experiment/user.py +++ b/src/amplitude_experiment/user.py @@ -27,7 +27,7 @@ def __init__( carrier: str = None, library: str = None, user_properties: Dict[str, Any] = None, - groups: Dict[str, str] = None, + groups: Dict[str, List[str]] = None, group_properties: Dict[str, Dict[str, Dict[str, Any]]] = None, group_cohort_ids: Dict[str, Dict[str, List[str]]] = None, cohort_ids: List[str] = None diff --git a/src/amplitude_experiment/util/user.py b/src/amplitude_experiment/util/user.py index 01aa779..f4e83cf 100644 --- a/src/amplitude_experiment/util/user.py +++ b/src/amplitude_experiment/util/user.py @@ -6,26 +6,40 @@ def user_to_evaluation_context(user: User) -> Dict[str, Any]: user_groups = user.groups user_group_properties = user.group_properties + user_group_cohort_ids = user.group_cohort_ids # Assuming this property exists on the User object user_dict = {key: value for key, value in user.__dict__.copy().items() if value is not None} user_dict.pop('groups', None) user_dict.pop('group_properties', None) + user_dict.pop('group_cohort_ids', None) # Removing the group_cohort_ids from the user dictionary context = {'user': user_dict} if len(user_dict) > 0 else {} + if user_groups is None: return context + groups: Dict[str, Dict[str, Any]] = {} - for group_type in user_groups: - group_name = user_groups[group_type] - if isinstance(group_name, list) and len(group_name) > 0: - group_name = group_name[0] - groups[group_type] = {'group_name': group_name} - if user_group_properties is None: - continue - group_properties_type = user_group_properties[group_type] - if group_properties_type is None or isinstance(group_properties_type, dict): + for group_type, group_names in user_groups.items(): + if isinstance(group_names, list) and len(group_names) > 0: + group_name = group_names[0] + else: continue - group_properties_name = group_properties_type[group_name] - if group_properties_name is None or isinstance(group_properties_name, dict): - continue - groups[group_type]['group_properties'] = group_properties_name + + group_name_map = {'group_name': group_name} + + if user_group_properties: + group_properties_type = user_group_properties.get(group_type) + if group_properties_type: + group_properties_name = group_properties_type.get(group_name) + if group_properties_name: + group_name_map['group_properties'] = group_properties_name + + if user_group_cohort_ids: + group_cohort_ids_type = user_group_cohort_ids.get(group_type) + if group_cohort_ids_type: + group_cohort_ids_name = group_cohort_ids_type.get(group_name) + if group_cohort_ids_name: + group_name_map['cohort_ids'] = group_cohort_ids_name + + groups[group_type] = group_name_map + context['groups'] = groups return context From 981de9ffc18430f49c2fc3112c0e568e18921333 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Fri, 14 Jun 2024 15:09:33 -0700 Subject: [PATCH 09/44] refactor and simplify to not use cohort_description --- .../{cohort_description.py => cohort.py} | 5 +- .../cohort/cohort_download_api.py | 51 +++++----- .../cohort/cohort_loader.py | 19 ++-- .../cohort/cohort_storage.py | 64 ++++++------ .../deployment/deployment_runner.py | 55 ++++++----- .../flag/flag_config_storage.py | 6 ++ src/amplitude_experiment/local/client.py | 2 +- src/amplitude_experiment/util/flag_config.py | 2 +- tests/cohort/cohort_download_api_test.py | 97 ++++++++----------- tests/cohort/cohort_loader_test.py | 50 +++++----- 10 files changed, 167 insertions(+), 184 deletions(-) rename src/amplitude_experiment/cohort/{cohort_description.py => cohort.py} (73%) diff --git a/src/amplitude_experiment/cohort/cohort_description.py b/src/amplitude_experiment/cohort/cohort.py similarity index 73% rename from src/amplitude_experiment/cohort/cohort_description.py rename to src/amplitude_experiment/cohort/cohort.py index d4d882b..99af861 100644 --- a/src/amplitude_experiment/cohort/cohort_description.py +++ b/src/amplitude_experiment/cohort/cohort.py @@ -1,12 +1,13 @@ from dataclasses import dataclass, field -from typing import ClassVar +from typing import ClassVar, Set USER_GROUP_TYPE: ClassVar[str] = "User" @dataclass -class CohortDescription: +class Cohort: id: str last_computed: int size: int + member_ids: Set[str] group_type: str = field(default=USER_GROUP_TYPE) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 7ee3a97..15b87b5 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -5,7 +5,7 @@ from http.client import HTTPResponse from typing import Set -from .cohort_description import CohortDescription +from .cohort import Cohort from ..connection_pool import HTTPConnectionPool from ..exception import HTTPErrorResponseException, CohortTooLargeException, CohortNotModifiedException @@ -16,10 +16,7 @@ class CohortDownloadApi: def __init__(self): self.cdn_server_url = CDN_COHORT_SYNC_URL - def get_cohort_description(self, cohort_id: str) -> CohortDescription: - raise NotImplementedError - - def get_cohort_members(self, cohort_description: CohortDescription, last_modified: int) -> Set[str]: + def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: raise NotImplementedError @@ -37,16 +34,6 @@ def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, if debug: self.logger.setLevel(logging.DEBUG) - def get_cohort_description(self, cohort_id: str) -> CohortDescription: - response = self.get_cohort_info(cohort_id) - cohort_info = json.loads(response.read().decode("utf-8")) - return CohortDescription( - id=cohort_info['cohortId'], - last_computed=cohort_info['lastComputed'], - size=cohort_info['size'], - group_type=cohort_info['groupType'], - ) - def get_cohort_info(self, cohort_id: str) -> HTTPResponse: conn = self._connection_pool.acquire() try: @@ -55,20 +42,25 @@ def get_cohort_info(self, cohort_id: str) -> HTTPResponse: finally: self._connection_pool.release(conn) - def get_cohort_members(self, cohort_description: CohortDescription, should_download_cohort: bool = False) -> Set[str]: - self.logger.debug(f"getCohortMembers({cohort_description.id}): start - {cohort_description}") + def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: + self.logger.debug(f"getCohortMembers({cohort_id}): start") errors = 0 while True: response = None try: - last_modified = -1 if should_download_cohort else cohort_description.last_computed - response = self._get_cohort_members_request(cohort_description.id, last_modified) - self.logger.debug(f"getCohortMembers({cohort_description.id}): status={response.status}") + last_modified = None if cohort is None else cohort.last_computed + response = self._get_cohort_members_request(cohort_id, last_modified) + self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}") if response.status == 200: - response_json = json.loads(response.read().decode("utf8")) - members = set(response_json['memberIds']) - self.logger.debug(f"getCohortMembers({cohort_description.id}): end - resultSize={len(members)}") - return members + cohort_info = json.loads(response.read().decode("utf8")) + self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}") + return Cohort( + id=cohort_info['cohortId'], + last_computed=cohort_info['lastComputed'], + size=cohort_info['size'], + member_ids=set(cohort_info['memberIds']), + group_type=cohort_info['groupType'], + ) elif response.status == 204: raise CohortNotModifiedException(f"Cohort not modified: {response.status}") elif response.status == 413: @@ -79,8 +71,8 @@ def get_cohort_members(self, cohort_description: CohortDescription, should_downl except Exception as e: if not isinstance(e, HTTPErrorResponseException) and response.status != 429: errors += 1 - self.logger.debug(f"getCohortMembers({cohort_description.id}): request-status error {errors} - {e}") - if errors >= 3 or isinstance(e, CohortTooLargeException): + self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") + if errors >= 3 or isinstance(e, CohortNotModifiedException) or isinstance(e, CohortTooLargeException): raise e time.sleep(self.cohort_request_delay_millis/1000) @@ -90,9 +82,10 @@ def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTT } conn = self._connection_pool.acquire() try: - response = conn.request('GET', f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}' - f'&lastModified={last_modified}', - headers=headers) + url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}' + if last_modified is not None: + url += f'&lastModified={last_modified}' + response = conn.request('GET', url, headers=headers) return response finally: self._connection_pool.release(conn) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index e9b9bfd..8d74a03 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor, Future import threading -from .cohort_description import CohortDescription +from .cohort import Cohort from .cohort_download_api import CohortDownloadApi from .cohort_storage import CohortStorage @@ -23,9 +23,8 @@ def load_cohort(self, cohort_id: str) -> Future: if cohort_id not in self.jobs: def task(): try: - cohort_description = self.get_cohort_description(cohort_id) - cohort_members = self.download_cohort(cohort_description) - self.cohort_storage.put_cohort(cohort_description, cohort_members) + cohort = self.download_cohort(cohort_id) + self.cohort_storage.put_cohort(cohort) except Exception as e: raise e @@ -38,12 +37,6 @@ def _remove_job(self, cohort_id: str): if cohort_id in self.jobs: del self.jobs[cohort_id] - def get_cohort_description(self, cohort_id: str) -> CohortDescription: - return self.cohort_download_api.get_cohort_description(cohort_id) - - def should_download_cohort(self, cohort_description: CohortDescription) -> bool: - return self.cohort_storage.get_cohort_description(cohort_description.id) is None - - def download_cohort(self, cohort_description: CohortDescription) -> Set[str]: - return self.cohort_download_api.get_cohort_members(cohort_description, - self.should_download_cohort(cohort_description)) + def download_cohort(self, cohort_id: str) -> Cohort: + cohort = self.cohort_storage.get_cohort(cohort_id) + return self.cohort_download_api.get_cohort(cohort_id, cohort) diff --git a/src/amplitude_experiment/cohort/cohort_storage.py b/src/amplitude_experiment/cohort/cohort_storage.py index c8a3430..e76c875 100644 --- a/src/amplitude_experiment/cohort/cohort_storage.py +++ b/src/amplitude_experiment/cohort/cohort_storage.py @@ -1,34 +1,41 @@ from typing import Dict, Set, Optional from threading import RLock -from .cohort_description import CohortDescription, USER_GROUP_TYPE +from .cohort import Cohort, USER_GROUP_TYPE class CohortStorage: - def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: - raise NotImplementedError() + def get_cohort(self, cohort_id: str): + pass - def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: - raise NotImplementedError() + def get_cohorts(self): + pass - def get_cohort_description(self, cohort_id: str) -> Optional[CohortDescription]: - raise NotImplementedError() + def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: + pass - def get_cohort_descriptions(self) -> Dict[str, CohortDescription]: - raise NotImplementedError() + def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: + pass - def put_cohort(self, cohort_description: CohortDescription, members: Set[str]): - raise NotImplementedError() + def put_cohort(self, cohort_description: Cohort): + pass def delete_cohort(self, group_type: str, cohort_id: str): - raise NotImplementedError() + pass class InMemoryCohortStorage(CohortStorage): def __init__(self): self.lock = RLock() - self.cohort_store: Dict[str, Dict[str, Set[str]]] = {} - self.description_store: Dict[str, CohortDescription] = {} + self.group_to_cohort_store: Dict[str, Set[str]] = {} + self.cohort_store: Dict[str, Cohort] = {} + + def get_cohort(self, cohort_id: str): + with self.lock: + return self.cohort_store.get(cohort_id) + + def get_cohorts(self): + return self.cohort_store.copy() def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids) @@ -36,29 +43,24 @@ def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: result = set() with self.lock: - group_type_cohorts = self.cohort_store.get(group_type, {}) - for cohort_id, members in group_type_cohorts.items(): + group_type_cohorts = self.group_to_cohort_store.get(group_type, {}) + for cohort_id in group_type_cohorts: + members = self.cohort_store.get(cohort_id).member_ids if cohort_id in cohort_ids and group_name in members: result.add(cohort_id) return result - def get_cohort_description(self, cohort_id: str) -> Optional[CohortDescription]: - with self.lock: - return self.description_store.get(cohort_id) - - def get_cohort_descriptions(self) -> Dict[str, CohortDescription]: - with self.lock: - return self.description_store.copy() - - def put_cohort(self, cohort_description: CohortDescription, members: Set[str]): + def put_cohort(self, cohort: Cohort): with self.lock: - self.cohort_store.setdefault(cohort_description.group_type, {})[cohort_description.id] = members - self.description_store[cohort_description.id] = cohort_description + if cohort.group_type not in self.group_to_cohort_store: + self.group_to_cohort_store[cohort.group_type] = set() + self.group_to_cohort_store[cohort.group_type].add(cohort.id) + self.cohort_store[cohort.id] = cohort def delete_cohort(self, group_type: str, cohort_id: str): with self.lock: - group_cohorts = self.cohort_store.get(group_type, {}) + group_cohorts = self.group_to_cohort_store.get(group_type, {}) if cohort_id in group_cohorts: - del group_cohorts[cohort_id] - if cohort_id in self.description_store: - del self.description_store[cohort_id] + group_cohorts.remove(cohort_id) + if cohort_id in self.cohort_store: + del self.cohort_store[cohort_id] diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 9d57787..5f80a38 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -1,8 +1,6 @@ import logging -from concurrent.futures import Future from typing import Optional, Set import threading -import time from ..local.config import LocalEvaluationConfig from ..cohort.cohort_loader import CohortLoader @@ -44,9 +42,9 @@ def stop(self): def __periodic_refresh(self): try: - self.refresh(initial=False) + self.refresh() except Exception as e: - self.logger.error("Refresh flag configs failed.", exc_info=e) + self.logger.error(f"Refresh flag and cohort configs failed: {e}") def refresh(self, initial: bool = False): self.logger.debug("Refreshing flag configs.") @@ -54,52 +52,65 @@ def refresh(self, initial: bool = False): flag_keys = {flag['key'] for flag in flag_configs} self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) - futures = [] for flag_config in flag_configs: cohort_ids = get_all_cohort_ids_from_flag(flag_config) if not self.cohort_loader or not cohort_ids: self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") self.flag_config_storage.put_flag_config(flag_config) continue - future = self._load_cohorts_and_store_flag(flag_config, cohort_ids) - futures.append(future) - if initial: + # Keep track of old flag and cohort for each flag + old_flag_config = self.flag_config_storage.get_flag_config(flag_config['key']) + try: - for future in futures: - future.result() + flag_loaded = self._load_cohorts_and_store_flag(flag_config, cohort_ids, initial) + if flag_loaded: + self.flag_config_storage.put_flag_config(flag_config) # Store new flag config + self.logger.debug(f"Stored flag config {flag_config['key']}") + else: + self.logger.warning(f"Failed to load all cohorts for flag {flag_config['key']}. Using the old flag config.") + self.flag_config_storage.put_flag_config(old_flag_config) except Exception as e: - self.logger.warning("Failed to download cohort", e) - raise e + self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}:{e}") + if initial: + raise e self._delete_unused_cohorts() self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") - def _load_cohorts_and_store_flag(self, flag_config: dict, cohort_ids: Set[str]) -> Future: + def _load_cohorts_and_store_flag(self, flag_config: dict, cohort_ids: Set[str], initial: bool): def task(): try: for cohort_id in cohort_ids: future = self.cohort_loader.load_cohort(cohort_id) future.result() self.logger.debug(f"Cohort {cohort_id} loaded for flag {flag_config['key']}") - self.flag_config_storage.put_flag_config(flag_config) - self.logger.debug(f"Flag config {flag_config['key']} stored successfully.") + return True # All cohorts loaded successfully except Exception as e: - self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}", exc_info=e) - raise e + self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}: {e}") + if initial: + raise e + return False # Cohort loading failed + + cohort_fetched = self.cohort_loader.executor.submit(task) + flag_fetched = True + + # Wait for both flag and cohort loading to complete + if initial: + flag_fetched = cohort_fetched.result() - return self.cohort_loader.executor.submit(task) + return flag_fetched def _delete_unused_cohorts(self): flag_cohort_ids = set() for flag in self.flag_config_storage.get_flag_configs().values(): flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag)) - storage_cohorts = self.cohort_storage.get_cohort_descriptions() + storage_cohorts = self.cohort_storage.get_cohorts() deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids for deleted_cohort_id in deleted_cohort_ids: - deleted_cohort_description = storage_cohorts.get(deleted_cohort_id) - if deleted_cohort_description is not None: - self.cohort_storage.delete_cohort(deleted_cohort_description.group_type, deleted_cohort_id) + deleted_cohort = storage_cohorts.get(deleted_cohort_id) + if deleted_cohort is not None: + self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) diff --git a/src/amplitude_experiment/flag/flag_config_storage.py b/src/amplitude_experiment/flag/flag_config_storage.py index c949f7a..14c765f 100644 --- a/src/amplitude_experiment/flag/flag_config_storage.py +++ b/src/amplitude_experiment/flag/flag_config_storage.py @@ -3,6 +3,8 @@ class FlagConfigStorage: + def get_flag_config(self, key: str) -> Dict: + pass def get_flag_configs(self) -> Dict: pass @@ -18,6 +20,10 @@ def __init__(self): self.flag_configs = {} self.flag_configs_lock = Lock() + def get_flag_config(self, key: str) -> Dict: + with self.flag_configs_lock: + return self.flag_configs.get(key) + def get_flag_configs(self) -> Dict[str, Dict]: with self.flag_configs_lock: return self.flag_configs.copy() diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 4b507f0..929d0b0 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -8,7 +8,7 @@ from .config import LocalEvaluationConfig from .topological_sort import topological_sort from ..assignment import Assignment, AssignmentFilter, AssignmentService -from ..cohort.cohort_description import USER_GROUP_TYPE +from ..cohort.cohort import USER_GROUP_TYPE from ..cohort.cohort_download_api import DirectCohortDownloadApi from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import InMemoryCohortStorage diff --git a/src/amplitude_experiment/util/flag_config.py b/src/amplitude_experiment/util/flag_config.py index 1697856..c64ac2c 100644 --- a/src/amplitude_experiment/util/flag_config.py +++ b/src/amplitude_experiment/util/flag_config.py @@ -1,6 +1,6 @@ from typing import List, Dict, Set, Any -from ..cohort.cohort_description import USER_GROUP_TYPE +from ..cohort.cohort import USER_GROUP_TYPE def is_cohort_filter(condition: Dict[str, Any]) -> bool: diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index 3da98a7..da8b400 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -1,7 +1,7 @@ import json import unittest from unittest.mock import MagicMock, patch -from src.amplitude_experiment.cohort.cohort_description import CohortDescription +from src.amplitude_experiment.cohort.cohort import Cohort from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi from src.amplitude_experiment.exception import CohortTooLargeException, CohortNotModifiedException @@ -20,99 +20,82 @@ def setUp(self): self.api = DirectCohortDownloadApi('api', 'secret', 15000, False, 100) def test_cohort_download_success(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) - members_response = response(200, {'memberIds': ['user']}) + cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) + cohort_info_response = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', return_value=members_response): + with patch.object(self.api, 'get_cohort', return_value=cohort_info_response): - members = self.api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) def test_cohort_download_many_202s_success(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) - members_response = response(200, {'memberIds': ['user']}) - async_responses = [response(202)] * 9 + [members_response] + cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) + async_responses = [response(202)] * 9 + [response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']})] - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): - members = self.api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) def test_cohort_request_status_with_two_failures_succeeds(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) + cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) error_response = response(503) - members_response = response(200, {'memberIds': ['user']}) - async_responses = [error_response, error_response, members_response] + success_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) + async_responses = [error_response, error_response, success_response] - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): - members = self.api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) def test_cohort_request_status_429s_keep_retrying(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type='user') - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'user'}) + cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) error_response = response(429) - members_response = response(200, {'memberIds': ['user']}) - async_responses = [error_response] * 9 + [members_response] + success_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) + async_responses = [error_response] * 9 + [success_response] - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): - members = self.api.get_cohort_members(cohort) - self.assertEqual({'user'}, members) + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) def test_group_cohort_download_success(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type="org name") - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': "org name"}) - members_response = response(200, {'memberIds': ['group']}) + cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'group'}, group_type="org name") + cohort_info_response = Cohort(id="1234", last_computed=0, size=1, member_ids={'group'}, group_type="org name") - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', return_value=members_response): + with patch.object(self.api, 'get_cohort', return_value=cohort_info_response): - members = self.api.get_cohort_members(cohort) - self.assertEqual({'group'}, members) + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) def test_group_cohort_request_status_429s_keep_retrying(self): - cohort = CohortDescription(id="1234", last_computed=0, size=1, group_type="org name") - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': "org name"}) + cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'group'}, group_type="org name") error_response = response(429) - members_response = response(200, {'memberIds': ['group']}) - async_responses = [error_response] * 9 + [members_response] + success_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'org name', 'memberIds': ['group']}) + async_responses = [error_response] * 9 + [success_response] - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): - members = self.api.get_cohort_members(cohort) - self.assertEqual({'group'}, members) + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) def test_cohort_size_too_large(self): - cohort = CohortDescription(id="1234", last_computed=0, size=16000, group_type='user') - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 16000, 'groupType': 'user'}) + cohort = Cohort(id="1234", last_computed=0, size=16000, member_ids=set()) too_large_response = response(413) - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', return_value=too_large_response): + with patch.object(self.api, '_get_cohort_members_request', return_value=too_large_response): with self.assertRaises(CohortTooLargeException): - self.api.get_cohort_members(cohort) + self.api.get_cohort("1234", cohort) def test_cohort_not_modified_exception(self): - cohort = CohortDescription(id="1234", last_computed=1000, size=1, group_type='user') - cohort_info_response = response(200, {'cohortId': '1234', 'lastComputed': 1000, 'size': 1, 'groupType': 'user'}) + cohort = Cohort(id="1234", last_computed=1000, size=1, member_ids=set()) not_modified_response = response(204) - with patch.object(self.api, 'get_cohort_info', return_value=cohort_info_response), \ - patch.object(self.api, '_get_cohort_members_request', return_value=not_modified_response): + with patch.object(self.api, '_get_cohort_members_request', return_value=not_modified_response): with self.assertRaises(CohortNotModifiedException): - self.api.get_cohort_members(cohort, should_download_cohort=False) + self.api.get_cohort("1234", cohort) if __name__ == '__main__': diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py index 66a613b..4ea819f 100644 --- a/tests/cohort/cohort_loader_test.py +++ b/tests/cohort/cohort_loader_test.py @@ -1,11 +1,10 @@ import unittest from unittest.mock import MagicMock -from src.amplitude_experiment.cohort.cohort_description import CohortDescription +from src.amplitude_experiment.cohort.cohort import Cohort from src.amplitude_experiment.cohort.cohort_loader import CohortLoader from src.amplitude_experiment.cohort.cohort_storage import InMemoryCohortStorage - class CohortLoaderTest(unittest.TestCase): def setUp(self): self.api = MagicMock() @@ -13,8 +12,10 @@ def setUp(self): self.loader = CohortLoader(self.api, self.storage) def test_load_success(self): - self.api.get_cohort_description.side_effect = [cohort_description("a"), cohort_description("b")] - self.api.get_cohort_members.side_effect = [{"1"}, {"1", "2"}] + self.api.get_cohort.side_effect = [ + Cohort(id="a", last_computed=0, size=1, member_ids={"1"}), + Cohort(id="b", last_computed=0, size=2, member_ids={"1", "2"}) + ] future_a = self.loader.load_cohort("a") future_b = self.loader.load_cohort("b") @@ -22,10 +23,10 @@ def test_load_success(self): future_a.result() future_b.result() - storage_description_a = self.storage.get_cohort_description("a") - storage_description_b = self.storage.get_cohort_description("b") - self.assertEqual(cohort_description("a"), storage_description_a) - self.assertEqual(cohort_description("b"), storage_description_b) + storage_description_a = self.storage.get_cohort("a") + storage_description_b = self.storage.get_cohort("b") + self.assertEqual(Cohort(id="a", last_computed=0, size=1, member_ids={"1"}), storage_description_a) + self.assertEqual(Cohort(id="b", last_computed=0, size=2, member_ids={"1", "2"}), storage_description_b) storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) @@ -33,21 +34,20 @@ def test_load_success(self): self.assertEqual({"b"}, storage_user2_cohorts) def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): - self.storage.put_cohort(cohort_description("a", last_computed=0), set()) - self.storage.put_cohort(cohort_description("b", last_computed=0), set()) - self.api.get_cohort_description.side_effect = [ - cohort_description("a", last_computed=0), - cohort_description("b", last_computed=1) + self.storage.put_cohort(Cohort("a", last_computed=0, size=0, member_ids=set())) + self.storage.put_cohort(Cohort("b", last_computed=0, size=0, member_ids=set())) + self.api.get_cohort.side_effect = [ + Cohort(id="a", last_computed=0, size=0, member_ids=set()), + Cohort(id="b", last_computed=1, size=2, member_ids={"1", "2"}) ] - self.api.get_cohort_members.side_effect = [{"1", "2"}] self.loader.load_cohort("a").result() self.loader.load_cohort("b").result() - storage_description_a = self.storage.get_cohort_description("a") - storage_description_b = self.storage.get_cohort_description("b") - self.assertEqual(cohort_description("a", last_computed=0), storage_description_a) - self.assertEqual(cohort_description("b", last_computed=1), storage_description_b) + storage_description_a = self.storage.get_cohort("a") + storage_description_b = self.storage.get_cohort("b") + self.assertEqual(Cohort(id="a", last_computed=0, size=0, member_ids=set()), storage_description_a) + self.assertEqual(Cohort(id="b", last_computed=1, size=2, member_ids={"1", "2"}), storage_description_b) storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) @@ -55,12 +55,11 @@ def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): self.assertEqual({"b"}, storage_user2_cohorts) def test_load_download_failure_throws(self): - self.api.get_cohort_description.side_effect = [ - cohort_description("a"), - cohort_description("b"), - cohort_description("c") + self.api.get_cohort.side_effect = [ + Cohort(id="a", last_computed=0, size=1, member_ids={"1"}), + Exception("Connection timed out"), + Cohort(id="c", last_computed=0, size=1, member_ids={"1"}) ] - self.api.get_cohort_members.side_effect = [{"1"}, Exception("Connection timed out"), {"1"}] self.loader.load_cohort("a").result() with self.assertRaises(Exception): @@ -69,10 +68,5 @@ def test_load_download_failure_throws(self): self.assertEqual({"a", "c"}, self.storage.get_cohorts_for_user("1", {"a", "b", "c"})) - -def cohort_description(cohort_id, last_computed=0, size=0): - return CohortDescription(id=cohort_id, last_computed=last_computed, size=size) - - if __name__ == "__main__": unittest.main() From 080e8f0d7a8ba97711fbdec6d015e245a4caad3d Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Fri, 14 Jun 2024 15:17:46 -0700 Subject: [PATCH 10/44] nit: fix formatting --- .../cohort/cohort_download_api.py | 9 --------- src/amplitude_experiment/cohort/cohort_storage.py | 12 ++++++------ src/amplitude_experiment/flag/flag_config_storage.py | 9 +++++---- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 15b87b5..4ea2efc 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -3,7 +3,6 @@ import base64 import json from http.client import HTTPResponse -from typing import Set from .cohort import Cohort from ..connection_pool import HTTPConnectionPool @@ -34,14 +33,6 @@ def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, if debug: self.logger.setLevel(logging.DEBUG) - def get_cohort_info(self, cohort_id: str) -> HTTPResponse: - conn = self._connection_pool.acquire() - try: - return conn.request('GET', f'/sdk/v1/cohort/{cohort_id}?skipCohortDownload=true', - headers={'Authorization': f'Basic {self._get_basic_auth()}'}) - finally: - self._connection_pool.release(conn) - def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: self.logger.debug(f"getCohortMembers({cohort_id}): start") errors = 0 diff --git a/src/amplitude_experiment/cohort/cohort_storage.py b/src/amplitude_experiment/cohort/cohort_storage.py index e76c875..5ba6018 100644 --- a/src/amplitude_experiment/cohort/cohort_storage.py +++ b/src/amplitude_experiment/cohort/cohort_storage.py @@ -6,22 +6,22 @@ class CohortStorage: def get_cohort(self, cohort_id: str): - pass + raise NotImplementedError def get_cohorts(self): - pass + raise NotImplementedError def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: - pass + raise NotImplementedError def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: - pass + raise NotImplementedError def put_cohort(self, cohort_description: Cohort): - pass + raise NotImplementedError def delete_cohort(self, group_type: str, cohort_id: str): - pass + raise NotImplementedError class InMemoryCohortStorage(CohortStorage): diff --git a/src/amplitude_experiment/flag/flag_config_storage.py b/src/amplitude_experiment/flag/flag_config_storage.py index 14c765f..68b1a73 100644 --- a/src/amplitude_experiment/flag/flag_config_storage.py +++ b/src/amplitude_experiment/flag/flag_config_storage.py @@ -4,15 +4,16 @@ class FlagConfigStorage: def get_flag_config(self, key: str) -> Dict: - pass + raise NotImplementedError + def get_flag_configs(self) -> Dict: - pass + raise NotImplementedError def put_flag_config(self, flag_config: Dict): - pass + raise NotImplementedError def remove_if(self, condition: Callable[[Dict], bool]): - pass + raise NotImplementedError class InMemoryFlagConfigStorage(FlagConfigStorage): From 23c8b25d6a24f3f34a3e82f3d2de159a1126d6e7 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Fri, 14 Jun 2024 15:28:29 -0700 Subject: [PATCH 11/44] handle flag fetch fail --- src/amplitude_experiment/deployment/deployment_runner.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 5f80a38..ce38a15 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -48,7 +48,13 @@ def __periodic_refresh(self): def refresh(self, initial: bool = False): self.logger.debug("Refreshing flag configs.") - flag_configs = self.flag_config_api.get_flag_configs() + try: + flag_configs = self.flag_config_api.get_flag_configs() + except Exception as e: + self.logger.error(f'Failed to fetch flag configs: {e}') + if initial: + raise Exception + return flag_keys = {flag['key'] for flag in flag_configs} self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) From f249b4481486c6d9766f52d5f646f91c0682d1b7 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 17 Jun 2024 11:39:48 -0700 Subject: [PATCH 12/44] Use lastModified instead of lastComputed --- src/amplitude_experiment/cohort/cohort.py | 2 +- .../cohort/cohort_download_api.py | 4 +-- tests/cohort/cohort_download_api_test.py | 28 +++++++++---------- tests/cohort/cohort_loader_test.py | 24 ++++++++-------- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort.py b/src/amplitude_experiment/cohort/cohort.py index 99af861..ccc5013 100644 --- a/src/amplitude_experiment/cohort/cohort.py +++ b/src/amplitude_experiment/cohort/cohort.py @@ -7,7 +7,7 @@ @dataclass class Cohort: id: str - last_computed: int + last_modified: int size: int member_ids: Set[str] group_type: str = field(default=USER_GROUP_TYPE) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 4ea2efc..daf44d9 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -39,7 +39,7 @@ def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: while True: response = None try: - last_modified = None if cohort is None else cohort.last_computed + last_modified = None if cohort is None else cohort.last_modified response = self._get_cohort_members_request(cohort_id, last_modified) self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}") if response.status == 200: @@ -47,7 +47,7 @@ def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}") return Cohort( id=cohort_info['cohortId'], - last_computed=cohort_info['lastComputed'], + last_modified=cohort_info['lastModified'], size=cohort_info['size'], member_ids=set(cohort_info['memberIds']), group_type=cohort_info['groupType'], diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index da8b400..e34297a 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -20,8 +20,8 @@ def setUp(self): self.api = DirectCohortDownloadApi('api', 'secret', 15000, False, 100) def test_cohort_download_success(self): - cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) - cohort_info_response = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) + cohort_info_response = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) with patch.object(self.api, 'get_cohort', return_value=cohort_info_response): @@ -29,8 +29,8 @@ def test_cohort_download_success(self): self.assertEqual(cohort, result_cohort) def test_cohort_download_many_202s_success(self): - cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) - async_responses = [response(202)] * 9 + [response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']})] + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) + async_responses = [response(202)] * 9 + [response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']})] with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): @@ -38,9 +38,9 @@ def test_cohort_download_many_202s_success(self): self.assertEqual(cohort, result_cohort) def test_cohort_request_status_with_two_failures_succeeds(self): - cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) error_response = response(503) - success_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) + success_response = response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) async_responses = [error_response, error_response, success_response] with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): @@ -49,9 +49,9 @@ def test_cohort_request_status_with_two_failures_succeeds(self): self.assertEqual(cohort, result_cohort) def test_cohort_request_status_429s_keep_retrying(self): - cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'user'}) + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) error_response = response(429) - success_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) + success_response = response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) async_responses = [error_response] * 9 + [success_response] with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): @@ -60,8 +60,8 @@ def test_cohort_request_status_429s_keep_retrying(self): self.assertEqual(cohort, result_cohort) def test_group_cohort_download_success(self): - cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'group'}, group_type="org name") - cohort_info_response = Cohort(id="1234", last_computed=0, size=1, member_ids={'group'}, group_type="org name") + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'group'}, group_type="org name") + cohort_info_response = Cohort(id="1234", last_modified=0, size=1, member_ids={'group'}, group_type="org name") with patch.object(self.api, 'get_cohort', return_value=cohort_info_response): @@ -69,9 +69,9 @@ def test_group_cohort_download_success(self): self.assertEqual(cohort, result_cohort) def test_group_cohort_request_status_429s_keep_retrying(self): - cohort = Cohort(id="1234", last_computed=0, size=1, member_ids={'group'}, group_type="org name") + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'group'}, group_type="org name") error_response = response(429) - success_response = response(200, {'cohortId': '1234', 'lastComputed': 0, 'size': 1, 'groupType': 'org name', 'memberIds': ['group']}) + success_response = response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'org name', 'memberIds': ['group']}) async_responses = [error_response] * 9 + [success_response] with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): @@ -80,7 +80,7 @@ def test_group_cohort_request_status_429s_keep_retrying(self): self.assertEqual(cohort, result_cohort) def test_cohort_size_too_large(self): - cohort = Cohort(id="1234", last_computed=0, size=16000, member_ids=set()) + cohort = Cohort(id="1234", last_modified=0, size=16000, member_ids=set()) too_large_response = response(413) with patch.object(self.api, '_get_cohort_members_request', return_value=too_large_response): @@ -89,7 +89,7 @@ def test_cohort_size_too_large(self): self.api.get_cohort("1234", cohort) def test_cohort_not_modified_exception(self): - cohort = Cohort(id="1234", last_computed=1000, size=1, member_ids=set()) + cohort = Cohort(id="1234", last_modified=1000, size=1, member_ids=set()) not_modified_response = response(204) with patch.object(self.api, '_get_cohort_members_request', return_value=not_modified_response): diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py index 4ea819f..309a8ff 100644 --- a/tests/cohort/cohort_loader_test.py +++ b/tests/cohort/cohort_loader_test.py @@ -13,8 +13,8 @@ def setUp(self): def test_load_success(self): self.api.get_cohort.side_effect = [ - Cohort(id="a", last_computed=0, size=1, member_ids={"1"}), - Cohort(id="b", last_computed=0, size=2, member_ids={"1", "2"}) + Cohort(id="a", last_modified=0, size=1, member_ids={"1"}), + Cohort(id="b", last_modified=0, size=2, member_ids={"1", "2"}) ] future_a = self.loader.load_cohort("a") @@ -25,8 +25,8 @@ def test_load_success(self): storage_description_a = self.storage.get_cohort("a") storage_description_b = self.storage.get_cohort("b") - self.assertEqual(Cohort(id="a", last_computed=0, size=1, member_ids={"1"}), storage_description_a) - self.assertEqual(Cohort(id="b", last_computed=0, size=2, member_ids={"1", "2"}), storage_description_b) + self.assertEqual(Cohort(id="a", last_modified=0, size=1, member_ids={"1"}), storage_description_a) + self.assertEqual(Cohort(id="b", last_modified=0, size=2, member_ids={"1", "2"}), storage_description_b) storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) @@ -34,11 +34,11 @@ def test_load_success(self): self.assertEqual({"b"}, storage_user2_cohorts) def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): - self.storage.put_cohort(Cohort("a", last_computed=0, size=0, member_ids=set())) - self.storage.put_cohort(Cohort("b", last_computed=0, size=0, member_ids=set())) + self.storage.put_cohort(Cohort("a", last_modified=0, size=0, member_ids=set())) + self.storage.put_cohort(Cohort("b", last_modified=0, size=0, member_ids=set())) self.api.get_cohort.side_effect = [ - Cohort(id="a", last_computed=0, size=0, member_ids=set()), - Cohort(id="b", last_computed=1, size=2, member_ids={"1", "2"}) + Cohort(id="a", last_modified=0, size=0, member_ids=set()), + Cohort(id="b", last_modified=1, size=2, member_ids={"1", "2"}) ] self.loader.load_cohort("a").result() @@ -46,8 +46,8 @@ def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): storage_description_a = self.storage.get_cohort("a") storage_description_b = self.storage.get_cohort("b") - self.assertEqual(Cohort(id="a", last_computed=0, size=0, member_ids=set()), storage_description_a) - self.assertEqual(Cohort(id="b", last_computed=1, size=2, member_ids={"1", "2"}), storage_description_b) + self.assertEqual(Cohort(id="a", last_modified=0, size=0, member_ids=set()), storage_description_a) + self.assertEqual(Cohort(id="b", last_modified=1, size=2, member_ids={"1", "2"}), storage_description_b) storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) @@ -56,9 +56,9 @@ def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): def test_load_download_failure_throws(self): self.api.get_cohort.side_effect = [ - Cohort(id="a", last_computed=0, size=1, member_ids={"1"}), + Cohort(id="a", last_modified=0, size=1, member_ids={"1"}), Exception("Connection timed out"), - Cohort(id="c", last_computed=0, size=1, member_ids={"1"}) + Cohort(id="c", last_modified=0, size=1, member_ids={"1"}) ] self.loader.load_cohort("a").result() From 04d35f650248ed8ca6c82981e11e2268a4dffe70 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 17 Jun 2024 13:50:14 -0700 Subject: [PATCH 13/44] add cohort_request_delay_millis to config --- src/amplitude_experiment/cohort/cohort_download_api.py | 2 +- src/amplitude_experiment/cohort/cohort_sync_config.py | 4 +++- src/amplitude_experiment/local/client.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index daf44d9..dcef209 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -20,7 +20,7 @@ def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: class DirectCohortDownloadApi(CohortDownloadApi): - def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, debug: bool = False, cohort_request_delay_millis: int = 5000): super().__init__() self.api_key = api_key diff --git a/src/amplitude_experiment/cohort/cohort_sync_config.py b/src/amplitude_experiment/cohort/cohort_sync_config.py index ba32659..adb6a01 100644 --- a/src/amplitude_experiment/cohort/cohort_sync_config.py +++ b/src/amplitude_experiment/cohort/cohort_sync_config.py @@ -1,5 +1,7 @@ class CohortSyncConfig: - def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000): + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, + cohort_request_delay_millis: int = 5000): self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size + self.cohort_request_delay_millis = cohort_request_delay_millis diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 929d0b0..94c6943 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -63,7 +63,8 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): cohort_download_api = DirectCohortDownloadApi(self.config.cohort_sync_config.api_key, self.config.cohort_sync_config.secret_key, self.config.cohort_sync_config.max_cohort_size, - self.config.debug) + self.config.debug, + self.config.cohort_sync_config.cohort_request_delay_millis) cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) From b4402eb4cf9de03531daad67f1a22d92ee6fd029 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 17 Jun 2024 14:03:46 -0700 Subject: [PATCH 14/44] fix DirectCohortDownloadApi constructor --- src/amplitude_experiment/cohort/cohort_download_api.py | 4 ++-- src/amplitude_experiment/local/client.py | 5 +++-- tests/cohort/cohort_download_api_test.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index dcef209..a389451 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -20,8 +20,8 @@ def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: class DirectCohortDownloadApi(CohortDownloadApi): - def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, - debug: bool = False, cohort_request_delay_millis: int = 5000): + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int, + debug: bool): super().__init__() self.api_key = api_key self.secret_key = secret_key diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 94c6943..12114fb 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -63,8 +63,9 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): cohort_download_api = DirectCohortDownloadApi(self.config.cohort_sync_config.api_key, self.config.cohort_sync_config.secret_key, self.config.cohort_sync_config.max_cohort_size, - self.config.debug, - self.config.cohort_sync_config.cohort_request_delay_millis) + self.config.cohort_sync_config.cohort_request_delay_millis, + self.config.debug) + cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index e34297a..c2482a2 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -17,7 +17,7 @@ def response(code: int, body: dict = None): class CohortDownloadApiTest(unittest.TestCase): def setUp(self): - self.api = DirectCohortDownloadApi('api', 'secret', 15000, False, 100) + self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, False) def test_cohort_download_success(self): cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) From a7ffbb6a019cca3c6277951633e631713de631fd Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 25 Jun 2024 10:58:38 -0700 Subject: [PATCH 15/44] Simplify deployment_runner, clean up comments --- .../cohort/cohort_download_api.py | 5 ++- .../deployment/deployment_runner.py | 41 +++++++------------ src/amplitude_experiment/util/user.py | 4 +- 3 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index a389451..495c71a 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -3,6 +3,7 @@ import base64 import json from http.client import HTTPResponse +from typing import Optional from .cohort import Cohort from ..connection_pool import HTTPConnectionPool @@ -15,7 +16,7 @@ class CohortDownloadApi: def __init__(self): self.cdn_server_url = CDN_COHORT_SYNC_URL - def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: + def get_cohort(self, cohort_id: str, cohort: Cohort) -> Optional[Cohort]: raise NotImplementedError @@ -33,7 +34,7 @@ def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_r if debug: self.logger.setLevel(logging.DEBUG) - def get_cohort(self, cohort_id: str, cohort: Cohort) -> Cohort: + def get_cohort(self, cohort_id: str, cohort: Cohort) -> Optional[Cohort]: self.logger.debug(f"getCohortMembers({cohort_id}): start") errors = 0 while True: diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index ce38a15..af27510 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -34,7 +34,7 @@ def __init__( def start(self): with self.lock: - self.refresh(initial=True) + self.refresh() self.poller.start() def stop(self): @@ -46,15 +46,14 @@ def __periodic_refresh(self): except Exception as e: self.logger.error(f"Refresh flag and cohort configs failed: {e}") - def refresh(self, initial: bool = False): + def refresh(self): self.logger.debug("Refreshing flag configs.") try: flag_configs = self.flag_config_api.get_flag_configs() except Exception as e: self.logger.error(f'Failed to fetch flag configs: {e}') - if initial: - raise Exception - return + raise Exception + flag_keys = {flag['key'] for flag in flag_configs} self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) @@ -69,43 +68,33 @@ def refresh(self, initial: bool = False): old_flag_config = self.flag_config_storage.get_flag_config(flag_config['key']) try: - flag_loaded = self._load_cohorts_and_store_flag(flag_config, cohort_ids, initial) - if flag_loaded: - self.flag_config_storage.put_flag_config(flag_config) # Store new flag config - self.logger.debug(f"Stored flag config {flag_config['key']}") - else: - self.logger.warning(f"Failed to load all cohorts for flag {flag_config['key']}. Using the old flag config.") - self.flag_config_storage.put_flag_config(old_flag_config) + self._load_cohorts(flag_config, cohort_ids) + self.flag_config_storage.put_flag_config(flag_config) # Store new flag config + self.logger.debug(f"Stored flag config {flag_config['key']}") + except Exception as e: - self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}:{e}") - if initial: - raise e + self.logger.warning(f"Failed to load all cohorts for flag {flag_config['key']}. " + f"Using the old flag config.") + self.flag_config_storage.put_flag_config(old_flag_config) + raise e self._delete_unused_cohorts() self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") - def _load_cohorts_and_store_flag(self, flag_config: dict, cohort_ids: Set[str], initial: bool): + def _load_cohorts(self, flag_config: dict, cohort_ids: Set[str]): def task(): try: for cohort_id in cohort_ids: future = self.cohort_loader.load_cohort(cohort_id) future.result() self.logger.debug(f"Cohort {cohort_id} loaded for flag {flag_config['key']}") - return True # All cohorts loaded successfully except Exception as e: self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}: {e}") - if initial: - raise e - return False # Cohort loading failed + raise e cohort_fetched = self.cohort_loader.executor.submit(task) - flag_fetched = True - # Wait for both flag and cohort loading to complete - if initial: - flag_fetched = cohort_fetched.result() - - return flag_fetched + cohort_fetched.result() def _delete_unused_cohorts(self): flag_cohort_ids = set() diff --git a/src/amplitude_experiment/util/user.py b/src/amplitude_experiment/util/user.py index f4e83cf..2b0b48f 100644 --- a/src/amplitude_experiment/util/user.py +++ b/src/amplitude_experiment/util/user.py @@ -6,11 +6,11 @@ def user_to_evaluation_context(user: User) -> Dict[str, Any]: user_groups = user.groups user_group_properties = user.group_properties - user_group_cohort_ids = user.group_cohort_ids # Assuming this property exists on the User object + user_group_cohort_ids = user.group_cohort_ids user_dict = {key: value for key, value in user.__dict__.copy().items() if value is not None} user_dict.pop('groups', None) user_dict.pop('group_properties', None) - user_dict.pop('group_cohort_ids', None) # Removing the group_cohort_ids from the user dictionary + user_dict.pop('group_cohort_ids', None) context = {'user': user_dict} if len(user_dict) > 0 else {} if user_groups is None: From f88d56ef851b8c58041e4cfa9d77d9f2253dfa11 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 25 Jun 2024 14:50:15 -0700 Subject: [PATCH 16/44] revert default deployment changes --- src/amplitude_experiment/factory.py | 20 ++++++++------------ src/amplitude_experiment/local/config.py | 4 ---- src/amplitude_experiment/remote/config.py | 6 +----- 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/amplitude_experiment/factory.py b/src/amplitude_experiment/factory.py index b118325..841c55c 100644 --- a/src/amplitude_experiment/factory.py +++ b/src/amplitude_experiment/factory.py @@ -15,17 +15,15 @@ def initialize_remote(api_key: str, config: RemoteEvaluationConfig = None) -> Re """ Initializes a remote evaluation client. Parameters: - api_key (str): The Amplitude Project API Key used in the client. If a deployment key is provided in the - config, it will be used instead + api_key (str): The Amplitude API Key config (RemoteEvaluationConfig): Optional Config Returns: A remote evaluation client. """ - used_key = config.deployment_key if config and config.deployment_key else api_key - if remote_evaluation_instances.get(used_key) is None: - remote_evaluation_instances[used_key] = RemoteEvaluationClient(used_key, config) - return remote_evaluation_instances[used_key] + if remote_evaluation_instances.get(api_key) is None: + remote_evaluation_instances[api_key] = RemoteEvaluationClient(api_key, config) + return remote_evaluation_instances[api_key] @staticmethod def initialize_local(api_key: str, config: LocalEvaluationConfig = None) -> LocalEvaluationClient: @@ -34,14 +32,12 @@ def initialize_local(api_key: str, config: LocalEvaluationConfig = None) -> Loca user without requiring a remote call to the amplitude evaluation server. In order to best leverage local evaluation, all flags, and experiments being evaluated server side should be configured as local. Parameters: - api_key (str): The Amplitude Project API Key used in the client. If a deployment key is provided in the - config, it will be used instead + api_key (str): The Amplitude API Key config (RemoteEvaluationConfig): Optional Config Returns: A local evaluation client. """ - used_key = config.deployment_key if config and config.deployment_key else api_key - if local_evaluation_instances.get(used_key) is None: - local_evaluation_instances[used_key] = LocalEvaluationClient(used_key, config) - return local_evaluation_instances[used_key] + if local_evaluation_instances.get(api_key) is None: + local_evaluation_instances[api_key] = LocalEvaluationClient(api_key, config) + return local_evaluation_instances[api_key] diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index 1183a9c..d34421b 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -12,7 +12,6 @@ def __init__(self, debug: bool = False, flag_config_polling_interval_millis: int = 30000, flag_config_poller_request_timeout_millis: int = 10000, assignment_config: AssignmentConfig = None, - deployment_key: str = None, cohort_sync_config: CohortSyncConfig = None): """ Initialize a config @@ -24,8 +23,6 @@ def __init__(self, debug: bool = False, to perform local evaluation. flag_config_poller_request_timeout_millis (int): The request timeout, in milliseconds, used when fetching variants. - deployment_key (str): The Experiment deployment key. If provided, it is used - instead of the project API key Returns: The config object @@ -35,5 +32,4 @@ def __init__(self, debug: bool = False, self.flag_config_polling_interval_millis = flag_config_polling_interval_millis self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis self.assignment_config = assignment_config - self.deployment_key = deployment_key self.cohort_sync_config = cohort_sync_config diff --git a/src/amplitude_experiment/remote/config.py b/src/amplitude_experiment/remote/config.py index b0a217c..7e84bf5 100644 --- a/src/amplitude_experiment/remote/config.py +++ b/src/amplitude_experiment/remote/config.py @@ -10,8 +10,7 @@ def __init__(self, debug=False, fetch_retry_backoff_min_millis=500, fetch_retry_backoff_max_millis=10000, fetch_retry_backoff_scalar=1.5, - fetch_retry_timeout_millis=10000, - deployment_key=None): + fetch_retry_timeout_millis=10000): """ Initialize a config Parameters: @@ -26,8 +25,6 @@ def __init__(self, debug=False, greater than the max, the max is used for all subsequent retries. fetch_retry_backoff_scalar (float): Scales the minimum backoff exponentially. fetch_retry_timeout_millis (int): The request timeout for retrying fetch requests. - deployment_key (str): The Experiment deployment key. If provided, it is used - instead of the project API key Returns: The config object @@ -40,4 +37,3 @@ def __init__(self, debug=False, self.fetch_retry_backoff_max_millis = fetch_retry_backoff_max_millis self.fetch_retry_backoff_scalar = fetch_retry_backoff_scalar self.fetch_retry_timeout_millis = fetch_retry_timeout_millis - self.deployment_key = deployment_key From f98f09fc252cc707933bc3ee667f43c8e1f99d66 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 26 Jun 2024 13:39:41 -0700 Subject: [PATCH 17/44] Update cohort sync config with comments and server_url config --- .../cohort/cohort_download_api.py | 15 +++---- .../cohort/cohort_sync_config.py | 18 ++++++++- src/amplitude_experiment/local/client.py | 1 + src/amplitude_experiment/local/config.py | 40 ++++++++++++++----- 4 files changed, 53 insertions(+), 21 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 495c71a..ed7ad08 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -9,32 +9,29 @@ from ..connection_pool import HTTPConnectionPool from ..exception import HTTPErrorResponseException, CohortTooLargeException, CohortNotModifiedException -CDN_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com' - class CohortDownloadApi: - def __init__(self): - self.cdn_server_url = CDN_COHORT_SYNC_URL - def get_cohort(self, cohort_id: str, cohort: Cohort) -> Optional[Cohort]: + def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: raise NotImplementedError class DirectCohortDownloadApi(CohortDownloadApi): def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int, - debug: bool): + server_url: str, debug: bool): super().__init__() self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size - self.__setup_connection_pool() self.cohort_request_delay_millis = cohort_request_delay_millis self.logger = logging.getLogger("Amplitude") self.logger.addHandler(logging.StreamHandler()) + self.server_url = server_url if debug: self.logger.setLevel(logging.DEBUG) + self.__setup_connection_pool() - def get_cohort(self, cohort_id: str, cohort: Cohort) -> Optional[Cohort]: + def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: self.logger.debug(f"getCohortMembers({cohort_id}): start") errors = 0 while True: @@ -87,7 +84,7 @@ def _get_basic_auth(self) -> str: return base64.b64encode(credentials.encode('utf-8')).decode('utf-8') def __setup_connection_pool(self): - scheme, _, host = self.cdn_server_url.split('/', 3) + scheme, _, host = self.server_url.split('/', 3) timeout = 10 self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout, scheme=scheme) diff --git a/src/amplitude_experiment/cohort/cohort_sync_config.py b/src/amplitude_experiment/cohort/cohort_sync_config.py index adb6a01..5a7a580 100644 --- a/src/amplitude_experiment/cohort/cohort_sync_config.py +++ b/src/amplitude_experiment/cohort/cohort_sync_config.py @@ -1,7 +1,23 @@ +DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com' +EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com' + + class CohortSyncConfig: + """Experiment Cohort Sync Configuration + This configuration is used to set up the cohort loader. The cohort loader is responsible for + downloading cohorts from the server and storing them locally. + Parameters: + api_key (str): The project API Key + secret_key (str): The project Secret Key + max_cohort_size (int): The maximum cohort size that can be downloaded + cohort_request_delay_millis (int): The delay in milliseconds between cohort download requests + cohort_server_url (str): The server endpoint from which to request cohorts + """ + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, - cohort_request_delay_millis: int = 5000): + cohort_request_delay_millis: int = 5000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL): self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size self.cohort_request_delay_millis = cohort_request_delay_millis + self.cohort_server_url = cohort_server_url diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 12114fb..b6f5a59 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -64,6 +64,7 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): self.config.cohort_sync_config.secret_key, self.config.cohort_sync_config.max_cohort_size, self.config.cohort_sync_config.cohort_request_delay_millis, + self.config.cohort_sync_config.cohort_server_url, self.config.debug) cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index d34421b..a408722 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -1,14 +1,23 @@ +from enum import Enum + from ..assignment import AssignmentConfig -from ..cohort.cohort_sync_config import CohortSyncConfig +from ..cohort.cohort_sync_config import CohortSyncConfig, DEFAULT_COHORT_SYNC_URL, EU_COHORT_SYNC_URL + +DEFAULT_SERVER_URL = 'https://api.lab.amplitude.com' +EU_SERVER_URL = 'https://api.eu.lab.amplitude.com' + + +class ServerZone(Enum): + US = "US" + EU = "EU" class LocalEvaluationConfig: """Experiment Local Client Configuration""" - DEFAULT_SERVER_URL = 'https://api.lab.amplitude.com' - def __init__(self, debug: bool = False, server_url: str = DEFAULT_SERVER_URL, + server_zone: ServerZone = ServerZone.US, flag_config_polling_interval_millis: int = 30000, flag_config_poller_request_timeout_millis: int = 10000, assignment_config: AssignmentConfig = None, @@ -16,20 +25,29 @@ def __init__(self, debug: bool = False, """ Initialize a config Parameters: - debug (bool): Set to true to log some extra information to the console. - server_url (str): The server endpoint from which to request variants. - flag_config_polling_interval_millis (int): The interval in milliseconds to poll the amplitude server for - flag config updates. These rules are stored in memory and used when calling evaluate() - to perform local evaluation. - flag_config_poller_request_timeout_millis (int): The request timeout, in milliseconds, - used when fetching variants. + debug (bool): Set to true to log some extra information to the console. + server_url (str): The server endpoint from which to request variants. + server_zone (ServerZone): Location of the Amplitude data center to get flags and cohorts from, US or EU + flag_config_polling_interval_millis (int): The interval, in milliseconds, at which to poll for flag + configurations. + flag_config_poller_request_timeout_millis (int): The request timeout, in milliseconds, used when + fetching flag configurations. + assignment_config (AssignmentConfig): The assignment configuration. + cohort_sync_config (CohortSyncConfig): The cohort sync configuration. Returns: The config object """ self.debug = debug self.server_url = server_url + self.server_zone = server_zone + self.cohort_sync_config = cohort_sync_config + if server_url == DEFAULT_SERVER_URL and server_zone == ServerZone.EU: + self.server_url = EU_SERVER_URL + if (cohort_sync_config is not None and + cohort_sync_config.cohort_server_url == DEFAULT_COHORT_SYNC_URL): + self.cohort_sync_config.cohort_server_url = EU_COHORT_SYNC_URL + self.flag_config_polling_interval_millis = flag_config_polling_interval_millis self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis self.assignment_config = assignment_config - self.cohort_sync_config = cohort_sync_config From 7243d9b6bda8295942336d5b6b294572462278b4 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 26 Jun 2024 13:47:32 -0700 Subject: [PATCH 18/44] fix EU flag url --- src/amplitude_experiment/local/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index a408722..c729e36 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -4,7 +4,7 @@ from ..cohort.cohort_sync_config import CohortSyncConfig, DEFAULT_COHORT_SYNC_URL, EU_COHORT_SYNC_URL DEFAULT_SERVER_URL = 'https://api.lab.amplitude.com' -EU_SERVER_URL = 'https://api.eu.lab.amplitude.com' +EU_SERVER_URL = 'https://flag.lab.eu.amplitude.com' class ServerZone(Enum): @@ -26,7 +26,7 @@ def __init__(self, debug: bool = False, Initialize a config Parameters: debug (bool): Set to true to log some extra information to the console. - server_url (str): The server endpoint from which to request variants. + server_url (str): The server endpoint from which to request flag configs. server_zone (ServerZone): Location of the Amplitude data center to get flags and cohorts from, US or EU flag_config_polling_interval_millis (int): The interval, in milliseconds, at which to poll for flag configurations. From 4fe1facdffd15d0f30df3a56faeef8a007cbf416 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 26 Jun 2024 14:04:37 -0700 Subject: [PATCH 19/44] export CohortSyncConfig and ServerZone --- src/amplitude_experiment/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/amplitude_experiment/__init__.py b/src/amplitude_experiment/__init__.py index 482e490..aede23b 100644 --- a/src/amplitude_experiment/__init__.py +++ b/src/amplitude_experiment/__init__.py @@ -12,4 +12,6 @@ from .cookie import AmplitudeCookie from .local.client import LocalEvaluationClient from .local.config import LocalEvaluationConfig +from .local.config import ServerZone from .assignment import AssignmentConfig +from .cohort.cohort_sync_config import CohortSyncConfig From a51067bc1a7aea5471c9fbd5601515c80fc60bea Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 27 Jun 2024 09:33:29 -0700 Subject: [PATCH 20/44] nit: simplify logic --- src/amplitude_experiment/local/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index b6f5a59..79cd76e 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -99,7 +99,6 @@ def evaluate_v2(self, user: User, flag_keys: Set[str] = None) -> Dict[str, Varia if flag_configs is None or len(flag_configs) == 0: return {} self.logger.debug(f"[Experiment] Evaluate: user={user} - Flags: {flag_configs}") - flag_configs = self.flag_config_storage.get_flag_configs() sorted_flags = topological_sort(flag_configs, flag_keys) if not sorted_flags: return {} From f0e899b85763507d1a936897dba72777fdf3ff44 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Fri, 28 Jun 2024 14:56:58 -0700 Subject: [PATCH 21/44] Handle 204 errors --- src/amplitude_experiment/cohort/cohort_download_api.py | 2 +- src/amplitude_experiment/deployment/deployment_runner.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index ed7ad08..74027ab 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -58,7 +58,7 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: raise HTTPErrorResponseException(response.status, f"Unexpected response code: {response.status}") except Exception as e: - if not isinstance(e, HTTPErrorResponseException) and response.status != 429: + if response and not isinstance(e, HTTPErrorResponseException) and response.status != 429: errors += 1 self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") if errors >= 3 or isinstance(e, CohortNotModifiedException) or isinstance(e, CohortTooLargeException): diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index af27510..4e04c30 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -2,6 +2,7 @@ from typing import Optional, Set import threading +from ..exception import CohortNotModifiedException from ..local.config import LocalEvaluationConfig from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import CohortStorage @@ -55,7 +56,7 @@ def refresh(self): raise Exception flag_keys = {flag['key'] for flag in flag_configs} - self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys) + self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys) for flag_config in flag_configs: cohort_ids = get_all_cohort_ids_from_flag(flag_config) @@ -89,8 +90,9 @@ def task(): future.result() self.logger.debug(f"Cohort {cohort_id} loaded for flag {flag_config['key']}") except Exception as e: - self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}: {e}") - raise e + if not isinstance(e, CohortNotModifiedException): + self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}: {e}") + raise e cohort_fetched = self.cohort_loader.executor.submit(task) # Wait for both flag and cohort loading to complete From 5130cddc797d3401627a741cdb70769ff1ebc273 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 2 Jul 2024 16:45:00 -0700 Subject: [PATCH 22/44] update deployment_runner flag/cohort update logic, update tests, fix logger initialization --- .../cohort/cohort_download_api.py | 9 +- .../cohort/cohort_loader.py | 18 ++- .../cohort/cohort_storage.py | 7 ++ .../cohort/cohort_sync_config.py | 2 +- .../deployment/deployment_runner.py | 103 +++++++++--------- src/amplitude_experiment/local/client.py | 6 +- tests/cohort/cohort_download_api_test.py | 2 +- tests/deployment/deployment_runner_test.py | 13 ++- 8 files changed, 95 insertions(+), 65 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 74027ab..349094d 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -18,17 +18,14 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: class DirectCohortDownloadApi(CohortDownloadApi): def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int, - server_url: str, debug: bool): + server_url: str, logger: logging.Logger = None): super().__init__() self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size self.cohort_request_delay_millis = cohort_request_delay_millis - self.logger = logging.getLogger("Amplitude") - self.logger.addHandler(logging.StreamHandler()) self.server_url = server_url - if debug: - self.logger.setLevel(logging.DEBUG) + self.logger = logger or logging.getLogger("Amplitude") self.__setup_connection_pool() def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: @@ -58,7 +55,7 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: raise HTTPErrorResponseException(response.status, f"Unexpected response code: {response.status}") except Exception as e: - if response and not isinstance(e, HTTPErrorResponseException) and response.status != 429: + if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429): errors += 1 self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") if errors >= 3 or isinstance(e, CohortNotModifiedException) or isinstance(e, CohortTooLargeException): diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index 8d74a03..17b0181 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -1,5 +1,6 @@ +import logging from typing import Dict, Set -from concurrent.futures import ThreadPoolExecutor, Future +from concurrent.futures import ThreadPoolExecutor, Future, as_completed import threading from .cohort import Cohort @@ -8,7 +9,8 @@ class CohortLoader: - def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage): + def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage, + logger: logging.Logger = None): self.cohort_download_api = cohort_download_api self.cohort_storage = cohort_storage self.jobs: Dict[str, Future] = {} @@ -17,6 +19,7 @@ def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: Cohor max_workers=32, thread_name_prefix='CohortLoaderExecutor' ) + self.logger = logger or logging.getLogger("Amplitude") def load_cohort(self, cohort_id: str) -> Future: with self.lock_jobs: @@ -40,3 +43,14 @@ def _remove_job(self, cohort_id: str): def download_cohort(self, cohort_id: str) -> Cohort: cohort = self.cohort_storage.get_cohort(cohort_id) return self.cohort_download_api.get_cohort(cohort_id, cohort) + + def update_stored_cohorts(self) -> Future: + def task(): + futures = [self.load_cohort(cohort_id) for cohort_id in self.cohort_storage.get_cohort_ids()] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + self.logger.error(f"Error updating cohort: {e}") + + return self.executor.submit(task) diff --git a/src/amplitude_experiment/cohort/cohort_storage.py b/src/amplitude_experiment/cohort/cohort_storage.py index 5ba6018..9c6d4c5 100644 --- a/src/amplitude_experiment/cohort/cohort_storage.py +++ b/src/amplitude_experiment/cohort/cohort_storage.py @@ -23,6 +23,9 @@ def put_cohort(self, cohort_description: Cohort): def delete_cohort(self, group_type: str, cohort_id: str): raise NotImplementedError + def get_cohort_ids(self) -> Set[str]: + raise NotImplementedError + class InMemoryCohortStorage(CohortStorage): def __init__(self): @@ -64,3 +67,7 @@ def delete_cohort(self, group_type: str, cohort_id: str): group_cohorts.remove(cohort_id) if cohort_id in self.cohort_store: del self.cohort_store[cohort_id] + + def get_cohort_ids(self): + with self.lock: + return set(self.cohort_store.keys()) diff --git a/src/amplitude_experiment/cohort/cohort_sync_config.py b/src/amplitude_experiment/cohort/cohort_sync_config.py index 5a7a580..0ee25b2 100644 --- a/src/amplitude_experiment/cohort/cohort_sync_config.py +++ b/src/amplitude_experiment/cohort/cohort_sync_config.py @@ -14,7 +14,7 @@ class CohortSyncConfig: cohort_server_url (str): The server endpoint from which to request cohorts """ - def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000, + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647, cohort_request_delay_millis: int = 5000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL): self.api_key = api_key self.secret_key = secret_key diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 4e04c30..39ef62d 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -2,7 +2,6 @@ from typing import Optional, Set import threading -from ..exception import CohortNotModifiedException from ..local.config import LocalEvaluationConfig from ..cohort.cohort_loader import CohortLoader from ..cohort.cohort_storage import CohortStorage @@ -20,6 +19,7 @@ def __init__( flag_config_storage: FlagConfigStorage, cohort_storage: CohortStorage, cohort_loader: Optional[CohortLoader] = None, + logger: logging.Logger = None ): self.config = config self.flag_config_api = flag_config_api @@ -27,76 +27,82 @@ def __init__( self.cohort_storage = cohort_storage self.cohort_loader = cohort_loader self.lock = threading.Lock() - self.poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_refresh) - self.logger = logging.getLogger("Amplitude") - self.logger.addHandler(logging.StreamHandler()) - if self.config.debug: - self.logger.setLevel(logging.DEBUG) + self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) + if self.cohort_loader: + self.cohort_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, + self.__update_cohorts) + self.logger = logger def start(self): with self.lock: - self.refresh() - self.poller.start() + self.__update_flag_configs() + self.flag_poller.start() + if self.cohort_loader: + self.cohort_poller.start() def stop(self): - self.poller.stop() + self.flag_poller.stop() - def __periodic_refresh(self): + def __periodic_flag_update(self): try: - self.refresh() + self.__update_flag_configs() except Exception as e: - self.logger.error(f"Refresh flag and cohort configs failed: {e}") + self.logger.error(f"Error while updating flags: {e}") - def refresh(self): - self.logger.debug("Refreshing flag configs.") + def __update_flag_configs(self): try: flag_configs = self.flag_config_api.get_flag_configs() except Exception as e: self.logger.error(f'Failed to fetch flag configs: {e}') - raise Exception - + raise e flag_keys = {flag['key'] for flag in flag_configs} self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys) + if not self.cohort_loader: + for flag_config in flag_configs: + self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") + self.flag_config_storage.put_flag_config(flag_config) + return + new_cohort_ids = set() + for flag_config in flag_configs: + new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) + existing_cohort_ids = self.cohort_storage.get_cohort_ids() + cohort_ids_to_download = new_cohort_ids - existing_cohort_ids + cohort_download_error = None + # download all new cohorts + for cohort_id in cohort_ids_to_download: + try: + self.cohort_loader.load_cohort(cohort_id).result() + except Exception as e: + cohort_download_error = e + self.logger.error(f"Download cohort {cohort_id} failed: {e}") + # get updated set of cohort ids + updated_cohort_ids = self.cohort_storage.get_cohort_ids() + # iterate through new flag configs and check if their required cohorts exist + failed_flag_count = 0 for flag_config in flag_configs: cohort_ids = get_all_cohort_ids_from_flag(flag_config) - if not self.cohort_loader or not cohort_ids: + if not cohort_ids or not self.cohort_loader: + self.flag_config_storage.put_flag_config(flag_config) self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") + # only store flag config if all required cohorts exist + elif cohort_ids.issubset(updated_cohort_ids): self.flag_config_storage.put_flag_config(flag_config) - continue - - # Keep track of old flag and cohort for each flag - old_flag_config = self.flag_config_storage.get_flag_config(flag_config['key']) - - try: - self._load_cohorts(flag_config, cohort_ids) - self.flag_config_storage.put_flag_config(flag_config) # Store new flag config - self.logger.debug(f"Stored flag config {flag_config['key']}") - - except Exception as e: - self.logger.warning(f"Failed to load all cohorts for flag {flag_config['key']}. " - f"Using the old flag config.") - self.flag_config_storage.put_flag_config(old_flag_config) - raise e + self.logger.debug(f"Putting flag {flag_config['key']}") + else: + self.logger.error(f"Flag {flag_config['key']} not updated because " + f"not all required cohorts could be loaded") + failed_flag_count += 1 + # delete unused cohorts self._delete_unused_cohorts() - self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") + self.logger.debug(f"Refreshed {len(flag_configs) - failed_flag_count} flag configs.") + # if not all required cohorts exist, throw an error + if cohort_download_error: + raise cohort_download_error - def _load_cohorts(self, flag_config: dict, cohort_ids: Set[str]): - def task(): - try: - for cohort_id in cohort_ids: - future = self.cohort_loader.load_cohort(cohort_id) - future.result() - self.logger.debug(f"Cohort {cohort_id} loaded for flag {flag_config['key']}") - except Exception as e: - if not isinstance(e, CohortNotModifiedException): - self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}: {e}") - raise e - - cohort_fetched = self.cohort_loader.executor.submit(task) - # Wait for both flag and cohort loading to complete - cohort_fetched.result() + def __update_cohorts(self): + self.cohort_loader.update_stored_cohorts().result() def _delete_unused_cohorts(self): flag_cohort_ids = set() @@ -110,4 +116,3 @@ def _delete_unused_cohorts(self): deleted_cohort = storage_cohorts.get(deleted_cohort_id) if deleted_cohort is not None: self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) - diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 79cd76e..e2b45f7 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -65,13 +65,13 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): self.config.cohort_sync_config.max_cohort_size, self.config.cohort_sync_config.cohort_request_delay_millis, self.config.cohort_sync_config.cohort_server_url, - self.config.debug) + self.logger) - cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) + cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage, self.logger) flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, - self.cohort_storage, cohort_loader) + self.cohort_storage, cohort_loader, self.logger) def start(self): """ diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index c2482a2..03c6600 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -17,7 +17,7 @@ def response(code: int, body: dict = None): class CohortDownloadApiTest(unittest.TestCase): def setUp(self): - self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, False) + self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com") def test_cohort_download_success(self): cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index 02542b9..c4c0c8f 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -1,6 +1,7 @@ import unittest from unittest import mock from unittest.mock import patch +import logging from src.amplitude_experiment import LocalEvaluationConfig from src.amplitude_experiment.cohort.cohort_loader import CohortLoader @@ -36,13 +37,16 @@ def test_start_throws_if_first_flag_config_load_fails(self): cohort_download_api = mock.Mock() flag_config_storage = mock.Mock() cohort_storage = mock.Mock() + cohort_storage.get_cohort_ids.return_value = set() cohort_loader = CohortLoader(cohort_download_api, cohort_storage) + logger = mock.create_autospec(logging.Logger) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, cohort_storage, - cohort_loader + cohort_loader, + logger # Pass the logger mock here ) flag_api.get_flag_configs.side_effect = RuntimeError("test") with self.assertRaises(RuntimeError): @@ -53,16 +57,19 @@ def test_start_throws_if_first_cohort_load_fails(self): cohort_download_api = mock.Mock() flag_config_storage = mock.Mock() cohort_storage = mock.Mock() + cohort_storage.get_cohort_ids.return_value = set() cohort_loader = CohortLoader(cohort_download_api, cohort_storage) + logger = mock.create_autospec(logging.Logger) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, cohort_storage, - cohort_loader + cohort_loader, + logger # Pass the logger mock here ) with patch.object(runner, '_delete_unused_cohorts'): flag_api.get_flag_configs.return_value = [self.flag] - cohort_download_api.get_cohort_description.side_effect = RuntimeError("test") + cohort_download_api.get_cohort.side_effect = RuntimeError("test") with self.assertRaises(RuntimeError): runner.start() From 5ed0a98b58add65ad5e2871421820bcb830e0966 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 3 Jul 2024 10:25:54 -0700 Subject: [PATCH 23/44] Update logger requirement for classes --- .../cohort/cohort_download_api.py | 4 ++-- .../cohort/cohort_loader.py | 4 ++-- .../deployment/deployment_runner.py | 21 ++++++++++++------- src/amplitude_experiment/local/client.py | 2 +- tests/cohort/cohort_download_api_test.py | 4 +++- tests/cohort/cohort_loader_test.py | 4 +++- tests/deployment/deployment_runner_test.py | 10 ++++----- 7 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 349094d..5616917 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -18,14 +18,14 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: class DirectCohortDownloadApi(CohortDownloadApi): def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int, - server_url: str, logger: logging.Logger = None): + server_url: str, logger: logging.Logger): super().__init__() self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size self.cohort_request_delay_millis = cohort_request_delay_millis self.server_url = server_url - self.logger = logger or logging.getLogger("Amplitude") + self.logger = logger self.__setup_connection_pool() def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index 17b0181..d525085 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -10,7 +10,7 @@ class CohortLoader: def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage, - logger: logging.Logger = None): + logger: logging.Logger): self.cohort_download_api = cohort_download_api self.cohort_storage = cohort_storage self.jobs: Dict[str, Future] = {} @@ -19,7 +19,7 @@ def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: Cohor max_workers=32, thread_name_prefix='CohortLoaderExecutor' ) - self.logger = logger or logging.getLogger("Amplitude") + self.logger = logger def load_cohort(self, cohort_id: str) -> Future: with self.lock_jobs: diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 39ef62d..b273c0d 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -18,8 +18,8 @@ def __init__( flag_config_api: FlagConfigApi, flag_config_storage: FlagConfigStorage, cohort_storage: CohortStorage, + logger: logging.Logger, cohort_loader: Optional[CohortLoader] = None, - logger: logging.Logger = None ): self.config = config self.flag_config_api = flag_config_api @@ -55,25 +55,30 @@ def __update_flag_configs(self): except Exception as e: self.logger.error(f'Failed to fetch flag configs: {e}') raise e + flag_keys = {flag['key'] for flag in flag_configs} self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys) + if not self.cohort_loader: for flag_config in flag_configs: self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") self.flag_config_storage.put_flag_config(flag_config) return + new_cohort_ids = set() for flag_config in flag_configs: new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) + existing_cohort_ids = self.cohort_storage.get_cohort_ids() cohort_ids_to_download = new_cohort_ids - existing_cohort_ids - cohort_download_error = None + cohort_download_errors = [] + # download all new cohorts for cohort_id in cohort_ids_to_download: try: self.cohort_loader.load_cohort(cohort_id).result() except Exception as e: - cohort_download_error = e + cohort_download_errors.append((cohort_id, str(e))) self.logger.error(f"Download cohort {cohort_id} failed: {e}") # get updated set of cohort ids @@ -85,7 +90,6 @@ def __update_flag_configs(self): if not cohort_ids or not self.cohort_loader: self.flag_config_storage.put_flag_config(flag_config) self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") - # only store flag config if all required cohorts exist elif cohort_ids.issubset(updated_cohort_ids): self.flag_config_storage.put_flag_config(flag_config) self.logger.debug(f"Putting flag {flag_config['key']}") @@ -97,9 +101,12 @@ def __update_flag_configs(self): # delete unused cohorts self._delete_unused_cohorts() self.logger.debug(f"Refreshed {len(flag_configs) - failed_flag_count} flag configs.") - # if not all required cohorts exist, throw an error - if cohort_download_error: - raise cohort_download_error + + # if there are any download errors, raise an aggregated exception + if cohort_download_errors: + error_count = len(cohort_download_errors) + error_messages = "\n".join([f"Cohort {cohort_id}: {error}" for cohort_id, error in cohort_download_errors]) + raise Exception(f"{error_count} cohort(s) failed to download:\n{error_messages}") def __update_cohorts(self): self.cohort_loader.update_stored_cohorts().result() diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index e2b45f7..657c415 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -71,7 +71,7 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, - self.cohort_storage, cohort_loader, self.logger) + self.cohort_storage, self.logger, cohort_loader) def start(self): """ diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index 03c6600..3af9c8e 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -1,5 +1,7 @@ import json +import logging import unittest +from unittest import mock from unittest.mock import MagicMock, patch from src.amplitude_experiment.cohort.cohort import Cohort from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi @@ -17,7 +19,7 @@ def response(code: int, body: dict = None): class CohortDownloadApiTest(unittest.TestCase): def setUp(self): - self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com") + self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com", mock.create_autospec(logging.Logger)) def test_cohort_download_success(self): cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py index 309a8ff..a739b9d 100644 --- a/tests/cohort/cohort_loader_test.py +++ b/tests/cohort/cohort_loader_test.py @@ -1,4 +1,6 @@ +import logging import unittest +from unittest import mock from unittest.mock import MagicMock from src.amplitude_experiment.cohort.cohort import Cohort @@ -9,7 +11,7 @@ class CohortLoaderTest(unittest.TestCase): def setUp(self): self.api = MagicMock() self.storage = InMemoryCohortStorage() - self.loader = CohortLoader(self.api, self.storage) + self.loader = CohortLoader(self.api, self.storage, mock.create_autospec(logging.Logger)) def test_load_success(self): self.api.get_cohort.side_effect = [ diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index c4c0c8f..c5052db 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -38,15 +38,15 @@ def test_start_throws_if_first_flag_config_load_fails(self): flag_config_storage = mock.Mock() cohort_storage = mock.Mock() cohort_storage.get_cohort_ids.return_value = set() - cohort_loader = CohortLoader(cohort_download_api, cohort_storage) logger = mock.create_autospec(logging.Logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, cohort_storage, + logger, cohort_loader, - logger # Pass the logger mock here ) flag_api.get_flag_configs.side_effect = RuntimeError("test") with self.assertRaises(RuntimeError): @@ -58,19 +58,19 @@ def test_start_throws_if_first_cohort_load_fails(self): flag_config_storage = mock.Mock() cohort_storage = mock.Mock() cohort_storage.get_cohort_ids.return_value = set() - cohort_loader = CohortLoader(cohort_download_api, cohort_storage) logger = mock.create_autospec(logging.Logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, cohort_storage, + logger, cohort_loader, - logger # Pass the logger mock here ) with patch.object(runner, '_delete_unused_cohorts'): flag_api.get_flag_configs.return_value = [self.flag] cohort_download_api.get_cohort.side_effect = RuntimeError("test") - with self.assertRaises(RuntimeError): + with self.assertRaises(Exception): runner.start() From 916e5c11f79beefea09df6e2816649db8117a13b Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 3 Jul 2024 11:44:55 -0700 Subject: [PATCH 24/44] Refactor cohort_loader update_storage_cohorts --- .../cohort/cohort_loader.py | 47 +++++++++++-------- .../deployment/deployment_runner.py | 6 ++- src/amplitude_experiment/exception.py | 21 ++++++++- src/amplitude_experiment/local/client.py | 2 +- tests/cohort/cohort_loader_test.py | 4 +- 5 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index d525085..2cf31e9 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -6,11 +6,11 @@ from .cohort import Cohort from .cohort_download_api import CohortDownloadApi from .cohort_storage import CohortStorage +from ..exception import CohortUpdateException class CohortLoader: - def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage, - logger: logging.Logger): + def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage): self.cohort_download_api = cohort_download_api self.cohort_storage = cohort_storage self.jobs: Dict[str, Future] = {} @@ -19,19 +19,11 @@ def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: Cohor max_workers=32, thread_name_prefix='CohortLoaderExecutor' ) - self.logger = logger def load_cohort(self, cohort_id: str) -> Future: with self.lock_jobs: if cohort_id not in self.jobs: - def task(): - try: - cohort = self.download_cohort(cohort_id) - self.cohort_storage.put_cohort(cohort) - except Exception as e: - raise e - - future = self.executor.submit(task) + future = self.executor.submit(self.__load_cohort_internal, cohort_id) future.add_done_callback(lambda f: self._remove_job(cohort_id)) self.jobs[cohort_id] = future return self.jobs[cohort_id] @@ -45,12 +37,27 @@ def download_cohort(self, cohort_id: str) -> Cohort: return self.cohort_download_api.get_cohort(cohort_id, cohort) def update_stored_cohorts(self) -> Future: - def task(): - futures = [self.load_cohort(cohort_id) for cohort_id in self.cohort_storage.get_cohort_ids()] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - self.logger.error(f"Error updating cohort: {e}") - - return self.executor.submit(task) + def update_task(): + errors = [] + try: + futures = [self.executor.submit(self.__load_cohort_internal, cohort_id) for cohort_id in self.cohort_storage.get_cohort_ids()] + + for future, cohort_id in zip(as_completed(futures), self.cohort_storage.get_cohort_ids()): + try: + future.result() + except Exception as e: + errors.append((cohort_id, e)) + except Exception as e: + errors.append(e) + + if errors: + raise CohortUpdateException(errors) + + return self.executor.submit(update_task) + + def __load_cohort_internal(self, cohort_id): + try: + cohort = self.download_cohort(cohort_id) + self.cohort_storage.put_cohort(cohort) + except Exception as e: + raise e diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index b273c0d..6141116 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -109,7 +109,11 @@ def __update_flag_configs(self): raise Exception(f"{error_count} cohort(s) failed to download:\n{error_messages}") def __update_cohorts(self): - self.cohort_loader.update_stored_cohorts().result() + try: + self.cohort_loader.update_stored_cohorts().result() + except Exception as e: + self.logger.error(f"Error while updating cohorts: {e}") + def _delete_unused_cohorts(self): flag_cohort_ids = set() diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 15d9ac7..51d7534 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -1,3 +1,6 @@ +from typing import List, Tuple + + class FetchException(Exception): def __init__(self, status_code, message): super().__init__(message) @@ -8,7 +11,6 @@ class CohortNotModifiedException(Exception): def __init__(self, message): super().__init__(message) - class CohortTooLargeException(Exception): def __init__(self, message): super().__init__(message) @@ -18,3 +20,20 @@ class HTTPErrorResponseException(Exception): def __init__(self, status_code, message): super().__init__(message) self.status_code = status_code + +class CohortUpdateException(Exception): + def __init__(self, errors): + self.errors = errors + super().__init__(self.__str__()) + + def __str__(self): + error_messages = [] + for item in self.errors: + if isinstance(item, tuple) and len(item) == 2: + cohort_id, error = item + error_messages.append(f"Cohort {cohort_id}: {error}") + else: + error_messages.append(str(item)) + return f"One or more cohorts failed to update:\n" + "\n".join(error_messages) + + diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 657c415..0976ac8 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -67,7 +67,7 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): self.config.cohort_sync_config.cohort_server_url, self.logger) - cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage, self.logger) + cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, self.config.flag_config_poller_request_timeout_millis) self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py index a739b9d..309a8ff 100644 --- a/tests/cohort/cohort_loader_test.py +++ b/tests/cohort/cohort_loader_test.py @@ -1,6 +1,4 @@ -import logging import unittest -from unittest import mock from unittest.mock import MagicMock from src.amplitude_experiment.cohort.cohort import Cohort @@ -11,7 +9,7 @@ class CohortLoaderTest(unittest.TestCase): def setUp(self): self.api = MagicMock() self.storage = InMemoryCohortStorage() - self.loader = CohortLoader(self.api, self.storage, mock.create_autospec(logging.Logger)) + self.loader = CohortLoader(self.api, self.storage) def test_load_success(self): self.api.get_cohort.side_effect = [ From 013ffc94a394acc7b9e0d77c1d8d93bc03db47ac Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 3 Jul 2024 11:46:54 -0700 Subject: [PATCH 25/44] fix lint --- src/amplitude_experiment/exception.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 51d7534..7c8fb12 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -1,6 +1,3 @@ -from typing import List, Tuple - - class FetchException(Exception): def __init__(self, status_code, message): super().__init__(message) @@ -11,6 +8,7 @@ class CohortNotModifiedException(Exception): def __init__(self, message): super().__init__(message) + class CohortTooLargeException(Exception): def __init__(self, message): super().__init__(message) @@ -21,6 +19,7 @@ def __init__(self, status_code, message): super().__init__(message) self.status_code = status_code + class CohortUpdateException(Exception): def __init__(self, errors): self.errors = errors @@ -35,5 +34,3 @@ def __str__(self): else: error_messages.append(str(item)) return f"One or more cohorts failed to update:\n" + "\n".join(error_messages) - - From 57e1cc229fe1b66a64c19e17015e1cd386c0398d Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 22 Jul 2024 16:40:10 -0700 Subject: [PATCH 26/44] remove unnecessary import --- src/amplitude_experiment/deployment/deployment_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 6141116..831754e 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Set +from typing import Optional import threading from ..local.config import LocalEvaluationConfig @@ -114,7 +114,6 @@ def __update_cohorts(self): except Exception as e: self.logger.error(f"Error while updating cohorts: {e}") - def _delete_unused_cohorts(self): flag_cohort_ids = set() for flag in self.flag_config_storage.get_flag_configs().values(): From 93d2f15cf082380ed31ff6e7fb280babeef01017 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 23 Jul 2024 11:58:26 -0700 Subject: [PATCH 27/44] update test.yml --- .github/workflows/test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3072f54..ccd6e97 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,8 +19,7 @@ jobs: cache: 'pip' - name: Install requirements - run: pip install -r requirements.txt - pip install -r requirements-dev.txt + run: pip install -r requirements.txt && pip install -r requirements-dev.txt - name: Unit Test run: python -m unittest discover -s ./tests -p '*_test.py' From 5b30cb77b79228d37c14b73b94aa30d1377b73af Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 24 Jul 2024 14:55:30 -0700 Subject: [PATCH 28/44] add client cohort ci tests --- .github/workflows/test-arm.yml | 5 +++++ .github/workflows/test.yml | 5 +++++ README.md | 9 ++++++++ requirements-dev.txt | 1 + tests/local/client_eu_test.py | 39 ++++++++++++++++++++++++++++++++++ tests/local/client_test.py | 28 ++++++++++++++++++++++-- 6 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 tests/local/client_eu_test.py diff --git a/.github/workflows/test-arm.yml b/.github/workflows/test-arm.yml index 30eac88..6166b17 100644 --- a/.github/workflows/test-arm.yml +++ b/.github/workflows/test-arm.yml @@ -23,3 +23,8 @@ jobs: pip install -r requirements.txt pip install -r requirements-dev.txt python3 -m unittest discover -s ./tests -p '*_test.py' + env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ccd6e97..ba29fa8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,3 +23,8 @@ jobs: - name: Unit Test run: python -m unittest discover -s ./tests -p '*_test.py' + env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} diff --git a/README.md b/README.md index cc9a36e..5d8bb3d 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,15 @@ user = User( variants = experiment.evaluate(user) ``` +# Running unit tests suite +To setup for running test on local, create a `.env` file with following +contents, and replace `{API_KEY}` and `{SECRET_KEY}` (or `{EU_API_KEY}` and `{EU_SECRET_KEY}` for EU data center) for the project in test: + +``` +API_KEY={API_KEY} +SECRET_KEY={SECRET_KEY} +``` + ## More Information Please visit our :100:[Developer Center](https://www.docs.developers.amplitude.com/experiment/sdks/python-sdk/) for more instructions on using our the SDK. diff --git a/requirements-dev.txt b/requirements-dev.txt index 4f5c19d..2a4cf18 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1 +1,2 @@ parameterized~=0.9.0 +python-dotenv~=1.0.1 diff --git a/tests/local/client_eu_test.py b/tests/local/client_eu_test.py new file mode 100644 index 0000000..4037902 --- /dev/null +++ b/tests/local/client_eu_test.py @@ -0,0 +1,39 @@ +import unittest +from src.amplitude_experiment import LocalEvaluationClient, LocalEvaluationConfig, User, Variant +from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig +from src.amplitude_experiment.local.config import ServerZone +from dotenv import load_dotenv +import os + +SERVER_API_KEY = 'server-Qlp7XiSu6JtP2S3JzA95PnP27duZgQCF' + + +class LocalEvaluationClientTestCase(unittest.TestCase): + _local_evaluation_client: LocalEvaluationClient = None + + @classmethod + def setUpClass(cls) -> None: + load_dotenv() + api_key = os.getenv('EU_API_KEY') + secret_key = os.getenv('EU_SECRET_KEY') + cohort_sync_config = CohortSyncConfig(api_key=api_key, + secret_key=secret_key, + cohort_request_delay_millis=100) + cls._local_evaluation_client = ( + LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, server_zone=ServerZone.EU, + cohort_sync_config=cohort_sync_config))) + cls._local_evaluation_client.start() + + @classmethod + def tearDownClass(cls) -> None: + cls._local_evaluation_client.stop() + + def test_evaluate_with_cohort_eu(self): + user = User(user_id='1', device_id='0') + variant = self._local_evaluation_client.evaluate_v2(user).get('sdk-local-evaluation-user-cohort') + expected_variant = Variant(key='on', value='on') + self.assertEqual(expected_variant, variant) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/local/client_test.py b/tests/local/client_test.py index 00293db..b09e612 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -1,7 +1,11 @@ import unittest from src.amplitude_experiment import LocalEvaluationClient, LocalEvaluationConfig, User, Variant +from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig +from src.amplitude_experiment.local.config import ServerZone +from dotenv import load_dotenv +import os -API_KEY = 'server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz' +SERVER_API_KEY = 'server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz' test_user = User(user_id='test_user') test_user_2 = User(user_id='user_id', device_id='device_id') @@ -11,7 +15,15 @@ class LocalEvaluationClientTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls._local_evaluation_client = LocalEvaluationClient(API_KEY, LocalEvaluationConfig(debug=False)) + load_dotenv() + api_key = os.getenv('API_KEY') + secret_key = os.getenv('SECRET_KEY') + cohort_sync_config = CohortSyncConfig(api_key=api_key, + secret_key=secret_key, + cohort_request_delay_millis=100) + cls._local_evaluation_client = ( + LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, + cohort_sync_config=cohort_sync_config))) cls._local_evaluation_client.start() @classmethod @@ -56,6 +68,18 @@ def test_evaluate_with_dependencies_variant_holdout(self): expected_variant = None self.assertEqual(expected_variant, variants.get('sdk-local-evaluation-ci-test-holdout')) + def test_evaluate_with_cohort(self): + user = User(user_id='12345', device_id='device_id') + variant = self._local_evaluation_client.evaluate(user).get('sdk-local-evaluation-user-cohort-ci-test') + expected_variant = Variant(key='on', value='on') + self.assertEqual(expected_variant, variant) + + def test_evaluate_with_group_cohort(self): + user = User(user_id='12345', device_id='device_id', groups={'org id': ['1']}) + variant = self._local_evaluation_client.evaluate(user).get('sdk-local-evaluation-group-cohort-ci-test') + expected_variant = Variant(key='on', value='on') + self.assertEqual(expected_variant, variant) + if __name__ == '__main__': unittest.main() From 12267a0a3e578c99003d1a21358d9c1af253620d Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 24 Jul 2024 14:58:20 -0700 Subject: [PATCH 29/44] update requirements-dev dotenv version --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 2a4cf18..8e4ee4b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,2 @@ parameterized~=0.9.0 -python-dotenv~=1.0.1 +python-dotenv~=0.21.1 From 03b9081155fd5ae75d635ee3ea1139e4ad1ea19e Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 24 Jul 2024 15:06:25 -0700 Subject: [PATCH 30/44] debug env vars --- tests/local/client_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/local/client_test.py b/tests/local/client_test.py index b09e612..d9dc491 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -1,7 +1,6 @@ import unittest from src.amplitude_experiment import LocalEvaluationClient, LocalEvaluationConfig, User, Variant from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig -from src.amplitude_experiment.local.config import ServerZone from dotenv import load_dotenv import os @@ -18,6 +17,8 @@ def setUpClass(cls) -> None: load_dotenv() api_key = os.getenv('API_KEY') secret_key = os.getenv('SECRET_KEY') + if not api_key or not secret_key: + raise ValueError("API_KEY or SECRET_KEY is not set in the environment variables") cohort_sync_config = CohortSyncConfig(api_key=api_key, secret_key=secret_key, cohort_request_delay_millis=100) From 832a00c8ae56e68c032d581d621587be69a54890 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 24 Jul 2024 15:11:21 -0700 Subject: [PATCH 31/44] test yml set env vars --- .github/workflows/test-arm.yml | 11 ++++++----- .github/workflows/test.yml | 11 ++++++----- tests/local/client_test.py | 2 -- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test-arm.yml b/.github/workflows/test-arm.yml index 6166b17..2a40d8e 100644 --- a/.github/workflows/test-arm.yml +++ b/.github/workflows/test-arm.yml @@ -1,6 +1,12 @@ name: Unit Test on Arm on: [pull_request] +env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} + jobs: aarch_job: runs-on: ubuntu-latest @@ -23,8 +29,3 @@ jobs: pip install -r requirements.txt pip install -r requirements-dev.txt python3 -m unittest discover -s ./tests -p '*_test.py' - env: - API_KEY: ${{ secrets.API_KEY }} - SECRET_KEY: ${{ secrets.SECRET_KEY }} - EU_API_KEY: ${{ secrets.EU_API_KEY }} - EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ba29fa8..c1fb907 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,6 +2,12 @@ name: Unit Test on: [pull_request] +env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} + jobs: test: runs-on: ubuntu-latest @@ -23,8 +29,3 @@ jobs: - name: Unit Test run: python -m unittest discover -s ./tests -p '*_test.py' - env: - API_KEY: ${{ secrets.API_KEY }} - SECRET_KEY: ${{ secrets.SECRET_KEY }} - EU_API_KEY: ${{ secrets.EU_API_KEY }} - EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} diff --git a/tests/local/client_test.py b/tests/local/client_test.py index d9dc491..9d41f5c 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -17,8 +17,6 @@ def setUpClass(cls) -> None: load_dotenv() api_key = os.getenv('API_KEY') secret_key = os.getenv('SECRET_KEY') - if not api_key or not secret_key: - raise ValueError("API_KEY or SECRET_KEY is not set in the environment variables") cohort_sync_config = CohortSyncConfig(api_key=api_key, secret_key=secret_key, cohort_request_delay_millis=100) From 135a286b78267f7d016b98683d38779d4d8b4b75 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 24 Jul 2024 15:26:43 -0700 Subject: [PATCH 32/44] test cases use os.environ for secrets --- .github/workflows/test-arm.yml | 11 +++++------ .github/workflows/test.yml | 11 +++++------ tests/local/client_eu_test.py | 4 ++-- tests/local/client_test.py | 4 ++-- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test-arm.yml b/.github/workflows/test-arm.yml index 2a40d8e..dd42938 100644 --- a/.github/workflows/test-arm.yml +++ b/.github/workflows/test-arm.yml @@ -1,16 +1,15 @@ name: Unit Test on Arm on: [pull_request] -env: - API_KEY: ${{ secrets.API_KEY }} - SECRET_KEY: ${{ secrets.SECRET_KEY }} - EU_API_KEY: ${{ secrets.EU_API_KEY }} - EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} - jobs: aarch_job: runs-on: ubuntu-latest name: Test on ubuntu aarch64 + env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} steps: - uses: actions/checkout@v3 - uses: uraimo/run-on-arch-action@v2 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c1fb907..647da7a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,12 +2,6 @@ name: Unit Test on: [pull_request] -env: - API_KEY: ${{ secrets.API_KEY }} - SECRET_KEY: ${{ secrets.SECRET_KEY }} - EU_API_KEY: ${{ secrets.EU_API_KEY }} - EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} - jobs: test: runs-on: ubuntu-latest @@ -28,4 +22,9 @@ jobs: run: pip install -r requirements.txt && pip install -r requirements-dev.txt - name: Unit Test + env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} run: python -m unittest discover -s ./tests -p '*_test.py' diff --git a/tests/local/client_eu_test.py b/tests/local/client_eu_test.py index 4037902..84faf94 100644 --- a/tests/local/client_eu_test.py +++ b/tests/local/client_eu_test.py @@ -14,8 +14,8 @@ class LocalEvaluationClientTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: load_dotenv() - api_key = os.getenv('EU_API_KEY') - secret_key = os.getenv('EU_SECRET_KEY') + api_key = os.environ['EU_API_KEY'] + secret_key = os.environ['EU_SECRET_KEY'] cohort_sync_config = CohortSyncConfig(api_key=api_key, secret_key=secret_key, cohort_request_delay_millis=100) diff --git a/tests/local/client_test.py b/tests/local/client_test.py index 9d41f5c..d76e94b 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -15,8 +15,8 @@ class LocalEvaluationClientTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: load_dotenv() - api_key = os.getenv('API_KEY') - secret_key = os.getenv('SECRET_KEY') + api_key = os.environ['API_KEY'] + secret_key = os.environ['SECRET_KEY'] cohort_sync_config = CohortSyncConfig(api_key=api_key, secret_key=secret_key, cohort_request_delay_millis=100) From e1ff4a2e5f250e3fd9fa5a35a57348b0ccee59ab Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 25 Jul 2024 13:33:49 -0700 Subject: [PATCH 33/44] test-arm.yml env syntax fix --- .github/workflows/test-arm.yml | 21 ++++++++++++--------- .github/workflows/test.yml | 1 + 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test-arm.yml b/.github/workflows/test-arm.yml index dd42938..7ac803a 100644 --- a/.github/workflows/test-arm.yml +++ b/.github/workflows/test-arm.yml @@ -4,18 +4,21 @@ on: [pull_request] jobs: aarch_job: runs-on: ubuntu-latest - name: Test on ubuntu aarch64 - env: - API_KEY: ${{ secrets.API_KEY }} - SECRET_KEY: ${{ secrets.SECRET_KEY }} - EU_API_KEY: ${{ secrets.EU_API_KEY }} - EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} + environment: Unit Test + name: Test on Ubuntu aarch64 steps: - - uses: actions/checkout@v3 - - uses: uraimo/run-on-arch-action@v2 - name: Run Unit Test + - name: Checkout source code + uses: actions/checkout@v3 + + - name: Set up and run unit test on aarch64 + uses: uraimo/run-on-arch-action@v2 id: runcmd with: + env: | + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} arch: aarch64 distro: ubuntu20.04 githubToken: ${{ github.token }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 647da7a..d78efe6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,6 +5,7 @@ on: [pull_request] jobs: test: runs-on: ubuntu-latest + environment: Unit Test strategy: matrix: python-version: [ "3.7" ] From 9d6d62f5fb243a2623498d657a6662220076c316 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 25 Jul 2024 13:58:29 -0700 Subject: [PATCH 34/44] update client tests --- tests/local/client_eu_test.py | 14 ++++++++++---- tests/local/client_test.py | 28 ++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/tests/local/client_eu_test.py b/tests/local/client_eu_test.py index 84faf94..ed8b3af 100644 --- a/tests/local/client_eu_test.py +++ b/tests/local/client_eu_test.py @@ -29,10 +29,16 @@ def tearDownClass(cls) -> None: cls._local_evaluation_client.stop() def test_evaluate_with_cohort_eu(self): - user = User(user_id='1', device_id='0') - variant = self._local_evaluation_client.evaluate_v2(user).get('sdk-local-evaluation-user-cohort') - expected_variant = Variant(key='on', value='on') - self.assertEqual(expected_variant, variant) + targeted_user = User(user_id='1', device_id='0') + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user) + .get('sdk-local-evaluation-user-cohort')) + expected_on_variant = Variant(key='on', value='on') + self.assertEqual(expected_on_variant, targeted_variant) + non_targeted_user = User(user_id='not_targeted') + non_targeted_variant = (self._local_evaluation_client.evaluate_v2(non_targeted_user) + .get('sdk-local-evaluation-user-cohort')) + expected_off_variant = Variant(key='off') + self.assertEqual(non_targeted_variant, expected_off_variant) if __name__ == '__main__': diff --git a/tests/local/client_test.py b/tests/local/client_test.py index d76e94b..3a1468d 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -68,16 +68,28 @@ def test_evaluate_with_dependencies_variant_holdout(self): self.assertEqual(expected_variant, variants.get('sdk-local-evaluation-ci-test-holdout')) def test_evaluate_with_cohort(self): - user = User(user_id='12345', device_id='device_id') - variant = self._local_evaluation_client.evaluate(user).get('sdk-local-evaluation-user-cohort-ci-test') - expected_variant = Variant(key='on', value='on') - self.assertEqual(expected_variant, variant) + targeted_user = User(user_id='12345', device_id='device_id') + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user) + .get('sdk-local-evaluation-user-cohort-ci-test')) + expected_on_variant = Variant(key='on', value='on') + self.assertEqual(expected_on_variant, targeted_variant) + non_targeted_user = User(user_id='not_targeted') + non_targeted_variant = (self._local_evaluation_client.evaluate_v2(non_targeted_user) + .get('sdk-local-evaluation-user-cohort-ci-test')) + expected_off_variant = Variant(key='off') + self.assertEqual(expected_off_variant, non_targeted_variant) def test_evaluate_with_group_cohort(self): - user = User(user_id='12345', device_id='device_id', groups={'org id': ['1']}) - variant = self._local_evaluation_client.evaluate(user).get('sdk-local-evaluation-group-cohort-ci-test') - expected_variant = Variant(key='on', value='on') - self.assertEqual(expected_variant, variant) + targeted_user = User(user_id='12345', device_id='device_id', groups={'org id': ['1']}) + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user) + .get('sdk-local-evaluation-group-cohort-ci-test')) + expected_on_variant = Variant(key='on', value='on') + self.assertEqual(expected_on_variant, targeted_variant) + non_targeted_user = User(user_id='12345', device_id='device_id', groups={'org id': ['not_targeted']}) + non_targeted_variant = (self._local_evaluation_client.evaluate_v2(non_targeted_user) + .get('sdk-local-evaluation-group-cohort-ci-test')) + expected_off_variant = Variant(key='off') + self.assertEqual(expected_off_variant, non_targeted_variant) if __name__ == '__main__': From 7fddc9600fb83befc44d7591371ebb2a1c5c6bc8 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 30 Jul 2024 16:48:35 -0700 Subject: [PATCH 35/44] cohort not modified should not throw exception --- .../cohort/cohort_download_api.py | 11 ++++++----- src/amplitude_experiment/cohort/cohort_loader.py | 4 +++- .../deployment/deployment_runner.py | 10 +++++----- src/amplitude_experiment/exception.py | 5 ----- tests/cohort/cohort_download_api_test.py | 6 +++--- tests/deployment/deployment_runner_test.py | 4 ++-- 6 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 5616917..6f9aae0 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -7,7 +7,7 @@ from .cohort import Cohort from ..connection_pool import HTTPConnectionPool -from ..exception import HTTPErrorResponseException, CohortTooLargeException, CohortNotModifiedException +from ..exception import HTTPErrorResponseException, CohortTooLargeException class CohortDownloadApi: @@ -28,7 +28,7 @@ def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_r self.logger = logger self.__setup_connection_pool() - def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: + def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None: self.logger.debug(f"getCohortMembers({cohort_id}): start") errors = 0 while True: @@ -48,9 +48,10 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: group_type=cohort_info['groupType'], ) elif response.status == 204: - raise CohortNotModifiedException(f"Cohort not modified: {response.status}") + self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified" ) + return elif response.status == 413: - raise CohortTooLargeException(f"Cohort exceeds max cohort size: {response.status}") + raise CohortTooLargeException(f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}") elif response.status != 202: raise HTTPErrorResponseException(response.status, f"Unexpected response code: {response.status}") @@ -58,7 +59,7 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429): errors += 1 self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") - if errors >= 3 or isinstance(e, CohortNotModifiedException) or isinstance(e, CohortTooLargeException): + if errors >= 3 or isinstance(e, CohortTooLargeException): raise e time.sleep(self.cohort_request_delay_millis/1000) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index 2cf31e9..e951cd9 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -58,6 +58,8 @@ def update_task(): def __load_cohort_internal(self, cohort_id): try: cohort = self.download_cohort(cohort_id) - self.cohort_storage.put_cohort(cohort) + # None is returned when cohort is not modified + if cohort is not None: + self.cohort_storage.put_cohort(cohort) except Exception as e: raise e diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 831754e..38af235 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -47,13 +47,13 @@ def __periodic_flag_update(self): try: self.__update_flag_configs() except Exception as e: - self.logger.error(f"Error while updating flags: {e}") + self.logger.warning(f"Error while updating flags: {e}") def __update_flag_configs(self): try: flag_configs = self.flag_config_api.get_flag_configs() except Exception as e: - self.logger.error(f'Failed to fetch flag configs: {e}') + self.logger.warning(f'Failed to fetch flag configs: {e}') raise e flag_keys = {flag['key'] for flag in flag_configs} @@ -79,7 +79,7 @@ def __update_flag_configs(self): self.cohort_loader.load_cohort(cohort_id).result() except Exception as e: cohort_download_errors.append((cohort_id, str(e))) - self.logger.error(f"Download cohort {cohort_id} failed: {e}") + self.logger.warning(f"Download cohort {cohort_id} failed: {e}") # get updated set of cohort ids updated_cohort_ids = self.cohort_storage.get_cohort_ids() @@ -94,7 +94,7 @@ def __update_flag_configs(self): self.flag_config_storage.put_flag_config(flag_config) self.logger.debug(f"Putting flag {flag_config['key']}") else: - self.logger.error(f"Flag {flag_config['key']} not updated because " + self.logger.warning(f"Flag {flag_config['key']} not updated because " f"not all required cohorts could be loaded") failed_flag_count += 1 @@ -112,7 +112,7 @@ def __update_cohorts(self): try: self.cohort_loader.update_stored_cohorts().result() except Exception as e: - self.logger.error(f"Error while updating cohorts: {e}") + self.logger.warning(f"Error while updating cohorts: {e}") def _delete_unused_cohorts(self): flag_cohort_ids = set() diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 7c8fb12..7281a0e 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -4,11 +4,6 @@ def __init__(self, status_code, message): self.status_code = status_code -class CohortNotModifiedException(Exception): - def __init__(self, message): - super().__init__(message) - - class CohortTooLargeException(Exception): def __init__(self, message): super().__init__(message) diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index 3af9c8e..313f29a 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch from src.amplitude_experiment.cohort.cohort import Cohort from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi -from src.amplitude_experiment.exception import CohortTooLargeException, CohortNotModifiedException +from src.amplitude_experiment.exception import CohortTooLargeException def response(code: int, body: dict = None): @@ -95,9 +95,9 @@ def test_cohort_not_modified_exception(self): not_modified_response = response(204) with patch.object(self.api, '_get_cohort_members_request', return_value=not_modified_response): + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(None, result_cohort) - with self.assertRaises(CohortNotModifiedException): - self.api.get_cohort("1234", cohort) if __name__ == '__main__': diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index c5052db..28a375b 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -39,7 +39,7 @@ def test_start_throws_if_first_flag_config_load_fails(self): cohort_storage = mock.Mock() cohort_storage.get_cohort_ids.return_value = set() logger = mock.create_autospec(logging.Logger) - cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, @@ -59,7 +59,7 @@ def test_start_throws_if_first_cohort_load_fails(self): cohort_storage = mock.Mock() cohort_storage.get_cohort_ids.return_value = set() logger = mock.create_autospec(logging.Logger) - cohort_loader = CohortLoader(cohort_download_api, cohort_storage, logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( LocalEvaluationConfig(), flag_api, flag_config_storage, From 8cfb1289ea346b5b857f58b31dab19b52badd233 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 30 Jul 2024 17:09:20 -0700 Subject: [PATCH 36/44] nit: update test name --- tests/cohort/cohort_download_api_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index 313f29a..e8b856e 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -90,7 +90,7 @@ def test_cohort_size_too_large(self): with self.assertRaises(CohortTooLargeException): self.api.get_cohort("1234", cohort) - def test_cohort_not_modified_exception(self): + def test_cohort_not_modified(self): cohort = Cohort(id="1234", last_modified=1000, size=1, member_ids=set()) not_modified_response = response(204) From c71e0c78139162c3fd02ae8769d9e7b54b25387c Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 1 Aug 2024 14:24:11 -0700 Subject: [PATCH 37/44] do not throw exception upon start() if cohort download fails, log warning if evaluating flag without required cohorts --- .../deployment/deployment_runner.py | 24 +++++-------------- src/amplitude_experiment/exception.py | 5 ++++ src/amplitude_experiment/local/client.py | 24 ++++++++++++++----- tests/deployment/deployment_runner_test.py | 10 ++++++-- tests/local/client_test.py | 23 ++++++++++++++++-- 5 files changed, 58 insertions(+), 28 deletions(-) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 38af235..6d74d0c 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -84,29 +84,17 @@ def __update_flag_configs(self): # get updated set of cohort ids updated_cohort_ids = self.cohort_storage.get_cohort_ids() # iterate through new flag configs and check if their required cohorts exist - failed_flag_count = 0 for flag_config in flag_configs: cohort_ids = get_all_cohort_ids_from_flag(flag_config) - if not cohort_ids or not self.cohort_loader: - self.flag_config_storage.put_flag_config(flag_config) - self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") - elif cohort_ids.issubset(updated_cohort_ids): - self.flag_config_storage.put_flag_config(flag_config) - self.logger.debug(f"Putting flag {flag_config['key']}") - else: - self.logger.warning(f"Flag {flag_config['key']} not updated because " - f"not all required cohorts could be loaded") - failed_flag_count += 1 + self.logger.debug(f"Putting non-cohort flag {flag_config['key']} with cohorts {cohort_ids}") + self.flag_config_storage.put_flag_config(flag_config) + missing_cohorts = cohort_ids - updated_cohort_ids + if missing_cohorts: + self.logger.warning(f"Flag {flag_config['key']} - failed to load cohorts: {missing_cohorts}") # delete unused cohorts self._delete_unused_cohorts() - self.logger.debug(f"Refreshed {len(flag_configs) - failed_flag_count} flag configs.") - - # if there are any download errors, raise an aggregated exception - if cohort_download_errors: - error_count = len(cohort_download_errors) - error_messages = "\n".join([f"Cohort {cohort_id}: {error}" for cohort_id, error in cohort_download_errors]) - raise Exception(f"{error_count} cohort(s) failed to download:\n{error_messages}") + self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") def __update_cohorts(self): try: diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 7281a0e..ee0a712 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -9,6 +9,11 @@ def __init__(self, message): super().__init__(message) +class EvaluationCohortsNotInStorageException(Exception): + def __init__(self, message): + super().__init__(message) + + class HTTPErrorResponseException(Exception): def __init__(self, status_code, message): super().__init__(message) diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 0976ac8..6810e9a 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -17,14 +17,12 @@ from ..flag.flag_config_storage import InMemoryFlagConfigStorage from ..user import User from ..connection_pool import HTTPConnectionPool -from .poller import Poller from .evaluation.evaluation import evaluate from ..util import deprecated -from ..util.flag_config import get_grouped_cohort_ids_from_flags +from ..util.flag_config import get_grouped_cohort_ids_from_flags, get_all_cohort_ids_from_flag from ..util.user import user_to_evaluation_context from ..util.variant import evaluation_variants_json_to_variants from ..variant import Variant -from ..version import __version__ class LocalEvaluationClient: @@ -102,8 +100,13 @@ def evaluate_v2(self, user: User, flag_keys: Set[str] = None) -> Dict[str, Varia sorted_flags = topological_sort(flag_configs, flag_keys) if not sorted_flags: return {} - enriched_user = self.enrich_user(user, flag_configs) - context = user_to_evaluation_context(enriched_user) + + # Check if all required cohorts are in storage, if not log a warning + self._required_cohorts_in_storage(sorted_flags) + if self.config.cohort_sync_config: + user = self._enrich_user_with_cohorts(user, flag_configs) + + context = user_to_evaluation_context(user) flags_json = json.dumps(sorted_flags) context_json = json.dumps(context) result_json = evaluate(flags_json, context_json) @@ -168,7 +171,16 @@ def is_default_variant(variant: Variant) -> bool: return {key: variant for key, variant in variants.items() if not is_default_variant(variant)} - def enrich_user(self, user: User, flag_configs: Dict) -> User: + def _required_cohorts_in_storage(self, flag_configs: List) -> None: + stored_cohort_ids = self.cohort_storage.get_cohort_ids() + for flag in flag_configs: + flag_cohort_ids = get_all_cohort_ids_from_flag(flag) + missing_cohorts = flag_cohort_ids - stored_cohort_ids + if self.config.cohort_sync_config and missing_cohorts: + self.logger.warning(f"Evaluating flag {flag['key']} with cohorts {flag_cohort_ids} without " + f"cohort syncing configured") + + def _enrich_user_with_cohorts(self, user: User, flag_configs: Dict) -> User: grouped_cohort_ids = get_grouped_cohort_ids_from_flags(list(flag_configs.values())) if USER_GROUP_TYPE in grouped_cohort_ids: diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index 28a375b..447c190 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -52,7 +52,7 @@ def test_start_throws_if_first_flag_config_load_fails(self): with self.assertRaises(RuntimeError): runner.start() - def test_start_throws_if_first_cohort_load_fails(self): + def test_start_does_not_throw_if_cohort_load_fails(self): flag_api = mock.create_autospec(FlagConfigApi) cohort_download_api = mock.Mock() flag_config_storage = mock.Mock() @@ -67,11 +67,17 @@ def test_start_throws_if_first_cohort_load_fails(self): logger, cohort_loader, ) + + # Mock methods as needed with patch.object(runner, '_delete_unused_cohorts'): flag_api.get_flag_configs.return_value = [self.flag] cohort_download_api.get_cohort.side_effect = RuntimeError("test") - with self.assertRaises(Exception): + + # Simply call the method and let the test pass if no exception is raised + try: runner.start() + except Exception as e: + self.fail(f"runner.start() raised an exception unexpectedly: {e}") if __name__ == '__main__': diff --git a/tests/local/client_test.py b/tests/local/client_test.py index 3a1468d..138eb97 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -1,9 +1,13 @@ +import re import unittest +from unittest import mock + from src.amplitude_experiment import LocalEvaluationClient, LocalEvaluationConfig, User, Variant from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig from dotenv import load_dotenv import os + SERVER_API_KEY = 'server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz' test_user = User(user_id='test_user') test_user_2 = User(user_id='user_id', device_id='device_id') @@ -69,7 +73,8 @@ def test_evaluate_with_dependencies_variant_holdout(self): def test_evaluate_with_cohort(self): targeted_user = User(user_id='12345', device_id='device_id') - targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user) + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user, + {'sdk-local-evaluation-user-cohort-ci-test'}) .get('sdk-local-evaluation-user-cohort-ci-test')) expected_on_variant = Variant(key='on', value='on') self.assertEqual(expected_on_variant, targeted_variant) @@ -81,7 +86,8 @@ def test_evaluate_with_cohort(self): def test_evaluate_with_group_cohort(self): targeted_user = User(user_id='12345', device_id='device_id', groups={'org id': ['1']}) - targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user) + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user, + {'sdk-local-evaluation-group-cohort-ci-test'}) .get('sdk-local-evaluation-group-cohort-ci-test')) expected_on_variant = Variant(key='on', value='on') self.assertEqual(expected_on_variant, targeted_variant) @@ -91,6 +97,19 @@ def test_evaluate_with_group_cohort(self): expected_off_variant = Variant(key='off') self.assertEqual(expected_off_variant, non_targeted_variant) + def test_evaluation_cohorts_not_in_storage_exception(self): + with mock.patch.object(self._local_evaluation_client.cohort_storage, 'put_cohort', return_value=None): + self._local_evaluation_client.cohort_storage.get_cohort_ids = mock.Mock(return_value=set()) + targeted_user = User(user_id='12345') + + with self.assertLogs(self._local_evaluation_client.logger, level='WARNING') as log: + self._local_evaluation_client.evaluate_v2(targeted_user, {'sdk-local-evaluation-user-cohort-ci-test'}) + log_message = ( + "WARNING:Amplitude:Evaluating flag sdk-local-evaluation-user-cohort-ci-test with cohorts " + "{.*} without cohort syncing configured" + ) + self.assertTrue(any(re.match(log_message, message) for message in log.output)) + if __name__ == '__main__': unittest.main() From 85c6cf3b985de6495cd064c299ad91d07232d7f2 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 1 Aug 2024 14:40:11 -0700 Subject: [PATCH 38/44] fix deployment runner logging --- src/amplitude_experiment/deployment/deployment_runner.py | 2 +- src/amplitude_experiment/exception.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index 6d74d0c..e826aed 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -86,7 +86,7 @@ def __update_flag_configs(self): # iterate through new flag configs and check if their required cohorts exist for flag_config in flag_configs: cohort_ids = get_all_cohort_ids_from_flag(flag_config) - self.logger.debug(f"Putting non-cohort flag {flag_config['key']} with cohorts {cohort_ids}") + self.logger.debug(f"Storing flag {flag_config['key']}") self.flag_config_storage.put_flag_config(flag_config) missing_cohorts = cohort_ids - updated_cohort_ids if missing_cohorts: diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index ee0a712..7281a0e 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -9,11 +9,6 @@ def __init__(self, message): super().__init__(message) -class EvaluationCohortsNotInStorageException(Exception): - def __init__(self, message): - super().__init__(message) - - class HTTPErrorResponseException(Exception): def __init__(self, status_code, message): super().__init__(message) From 646dd5a06b2469ee6c1905929bd3d63f34e722ce Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 1 Aug 2024 14:47:08 -0700 Subject: [PATCH 39/44] nit: fix test name --- tests/local/client_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/local/client_test.py b/tests/local/client_test.py index 138eb97..93a6571 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -97,7 +97,7 @@ def test_evaluate_with_group_cohort(self): expected_off_variant = Variant(key='off') self.assertEqual(expected_off_variant, non_targeted_variant) - def test_evaluation_cohorts_not_in_storage_exception(self): + def test_evaluation_cohorts_not_in_storage(self): with mock.patch.object(self._local_evaluation_client.cohort_storage, 'put_cohort', return_value=None): self._local_evaluation_client.cohort_storage.get_cohort_ids = mock.Mock(return_value=set()) targeted_user = User(user_id='12345') From 9864e4630d4f4288a3dcb0030661926dea38f0aa Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Thu, 1 Aug 2024 15:41:29 -0700 Subject: [PATCH 40/44] update error log and test --- src/amplitude_experiment/local/client.py | 12 +++++++++--- tests/local/client_test.py | 6 +++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 6810e9a..147de5d 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -176,9 +176,15 @@ def _required_cohorts_in_storage(self, flag_configs: List) -> None: for flag in flag_configs: flag_cohort_ids = get_all_cohort_ids_from_flag(flag) missing_cohorts = flag_cohort_ids - stored_cohort_ids - if self.config.cohort_sync_config and missing_cohorts: - self.logger.warning(f"Evaluating flag {flag['key']} with cohorts {flag_cohort_ids} without " - f"cohort syncing configured") + if missing_cohorts: + message = ( + f"Evaluating flag {flag['key']} dependent on cohorts {flag_cohort_ids} " + f"without {missing_cohorts} in storage" + if self.config.cohort_sync_config + else f"Evaluating flag {flag['key']} dependent on cohorts {flag_cohort_ids} without " + f"cohort syncing configured" + ) + self.logger.warning(message) def _enrich_user_with_cohorts(self, user: User, flag_configs: Dict) -> User: grouped_cohort_ids = get_grouped_cohort_ids_from_flags(list(flag_configs.values())) diff --git a/tests/local/client_test.py b/tests/local/client_test.py index 93a6571..87085c6 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -97,7 +97,7 @@ def test_evaluate_with_group_cohort(self): expected_off_variant = Variant(key='off') self.assertEqual(expected_off_variant, non_targeted_variant) - def test_evaluation_cohorts_not_in_storage(self): + def test_evaluation_cohorts_not_in_storage_with_sync_config(self): with mock.patch.object(self._local_evaluation_client.cohort_storage, 'put_cohort', return_value=None): self._local_evaluation_client.cohort_storage.get_cohort_ids = mock.Mock(return_value=set()) targeted_user = User(user_id='12345') @@ -105,8 +105,8 @@ def test_evaluation_cohorts_not_in_storage(self): with self.assertLogs(self._local_evaluation_client.logger, level='WARNING') as log: self._local_evaluation_client.evaluate_v2(targeted_user, {'sdk-local-evaluation-user-cohort-ci-test'}) log_message = ( - "WARNING:Amplitude:Evaluating flag sdk-local-evaluation-user-cohort-ci-test with cohorts " - "{.*} without cohort syncing configured" + "WARNING:Amplitude:Evaluating flag sdk-local-evaluation-user-cohort-ci-test dependent on cohorts " + "{.*} without {.*} in storage" ) self.assertTrue(any(re.match(log_message, message) for message in log.output)) From 70437327260d74efa5c03078b500690b6fa64edd Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 5 Aug 2024 16:35:55 -0700 Subject: [PATCH 41/44] update_stored_cohorts using load_cohort --- .../cohort/cohort_loader.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index e951cd9..e5d6639 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -39,16 +39,20 @@ def download_cohort(self, cohort_id: str) -> Cohort: def update_stored_cohorts(self) -> Future: def update_task(): errors = [] - try: - futures = [self.executor.submit(self.__load_cohort_internal, cohort_id) for cohort_id in self.cohort_storage.get_cohort_ids()] + cohort_ids = self.cohort_storage.get_cohort_ids() - for future, cohort_id in zip(as_completed(futures), self.cohort_storage.get_cohort_ids()): - try: - future.result() - except Exception as e: - errors.append((cohort_id, e)) - except Exception as e: - errors.append(e) + futures = [] + with self.lock_jobs: + for cohort_id in cohort_ids: + future = self.load_cohort(cohort_id) + futures.append(future) + + for future in as_completed(futures): + cohort_id = next(c_id for c_id, f in self.jobs.items() if f == future) + try: + future.result() + except Exception as e: + errors.append((cohort_id, e)) if errors: raise CohortUpdateException(errors) From 1d974f1c983bba5bd94eef8ae0d40ecde0c99773 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Mon, 5 Aug 2024 23:26:47 -0700 Subject: [PATCH 42/44] refresh cohorts based on flag configs in storage --- .../cohort/cohort_loader.py | 28 +++++++++---------- .../deployment/deployment_runner.py | 20 ++++++------- src/amplitude_experiment/exception.py | 2 +- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index e5d6639..14f639e 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -6,7 +6,7 @@ from .cohort import Cohort from .cohort_download_api import CohortDownloadApi from .cohort_storage import CohortStorage -from ..exception import CohortUpdateException +from ..exception import CohortsDownloadException class CohortLoader: @@ -30,39 +30,37 @@ def load_cohort(self, cohort_id: str) -> Future: def _remove_job(self, cohort_id: str): if cohort_id in self.jobs: - del self.jobs[cohort_id] + with self.lock_jobs: + self.jobs.pop(cohort_id, None) def download_cohort(self, cohort_id: str) -> Cohort: cohort = self.cohort_storage.get_cohort(cohort_id) return self.cohort_download_api.get_cohort(cohort_id, cohort) - def update_stored_cohorts(self) -> Future: - def update_task(): + def download_cohorts(self, cohort_ids: Set[str]) -> Future: + def update_task(task_cohort_ids): errors = [] - cohort_ids = self.cohort_storage.get_cohort_ids() - futures = [] - with self.lock_jobs: - for cohort_id in cohort_ids: - future = self.load_cohort(cohort_id) - futures.append(future) + for cohort_id in task_cohort_ids: + future = self.load_cohort(cohort_id) + futures.append(future) for future in as_completed(futures): - cohort_id = next(c_id for c_id, f in self.jobs.items() if f == future) try: future.result() except Exception as e: - errors.append((cohort_id, e)) + cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None) + if cohort_id: + errors.append((cohort_id, e)) if errors: - raise CohortUpdateException(errors) + raise CohortsDownloadException(errors) - return self.executor.submit(update_task) + return self.executor.submit(update_task, cohort_ids) def __load_cohort_internal(self, cohort_id): try: cohort = self.download_cohort(cohort_id) - # None is returned when cohort is not modified if cohort is not None: self.cohort_storage.put_cohort(cohort) except Exception as e: diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index e826aed..f61b657 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -8,7 +8,9 @@ from ..flag.flag_config_api import FlagConfigApi from ..flag.flag_config_storage import FlagConfigStorage from ..local.poller import Poller -from ..util.flag_config import get_all_cohort_ids_from_flag +from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags + +COHORT_POLLING_INTERVAL_MILLIS = 60000 class DeploymentRunner: @@ -29,7 +31,7 @@ def __init__( self.lock = threading.Lock() self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) if self.cohort_loader: - self.cohort_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, + self.cohort_poller = Poller(COHORT_POLLING_INTERVAL_MILLIS / 1000, self.__update_cohorts) self.logger = logger @@ -71,15 +73,12 @@ def __update_flag_configs(self): existing_cohort_ids = self.cohort_storage.get_cohort_ids() cohort_ids_to_download = new_cohort_ids - existing_cohort_ids - cohort_download_errors = [] # download all new cohorts - for cohort_id in cohort_ids_to_download: - try: - self.cohort_loader.load_cohort(cohort_id).result() - except Exception as e: - cohort_download_errors.append((cohort_id, str(e))) - self.logger.warning(f"Download cohort {cohort_id} failed: {e}") + try: + self.cohort_loader.download_cohorts(cohort_ids_to_download).result() + except Exception as e: + self.logger.warning(f"Error while downloading cohorts: {e}") # get updated set of cohort ids updated_cohort_ids = self.cohort_storage.get_cohort_ids() @@ -97,8 +96,9 @@ def __update_flag_configs(self): self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") def __update_cohorts(self): + cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values())) try: - self.cohort_loader.update_stored_cohorts().result() + self.cohort_loader.download_cohorts(cohort_ids).result() except Exception as e: self.logger.warning(f"Error while updating cohorts: {e}") diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 7281a0e..92ed0e9 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -15,7 +15,7 @@ def __init__(self, status_code, message): self.status_code = status_code -class CohortUpdateException(Exception): +class CohortsDownloadException(Exception): def __init__(self, errors): self.errors = errors super().__init__(self.__str__()) From 06e693e0f995708f7140a48308c91bce7b0476c6 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Tue, 6 Aug 2024 13:58:53 -0700 Subject: [PATCH 43/44] update cohort_sync_config fields: include polling and remove request delay, use enum for serverzone, update tests accordingly --- .../cohort/cohort_download_api.py | 13 +++++++------ .../cohort/cohort_sync_config.py | 7 ++++--- .../deployment/deployment_runner.py | 4 +--- src/amplitude_experiment/local/client.py | 1 - tests/cohort/cohort_download_api_test.py | 2 +- tests/deployment/deployment_runner_test.py | 5 +++-- tests/local/client_eu_test.py | 3 +-- tests/local/client_test.py | 3 +-- 8 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 6f9aae0..65da244 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -9,6 +9,8 @@ from ..connection_pool import HTTPConnectionPool from ..exception import HTTPErrorResponseException, CohortTooLargeException +COHORT_REQUEST_RETRY_DELAY_MILLIS = 100 + class CohortDownloadApi: @@ -17,13 +19,11 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: class DirectCohortDownloadApi(CohortDownloadApi): - def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int, - server_url: str, logger: logging.Logger): + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger): super().__init__() self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size - self.cohort_request_delay_millis = cohort_request_delay_millis self.server_url = server_url self.logger = logger self.__setup_connection_pool() @@ -48,10 +48,11 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None group_type=cohort_info['groupType'], ) elif response.status == 204: - self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified" ) + self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified") return elif response.status == 413: - raise CohortTooLargeException(f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}") + raise CohortTooLargeException( + f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}") elif response.status != 202: raise HTTPErrorResponseException(response.status, f"Unexpected response code: {response.status}") @@ -61,7 +62,7 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") if errors >= 3 or isinstance(e, CohortTooLargeException): raise e - time.sleep(self.cohort_request_delay_millis/1000) + time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000) def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse: headers = { diff --git a/src/amplitude_experiment/cohort/cohort_sync_config.py b/src/amplitude_experiment/cohort/cohort_sync_config.py index 0ee25b2..e3b0090 100644 --- a/src/amplitude_experiment/cohort/cohort_sync_config.py +++ b/src/amplitude_experiment/cohort/cohort_sync_config.py @@ -10,14 +10,15 @@ class CohortSyncConfig: api_key (str): The project API Key secret_key (str): The project Secret Key max_cohort_size (int): The maximum cohort size that can be downloaded - cohort_request_delay_millis (int): The delay in milliseconds between cohort download requests + cohort_polling_interval_millis (int): The interval, in milliseconds, at which to poll for + cohort updates, minimum 60000 cohort_server_url (str): The server endpoint from which to request cohorts """ def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647, - cohort_request_delay_millis: int = 5000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL): + cohort_polling_interval_millis: int = 60000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL): self.api_key = api_key self.secret_key = secret_key self.max_cohort_size = max_cohort_size - self.cohort_request_delay_millis = cohort_request_delay_millis + self.cohort_polling_interval_millis = max(cohort_polling_interval_millis, 60000) self.cohort_server_url = cohort_server_url diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index f61b657..aa8aa64 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -10,8 +10,6 @@ from ..local.poller import Poller from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags -COHORT_POLLING_INTERVAL_MILLIS = 60000 - class DeploymentRunner: def __init__( @@ -31,7 +29,7 @@ def __init__( self.lock = threading.Lock() self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) if self.cohort_loader: - self.cohort_poller = Poller(COHORT_POLLING_INTERVAL_MILLIS / 1000, + self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000, self.__update_cohorts) self.logger = logger diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 147de5d..ba917d4 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -61,7 +61,6 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): cohort_download_api = DirectCohortDownloadApi(self.config.cohort_sync_config.api_key, self.config.cohort_sync_config.secret_key, self.config.cohort_sync_config.max_cohort_size, - self.config.cohort_sync_config.cohort_request_delay_millis, self.config.cohort_sync_config.cohort_server_url, self.logger) diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py index e8b856e..2a42394 100644 --- a/tests/cohort/cohort_download_api_test.py +++ b/tests/cohort/cohort_download_api_test.py @@ -19,7 +19,7 @@ def response(code: int, body: dict = None): class CohortDownloadApiTest(unittest.TestCase): def setUp(self): - self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com", mock.create_autospec(logging.Logger)) + self.api = DirectCohortDownloadApi('api', 'secret', 15000, "https://example.amplitude.com", mock.create_autospec(logging.Logger)) def test_cohort_download_success(self): cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py index 447c190..f64a332 100644 --- a/tests/deployment/deployment_runner_test.py +++ b/tests/deployment/deployment_runner_test.py @@ -5,6 +5,7 @@ from src.amplitude_experiment import LocalEvaluationConfig from src.amplitude_experiment.cohort.cohort_loader import CohortLoader +from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi from src.amplitude_experiment.deployment.deployment_runner import DeploymentRunner @@ -41,7 +42,7 @@ def test_start_throws_if_first_flag_config_load_fails(self): logger = mock.create_autospec(logging.Logger) cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( - LocalEvaluationConfig(), + LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), flag_api, flag_config_storage, cohort_storage, @@ -61,7 +62,7 @@ def test_start_does_not_throw_if_cohort_load_fails(self): logger = mock.create_autospec(logging.Logger) cohort_loader = CohortLoader(cohort_download_api, cohort_storage) runner = DeploymentRunner( - LocalEvaluationConfig(), + LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), flag_api, flag_config_storage, cohort_storage, logger, diff --git a/tests/local/client_eu_test.py b/tests/local/client_eu_test.py index ed8b3af..91e792d 100644 --- a/tests/local/client_eu_test.py +++ b/tests/local/client_eu_test.py @@ -17,8 +17,7 @@ def setUpClass(cls) -> None: api_key = os.environ['EU_API_KEY'] secret_key = os.environ['EU_SECRET_KEY'] cohort_sync_config = CohortSyncConfig(api_key=api_key, - secret_key=secret_key, - cohort_request_delay_millis=100) + secret_key=secret_key) cls._local_evaluation_client = ( LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, server_zone=ServerZone.EU, cohort_sync_config=cohort_sync_config))) diff --git a/tests/local/client_test.py b/tests/local/client_test.py index 87085c6..b6c50eb 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -22,8 +22,7 @@ def setUpClass(cls) -> None: api_key = os.environ['API_KEY'] secret_key = os.environ['SECRET_KEY'] cohort_sync_config = CohortSyncConfig(api_key=api_key, - secret_key=secret_key, - cohort_request_delay_millis=100) + secret_key=secret_key) cls._local_evaluation_client = ( LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, cohort_sync_config=cohort_sync_config))) From a9006cf3c6dde16758966169e8cbb336d6d298b9 Mon Sep 17 00:00:00 2001 From: tyiuhc Date: Wed, 7 Aug 2024 13:59:37 -0700 Subject: [PATCH 44/44] add SDK+version to cohort request header --- src/amplitude_experiment/cohort/cohort_download_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py index 65da244..e1a41a3 100644 --- a/src/amplitude_experiment/cohort/cohort_download_api.py +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -4,6 +4,7 @@ import json from http.client import HTTPResponse from typing import Optional +from ..version import __version__ from .cohort import Cohort from ..connection_pool import HTTPConnectionPool @@ -67,6 +68,7 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse: headers = { 'Authorization': f'Basic {self._get_basic_auth()}', + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" } conn = self._connection_pool.acquire() try: