diff --git a/.github/workflows/test-arm.yml b/.github/workflows/test-arm.yml index 30eac88..7ac803a 100644 --- a/.github/workflows/test-arm.yml +++ b/.github/workflows/test-arm.yml @@ -4,13 +4,21 @@ on: [pull_request] jobs: aarch_job: runs-on: ubuntu-latest - name: Test on ubuntu aarch64 + environment: Unit Test + name: Test on Ubuntu aarch64 steps: - - uses: actions/checkout@v3 - - uses: uraimo/run-on-arch-action@v2 - name: Run Unit Test + - name: Checkout source code + uses: actions/checkout@v3 + + - name: Set up and run unit test on aarch64 + uses: uraimo/run-on-arch-action@v2 id: runcmd with: + env: | + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} arch: aarch64 distro: ubuntu20.04 githubToken: ${{ github.token }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3072f54..d78efe6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,6 +5,7 @@ on: [pull_request] jobs: test: runs-on: ubuntu-latest + environment: Unit Test strategy: matrix: python-version: [ "3.7" ] @@ -19,8 +20,12 @@ jobs: cache: 'pip' - name: Install requirements - run: pip install -r requirements.txt - pip install -r requirements-dev.txt + run: pip install -r requirements.txt && pip install -r requirements-dev.txt - name: Unit Test + env: + API_KEY: ${{ secrets.API_KEY }} + SECRET_KEY: ${{ secrets.SECRET_KEY }} + EU_API_KEY: ${{ secrets.EU_API_KEY }} + EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }} run: python -m unittest discover -s ./tests -p '*_test.py' diff --git a/README.md b/README.md index cc9a36e..5d8bb3d 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,15 @@ user = User( variants = experiment.evaluate(user) ``` +# Running unit tests suite +To setup for running test on local, create a `.env` file with following +contents, and replace `{API_KEY}` and `{SECRET_KEY}` (or `{EU_API_KEY}` and `{EU_SECRET_KEY}` for EU data center) for the project in test: + +``` +API_KEY={API_KEY} +SECRET_KEY={SECRET_KEY} +``` + ## More Information Please visit our :100:[Developer Center](https://www.docs.developers.amplitude.com/experiment/sdks/python-sdk/) for more instructions on using our the SDK. diff --git a/requirements-dev.txt b/requirements-dev.txt index 4f5c19d..8e4ee4b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1 +1,2 @@ parameterized~=0.9.0 +python-dotenv~=0.21.1 diff --git a/src/amplitude_experiment/__init__.py b/src/amplitude_experiment/__init__.py index 482e490..aede23b 100644 --- a/src/amplitude_experiment/__init__.py +++ b/src/amplitude_experiment/__init__.py @@ -12,4 +12,6 @@ from .cookie import AmplitudeCookie from .local.client import LocalEvaluationClient from .local.config import LocalEvaluationConfig +from .local.config import ServerZone from .assignment import AssignmentConfig +from .cohort.cohort_sync_config import CohortSyncConfig diff --git a/src/amplitude_experiment/cohort/cohort.py b/src/amplitude_experiment/cohort/cohort.py new file mode 100644 index 0000000..ccc5013 --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass, field +from typing import ClassVar, Set + +USER_GROUP_TYPE: ClassVar[str] = "User" + + +@dataclass +class Cohort: + id: str + last_modified: int + size: int + member_ids: Set[str] + group_type: str = field(default=USER_GROUP_TYPE) diff --git a/src/amplitude_experiment/cohort/cohort_download_api.py b/src/amplitude_experiment/cohort/cohort_download_api.py new file mode 100644 index 0000000..e1a41a3 --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_download_api.py @@ -0,0 +1,91 @@ +import time +import logging +import base64 +import json +from http.client import HTTPResponse +from typing import Optional +from ..version import __version__ + +from .cohort import Cohort +from ..connection_pool import HTTPConnectionPool +from ..exception import HTTPErrorResponseException, CohortTooLargeException + +COHORT_REQUEST_RETRY_DELAY_MILLIS = 100 + + +class CohortDownloadApi: + + def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: + raise NotImplementedError + + +class DirectCohortDownloadApi(CohortDownloadApi): + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger): + super().__init__() + self.api_key = api_key + self.secret_key = secret_key + self.max_cohort_size = max_cohort_size + self.server_url = server_url + self.logger = logger + self.__setup_connection_pool() + + def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None: + self.logger.debug(f"getCohortMembers({cohort_id}): start") + errors = 0 + while True: + response = None + try: + last_modified = None if cohort is None else cohort.last_modified + response = self._get_cohort_members_request(cohort_id, last_modified) + self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}") + if response.status == 200: + cohort_info = json.loads(response.read().decode("utf8")) + self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}") + return Cohort( + id=cohort_info['cohortId'], + last_modified=cohort_info['lastModified'], + size=cohort_info['size'], + member_ids=set(cohort_info['memberIds']), + group_type=cohort_info['groupType'], + ) + elif response.status == 204: + self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified") + return + elif response.status == 413: + raise CohortTooLargeException( + f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}") + elif response.status != 202: + raise HTTPErrorResponseException(response.status, + f"Unexpected response code: {response.status}") + except Exception as e: + if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429): + errors += 1 + self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") + if errors >= 3 or isinstance(e, CohortTooLargeException): + raise e + time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000) + + def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse: + headers = { + 'Authorization': f'Basic {self._get_basic_auth()}', + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" + } + conn = self._connection_pool.acquire() + try: + url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}' + if last_modified is not None: + url += f'&lastModified={last_modified}' + response = conn.request('GET', url, headers=headers) + return response + finally: + self._connection_pool.release(conn) + + def _get_basic_auth(self) -> str: + credentials = f'{self.api_key}:{self.secret_key}' + return base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + + def __setup_connection_pool(self): + scheme, _, host = self.server_url.split('/', 3) + timeout = 10 + self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout, + scheme=scheme) diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py new file mode 100644 index 0000000..14f639e --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -0,0 +1,67 @@ +import logging +from typing import Dict, Set +from concurrent.futures import ThreadPoolExecutor, Future, as_completed +import threading + +from .cohort import Cohort +from .cohort_download_api import CohortDownloadApi +from .cohort_storage import CohortStorage +from ..exception import CohortsDownloadException + + +class CohortLoader: + def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage): + self.cohort_download_api = cohort_download_api + self.cohort_storage = cohort_storage + self.jobs: Dict[str, Future] = {} + self.lock_jobs = threading.Lock() + self.executor = ThreadPoolExecutor( + max_workers=32, + thread_name_prefix='CohortLoaderExecutor' + ) + + def load_cohort(self, cohort_id: str) -> Future: + with self.lock_jobs: + if cohort_id not in self.jobs: + future = self.executor.submit(self.__load_cohort_internal, cohort_id) + future.add_done_callback(lambda f: self._remove_job(cohort_id)) + self.jobs[cohort_id] = future + return self.jobs[cohort_id] + + def _remove_job(self, cohort_id: str): + if cohort_id in self.jobs: + with self.lock_jobs: + self.jobs.pop(cohort_id, None) + + def download_cohort(self, cohort_id: str) -> Cohort: + cohort = self.cohort_storage.get_cohort(cohort_id) + return self.cohort_download_api.get_cohort(cohort_id, cohort) + + def download_cohorts(self, cohort_ids: Set[str]) -> Future: + def update_task(task_cohort_ids): + errors = [] + futures = [] + for cohort_id in task_cohort_ids: + future = self.load_cohort(cohort_id) + futures.append(future) + + for future in as_completed(futures): + try: + future.result() + except Exception as e: + cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None) + if cohort_id: + errors.append((cohort_id, e)) + + if errors: + raise CohortsDownloadException(errors) + + return self.executor.submit(update_task, cohort_ids) + + def __load_cohort_internal(self, cohort_id): + try: + cohort = self.download_cohort(cohort_id) + if cohort is not None: + self.cohort_storage.put_cohort(cohort) + except Exception as e: + raise e diff --git a/src/amplitude_experiment/cohort/cohort_storage.py b/src/amplitude_experiment/cohort/cohort_storage.py new file mode 100644 index 0000000..9c6d4c5 --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_storage.py @@ -0,0 +1,73 @@ +from typing import Dict, Set, Optional +from threading import RLock + +from .cohort import Cohort, USER_GROUP_TYPE + + +class CohortStorage: + def get_cohort(self, cohort_id: str): + raise NotImplementedError + + def get_cohorts(self): + raise NotImplementedError + + def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: + raise NotImplementedError + + def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: + raise NotImplementedError + + def put_cohort(self, cohort_description: Cohort): + raise NotImplementedError + + def delete_cohort(self, group_type: str, cohort_id: str): + raise NotImplementedError + + def get_cohort_ids(self) -> Set[str]: + raise NotImplementedError + + +class InMemoryCohortStorage(CohortStorage): + def __init__(self): + self.lock = RLock() + self.group_to_cohort_store: Dict[str, Set[str]] = {} + self.cohort_store: Dict[str, Cohort] = {} + + def get_cohort(self, cohort_id: str): + with self.lock: + return self.cohort_store.get(cohort_id) + + def get_cohorts(self): + return self.cohort_store.copy() + + def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: + return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids) + + def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: + result = set() + with self.lock: + group_type_cohorts = self.group_to_cohort_store.get(group_type, {}) + for cohort_id in group_type_cohorts: + members = self.cohort_store.get(cohort_id).member_ids + if cohort_id in cohort_ids and group_name in members: + result.add(cohort_id) + return result + + def put_cohort(self, cohort: Cohort): + with self.lock: + if cohort.group_type not in self.group_to_cohort_store: + self.group_to_cohort_store[cohort.group_type] = set() + self.group_to_cohort_store[cohort.group_type].add(cohort.id) + self.cohort_store[cohort.id] = cohort + + def delete_cohort(self, group_type: str, cohort_id: str): + with self.lock: + group_cohorts = self.group_to_cohort_store.get(group_type, {}) + if cohort_id in group_cohorts: + group_cohorts.remove(cohort_id) + if cohort_id in self.cohort_store: + del self.cohort_store[cohort_id] + + def get_cohort_ids(self): + with self.lock: + return set(self.cohort_store.keys()) diff --git a/src/amplitude_experiment/cohort/cohort_sync_config.py b/src/amplitude_experiment/cohort/cohort_sync_config.py new file mode 100644 index 0000000..e3b0090 --- /dev/null +++ b/src/amplitude_experiment/cohort/cohort_sync_config.py @@ -0,0 +1,24 @@ +DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com' +EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com' + + +class CohortSyncConfig: + """Experiment Cohort Sync Configuration + This configuration is used to set up the cohort loader. The cohort loader is responsible for + downloading cohorts from the server and storing them locally. + Parameters: + api_key (str): The project API Key + secret_key (str): The project Secret Key + max_cohort_size (int): The maximum cohort size that can be downloaded + cohort_polling_interval_millis (int): The interval, in milliseconds, at which to poll for + cohort updates, minimum 60000 + cohort_server_url (str): The server endpoint from which to request cohorts + """ + + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647, + cohort_polling_interval_millis: int = 60000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL): + self.api_key = api_key + self.secret_key = secret_key + self.max_cohort_size = max_cohort_size + self.cohort_polling_interval_millis = max(cohort_polling_interval_millis, 60000) + self.cohort_server_url = cohort_server_url diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py new file mode 100644 index 0000000..aa8aa64 --- /dev/null +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -0,0 +1,114 @@ +import logging +from typing import Optional +import threading + +from ..local.config import LocalEvaluationConfig +from ..cohort.cohort_loader import CohortLoader +from ..cohort.cohort_storage import CohortStorage +from ..flag.flag_config_api import FlagConfigApi +from ..flag.flag_config_storage import FlagConfigStorage +from ..local.poller import Poller +from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags + + +class DeploymentRunner: + def __init__( + self, + config: LocalEvaluationConfig, + flag_config_api: FlagConfigApi, + flag_config_storage: FlagConfigStorage, + cohort_storage: CohortStorage, + logger: logging.Logger, + cohort_loader: Optional[CohortLoader] = None, + ): + self.config = config + self.flag_config_api = flag_config_api + self.flag_config_storage = flag_config_storage + self.cohort_storage = cohort_storage + self.cohort_loader = cohort_loader + self.lock = threading.Lock() + self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) + if self.cohort_loader: + self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000, + self.__update_cohorts) + self.logger = logger + + def start(self): + with self.lock: + self.__update_flag_configs() + self.flag_poller.start() + if self.cohort_loader: + self.cohort_poller.start() + + def stop(self): + self.flag_poller.stop() + + def __periodic_flag_update(self): + try: + self.__update_flag_configs() + except Exception as e: + self.logger.warning(f"Error while updating flags: {e}") + + def __update_flag_configs(self): + try: + flag_configs = self.flag_config_api.get_flag_configs() + except Exception as e: + self.logger.warning(f'Failed to fetch flag configs: {e}') + raise e + + flag_keys = {flag['key'] for flag in flag_configs} + self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys) + + if not self.cohort_loader: + for flag_config in flag_configs: + self.logger.debug(f"Putting non-cohort flag {flag_config['key']}") + self.flag_config_storage.put_flag_config(flag_config) + return + + new_cohort_ids = set() + for flag_config in flag_configs: + new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config)) + + existing_cohort_ids = self.cohort_storage.get_cohort_ids() + cohort_ids_to_download = new_cohort_ids - existing_cohort_ids + + # download all new cohorts + try: + self.cohort_loader.download_cohorts(cohort_ids_to_download).result() + except Exception as e: + self.logger.warning(f"Error while downloading cohorts: {e}") + + # get updated set of cohort ids + updated_cohort_ids = self.cohort_storage.get_cohort_ids() + # iterate through new flag configs and check if their required cohorts exist + for flag_config in flag_configs: + cohort_ids = get_all_cohort_ids_from_flag(flag_config) + self.logger.debug(f"Storing flag {flag_config['key']}") + self.flag_config_storage.put_flag_config(flag_config) + missing_cohorts = cohort_ids - updated_cohort_ids + if missing_cohorts: + self.logger.warning(f"Flag {flag_config['key']} - failed to load cohorts: {missing_cohorts}") + + # delete unused cohorts + self._delete_unused_cohorts() + self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") + + def __update_cohorts(self): + cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values())) + try: + self.cohort_loader.download_cohorts(cohort_ids).result() + except Exception as e: + self.logger.warning(f"Error while updating cohorts: {e}") + + def _delete_unused_cohorts(self): + flag_cohort_ids = set() + for flag in self.flag_config_storage.get_flag_configs().values(): + flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag)) + + storage_cohorts = self.cohort_storage.get_cohorts() + deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids + + for deleted_cohort_id in deleted_cohort_ids: + deleted_cohort = storage_cohorts.get(deleted_cohort_id) + if deleted_cohort is not None: + self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id) diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 58dd305..92ed0e9 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -2,3 +2,30 @@ class FetchException(Exception): def __init__(self, status_code, message): super().__init__(message) self.status_code = status_code + + +class CohortTooLargeException(Exception): + def __init__(self, message): + super().__init__(message) + + +class HTTPErrorResponseException(Exception): + def __init__(self, status_code, message): + super().__init__(message) + self.status_code = status_code + + +class CohortsDownloadException(Exception): + def __init__(self, errors): + self.errors = errors + super().__init__(self.__str__()) + + def __str__(self): + error_messages = [] + for item in self.errors: + if isinstance(item, tuple) and len(item) == 2: + cohort_id, error = item + error_messages.append(f"Cohort {cohort_id}: {error}") + else: + error_messages.append(str(item)) + return f"One or more cohorts failed to update:\n" + "\n".join(error_messages) diff --git a/src/amplitude_experiment/flag/flag_config_api.py b/src/amplitude_experiment/flag/flag_config_api.py new file mode 100644 index 0000000..15db645 --- /dev/null +++ b/src/amplitude_experiment/flag/flag_config_api.py @@ -0,0 +1,47 @@ +import json +from typing import List + +from ..version import __version__ + +from ..connection_pool import HTTPConnectionPool + + +class FlagConfigApi: + def get_flag_configs(self) -> List: + pass + + +class FlagConfigApiV2(FlagConfigApi): + def __init__(self, deployment_key: str, server_url: str, flag_config_poller_request_timeout_millis: int): + self.deployment_key = deployment_key + self.server_url = server_url + self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis + self.__setup_connection_pool() + + def get_flag_configs(self) -> List: + return self._get_flag_configs() + + def _get_flag_configs(self) -> List: + conn = self._connection_pool.acquire() + headers = { + 'Authorization': f"Api-Key {self.deployment_key}", + 'Content-Type': 'application/json;charset=utf-8', + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" + } + body = None + try: + response = conn.request('GET', '/sdk/v2/flags?v=0', body, headers) + response_body = response.read().decode("utf8") + if response.status != 200: + raise Exception( + f"[Experiment] Get flagConfigs - received error response: ${response.status}: ${response_body}") + flags = json.loads(response_body) + return flags + finally: + self._connection_pool.release(conn) + + def __setup_connection_pool(self): + scheme, _, host = self.server_url.split('/', 3) + timeout = self.flag_config_poller_request_timeout_millis / 1000 + self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, + read_timeout=timeout, scheme=scheme) diff --git a/src/amplitude_experiment/flag/flag_config_storage.py b/src/amplitude_experiment/flag/flag_config_storage.py new file mode 100644 index 0000000..68b1a73 --- /dev/null +++ b/src/amplitude_experiment/flag/flag_config_storage.py @@ -0,0 +1,38 @@ +from typing import Dict, Callable +from threading import Lock + + +class FlagConfigStorage: + def get_flag_config(self, key: str) -> Dict: + raise NotImplementedError + + def get_flag_configs(self) -> Dict: + raise NotImplementedError + + def put_flag_config(self, flag_config: Dict): + raise NotImplementedError + + def remove_if(self, condition: Callable[[Dict], bool]): + raise NotImplementedError + + +class InMemoryFlagConfigStorage(FlagConfigStorage): + def __init__(self): + self.flag_configs = {} + self.flag_configs_lock = Lock() + + def get_flag_config(self, key: str) -> Dict: + with self.flag_configs_lock: + return self.flag_configs.get(key) + + def get_flag_configs(self) -> Dict[str, Dict]: + with self.flag_configs_lock: + return self.flag_configs.copy() + + def put_flag_config(self, flag_config: Dict): + with self.flag_configs_lock: + self.flag_configs[flag_config['key']] = flag_config + + def remove_if(self, condition: Callable[[Dict], bool]): + with self.flag_configs_lock: + self.flag_configs = {key: value for key, value in self.flag_configs.items() if not condition(value)} diff --git a/src/amplitude_experiment/local/client.py b/src/amplitude_experiment/local/client.py index 4db9e5b..ba917d4 100644 --- a/src/amplitude_experiment/local/client.py +++ b/src/amplitude_experiment/local/client.py @@ -8,15 +8,21 @@ from .config import LocalEvaluationConfig from .topological_sort import topological_sort from ..assignment import Assignment, AssignmentFilter, AssignmentService +from ..cohort.cohort import USER_GROUP_TYPE +from ..cohort.cohort_download_api import DirectCohortDownloadApi +from ..cohort.cohort_loader import CohortLoader +from ..cohort.cohort_storage import InMemoryCohortStorage +from ..deployment.deployment_runner import DeploymentRunner +from ..flag.flag_config_api import FlagConfigApiV2 +from ..flag.flag_config_storage import InMemoryFlagConfigStorage from ..user import User from ..connection_pool import HTTPConnectionPool -from .poller import Poller from .evaluation.evaluation import evaluate from ..util import deprecated +from ..util.flag_config import get_grouped_cohort_ids_from_flags, get_all_cohort_ids_from_flag from ..util.user import user_to_evaluation_context from ..util.variant import evaluation_variants_json_to_variants from ..variant import Variant -from ..version import __version__ class LocalEvaluationClient: @@ -47,17 +53,29 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None): if self.config.debug: self.logger.setLevel(logging.DEBUG) self.__setup_connection_pool() - self.flags = None - self.poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__do_flags) self.lock = Lock() + self.cohort_storage = InMemoryCohortStorage() + self.flag_config_storage = InMemoryFlagConfigStorage() + cohort_loader = None + if self.config.cohort_sync_config: + cohort_download_api = DirectCohortDownloadApi(self.config.cohort_sync_config.api_key, + self.config.cohort_sync_config.secret_key, + self.config.cohort_sync_config.max_cohort_size, + self.config.cohort_sync_config.cohort_server_url, + self.logger) + + cohort_loader = CohortLoader(cohort_download_api, self.cohort_storage) + flag_config_api = FlagConfigApiV2(api_key, self.config.server_url, + self.config.flag_config_poller_request_timeout_millis) + self.deployment_runner = DeploymentRunner(self.config, flag_config_api, self.flag_config_storage, + self.cohort_storage, self.logger, cohort_loader) def start(self): """ Fetch initial flag configurations and start polling for updates. You must call this function to begin polling for flag config updates. """ - self.__do_flags() - self.poller.start() + self.deployment_runner.start() def evaluate_v2(self, user: User, flag_keys: Set[str] = None) -> Dict[str, Variant]: """ @@ -74,11 +92,20 @@ def evaluate_v2(self, user: User, flag_keys: Set[str] = None) -> Dict[str, Varia Returns: The evaluated variants. """ - if self.flags is None or len(self.flags) == 0: + flag_configs = self.flag_config_storage.get_flag_configs() + if flag_configs is None or len(flag_configs) == 0: return {} - self.logger.debug(f"[Experiment] Evaluate: user={user} - Flags: {self.flags}") + self.logger.debug(f"[Experiment] Evaluate: user={user} - Flags: {flag_configs}") + sorted_flags = topological_sort(flag_configs, flag_keys) + if not sorted_flags: + return {} + + # Check if all required cohorts are in storage, if not log a warning + self._required_cohorts_in_storage(sorted_flags) + if self.config.cohort_sync_config: + user = self._enrich_user_with_cohorts(user, flag_configs) + context = user_to_evaluation_context(user) - sorted_flags = topological_sort(self.flags, flag_keys) flags_json = json.dumps(sorted_flags) context_json = json.dumps(context) result_json = evaluate(flags_json, context_json) @@ -115,30 +142,6 @@ def evaluate(self, user: User, flag_keys: List[str] = None) -> Dict[str, Variant variants = self.evaluate_v2(user, flag_keys) return self.__filter_default_variants(variants) - def __do_flags(self): - conn = self._connection_pool.acquire() - headers = { - 'Authorization': f"Api-Key {self.api_key}", - 'Content-Type': 'application/json;charset=utf-8', - 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" - } - body = None - self.logger.debug('[Experiment] Get flag configs') - try: - response = conn.request('GET', '/sdk/v2/flags?v=0', body, headers) - response_body = response.read().decode("utf8") - if response.status != 200: - raise Exception( - f"[Experiment] Get flagConfigs - received error response: ${response.status}: ${response_body}") - flags = json.loads(response_body) - flags_dict = {flag['key']: flag for flag in flags} - self.logger.debug(f"[Experiment] Got flag configs: {flags}") - self.lock.acquire() - self.flags = flags_dict - self.lock.release() - finally: - self._connection_pool.release(conn) - def __setup_connection_pool(self): scheme, _, host = self.config.server_url.split('/', 3) timeout = self.config.flag_config_poller_request_timeout_millis / 1000 @@ -149,7 +152,7 @@ def stop(self) -> None: """ Stop polling for flag configurations. Close resource like connection pool with client """ - self.poller.stop() + self.deployment_runner.stop() self._connection_pool.close() def __enter__(self) -> 'LocalEvaluationClient': @@ -167,5 +170,40 @@ def is_default_variant(variant: Variant) -> bool: return {key: variant for key, variant in variants.items() if not is_default_variant(variant)} - - + def _required_cohorts_in_storage(self, flag_configs: List) -> None: + stored_cohort_ids = self.cohort_storage.get_cohort_ids() + for flag in flag_configs: + flag_cohort_ids = get_all_cohort_ids_from_flag(flag) + missing_cohorts = flag_cohort_ids - stored_cohort_ids + if missing_cohorts: + message = ( + f"Evaluating flag {flag['key']} dependent on cohorts {flag_cohort_ids} " + f"without {missing_cohorts} in storage" + if self.config.cohort_sync_config + else f"Evaluating flag {flag['key']} dependent on cohorts {flag_cohort_ids} without " + f"cohort syncing configured" + ) + self.logger.warning(message) + + def _enrich_user_with_cohorts(self, user: User, flag_configs: Dict) -> User: + grouped_cohort_ids = get_grouped_cohort_ids_from_flags(list(flag_configs.values())) + + if USER_GROUP_TYPE in grouped_cohort_ids: + user_cohort_ids = grouped_cohort_ids[USER_GROUP_TYPE] + if user_cohort_ids and user.user_id: + user.cohort_ids = list(self.cohort_storage.get_cohorts_for_user(user.user_id, user_cohort_ids)) + + if user.groups: + for group_type, group_names in user.groups.items(): + group_name = group_names[0] if group_names else None + if not group_name: + continue + cohort_ids = grouped_cohort_ids.get(group_type, []) + if not cohort_ids: + continue + user.add_group_cohort_ids( + group_type, + group_name, + list(self.cohort_storage.get_cohorts_for_group(group_type, group_name, cohort_ids)) + ) + return user diff --git a/src/amplitude_experiment/local/config.py b/src/amplitude_experiment/local/config.py index 027467d..c729e36 100644 --- a/src/amplitude_experiment/local/config.py +++ b/src/amplitude_experiment/local/config.py @@ -1,32 +1,53 @@ +from enum import Enum + from ..assignment import AssignmentConfig +from ..cohort.cohort_sync_config import CohortSyncConfig, DEFAULT_COHORT_SYNC_URL, EU_COHORT_SYNC_URL + +DEFAULT_SERVER_URL = 'https://api.lab.amplitude.com' +EU_SERVER_URL = 'https://flag.lab.eu.amplitude.com' + + +class ServerZone(Enum): + US = "US" + EU = "EU" class LocalEvaluationConfig: """Experiment Local Client Configuration""" - DEFAULT_SERVER_URL = 'https://api.lab.amplitude.com' - def __init__(self, debug: bool = False, server_url: str = DEFAULT_SERVER_URL, + server_zone: ServerZone = ServerZone.US, flag_config_polling_interval_millis: int = 30000, flag_config_poller_request_timeout_millis: int = 10000, - assignment_config: AssignmentConfig = None): + assignment_config: AssignmentConfig = None, + cohort_sync_config: CohortSyncConfig = None): """ Initialize a config Parameters: - debug (bool): Set to true to log some extra information to the console. - server_url (str): The server endpoint from which to request variants. - flag_config_polling_interval_millis (int): The interval in milliseconds to poll the amplitude server for - flag config updates. These rules are stored in memory and used when calling evaluate() - to perform local evaluation. - flag_config_poller_request_timeout_millis (int): The request timeout, in milliseconds, - used when fetching variants. + debug (bool): Set to true to log some extra information to the console. + server_url (str): The server endpoint from which to request flag configs. + server_zone (ServerZone): Location of the Amplitude data center to get flags and cohorts from, US or EU + flag_config_polling_interval_millis (int): The interval, in milliseconds, at which to poll for flag + configurations. + flag_config_poller_request_timeout_millis (int): The request timeout, in milliseconds, used when + fetching flag configurations. + assignment_config (AssignmentConfig): The assignment configuration. + cohort_sync_config (CohortSyncConfig): The cohort sync configuration. Returns: The config object """ self.debug = debug self.server_url = server_url + self.server_zone = server_zone + self.cohort_sync_config = cohort_sync_config + if server_url == DEFAULT_SERVER_URL and server_zone == ServerZone.EU: + self.server_url = EU_SERVER_URL + if (cohort_sync_config is not None and + cohort_sync_config.cohort_server_url == DEFAULT_COHORT_SYNC_URL): + self.cohort_sync_config.cohort_server_url = EU_COHORT_SYNC_URL + self.flag_config_polling_interval_millis = flag_config_polling_interval_millis self.flag_config_poller_request_timeout_millis = flag_config_poller_request_timeout_millis self.assignment_config = assignment_config diff --git a/src/amplitude_experiment/user.py b/src/amplitude_experiment/user.py index eed02c7..8cbdcea 100644 --- a/src/amplitude_experiment/user.py +++ b/src/amplitude_experiment/user.py @@ -1,6 +1,6 @@ import json -from typing import Dict, Any +from typing import Dict, Any, Set, List class User: @@ -27,8 +27,10 @@ def __init__( carrier: str = None, library: str = None, user_properties: Dict[str, Any] = None, - groups: Dict[str, str] = None, - group_properties: Dict[str, Dict[str, Dict[str, Any]]] = None + groups: Dict[str, List[str]] = None, + group_properties: Dict[str, Dict[str, Dict[str, Any]]] = None, + group_cohort_ids: Dict[str, Dict[str, List[str]]] = None, + cohort_ids: List[str] = None ): """ Initialize User instance @@ -73,6 +75,8 @@ def __init__( self.user_properties = user_properties self.groups = groups self.group_properties = group_properties + self.group_cohort_ids = group_cohort_ids + self.cohort_ids = cohort_ids def to_json(self): """Return user information as JSON string.""" @@ -81,3 +85,17 @@ def to_json(self): def __str__(self): """Return user as string""" return self.to_json() + + def add_group_cohort_ids(self, group_type: str, group_name: str, cohort_ids: List[str]) -> None: + """ + Add cohort IDs for a group. + Parameters: + group_type (str): The type of the group + group_name (str): The name of the group + cohort_ids (Set[str]): Set of cohort IDs associated with the group + """ + if self.group_cohort_ids is None: + self.group_cohort_ids = {} + + group_names = self.group_cohort_ids.setdefault(group_type, {}) + group_names[group_name] = cohort_ids diff --git a/src/amplitude_experiment/util/flag_config.py b/src/amplitude_experiment/util/flag_config.py new file mode 100644 index 0000000..c64ac2c --- /dev/null +++ b/src/amplitude_experiment/util/flag_config.py @@ -0,0 +1,54 @@ +from typing import List, Dict, Set, Any + +from ..cohort.cohort import USER_GROUP_TYPE + + +def is_cohort_filter(condition: Dict[str, Any]) -> bool: + return ( + condition['op'] in {"set contains any", "set does not contain any"} + and condition['selector'] + and condition['selector'][-1] == "cohort_ids" + ) + + +def get_grouped_cohort_condition_ids(segment: Dict[str, Any]) -> Dict[str, Set[str]]: + cohort_ids = {} + conditions = segment.get('conditions', []) + for outer in conditions: + for condition in outer: + if is_cohort_filter(condition): + if len(condition['selector']) > 2: + context_subtype = condition['selector'][1] + if context_subtype == "user": + group_type = USER_GROUP_TYPE + elif "groups" in condition['selector']: + group_type = condition['selector'][2] + else: + continue + cohort_ids.setdefault(group_type, set()).update(condition['values']) + return cohort_ids + + +def get_grouped_cohort_ids_from_flag(flag: Dict[str, Any]) -> Dict[str, Set[str]]: + cohort_ids = {} + segments = flag.get('segments', []) + for segment in segments: + for key, values in get_grouped_cohort_condition_ids(segment).items(): + cohort_ids.setdefault(key, set()).update(values) + return cohort_ids + + +def get_all_cohort_ids_from_flag(flag: Dict[str, Any]) -> Set[str]: + return {cohort_id for values in get_grouped_cohort_ids_from_flag(flag).values() for cohort_id in values} + + +def get_grouped_cohort_ids_from_flags(flags: List[Dict[str, Any]]) -> Dict[str, Set[str]]: + cohort_ids = {} + for flag in flags: + for key, values in get_grouped_cohort_ids_from_flag(flag).items(): + cohort_ids.setdefault(key, set()).update(values) + return cohort_ids + + +def get_all_cohort_ids_from_flags(flags: List[Dict[str, Any]]) -> Set[str]: + return {cohort_id for values in get_grouped_cohort_ids_from_flags(flags).values() for cohort_id in values} diff --git a/src/amplitude_experiment/util/user.py b/src/amplitude_experiment/util/user.py index 93e3fdd..2b0b48f 100644 --- a/src/amplitude_experiment/util/user.py +++ b/src/amplitude_experiment/util/user.py @@ -6,26 +6,40 @@ def user_to_evaluation_context(user: User) -> Dict[str, Any]: user_groups = user.groups user_group_properties = user.group_properties + user_group_cohort_ids = user.group_cohort_ids user_dict = {key: value for key, value in user.__dict__.copy().items() if value is not None} user_dict.pop('groups', None) user_dict.pop('group_properties', None) + user_dict.pop('group_cohort_ids', None) context = {'user': user_dict} if len(user_dict) > 0 else {} + if user_groups is None: return context + groups: Dict[str, Dict[str, Any]] = {} - for group_type in user_groups: - group_name = user_groups[group_type] - if type(group_name) == list and len(group_name) > 0: - group_name = group_name[0] - groups[group_type] = {'group_name': group_name} - if user_group_properties is None: - continue - group_properties_type = user_group_properties[group_type] - if group_properties_type is None or type(group_properties_type) != dict: + for group_type, group_names in user_groups.items(): + if isinstance(group_names, list) and len(group_names) > 0: + group_name = group_names[0] + else: continue - group_properties_name = group_properties_type[group_name] - if group_properties_name is None or type(group_properties_name) != dict: - continue - groups[group_type]['group_properties'] = group_properties_name + + group_name_map = {'group_name': group_name} + + if user_group_properties: + group_properties_type = user_group_properties.get(group_type) + if group_properties_type: + group_properties_name = group_properties_type.get(group_name) + if group_properties_name: + group_name_map['group_properties'] = group_properties_name + + if user_group_cohort_ids: + group_cohort_ids_type = user_group_cohort_ids.get(group_type) + if group_cohort_ids_type: + group_cohort_ids_name = group_cohort_ids_type.get(group_name) + if group_cohort_ids_name: + group_name_map['cohort_ids'] = group_cohort_ids_name + + groups[group_type] = group_name_map + context['groups'] = groups return context diff --git a/tests/cohort/cohort_download_api_test.py b/tests/cohort/cohort_download_api_test.py new file mode 100644 index 0000000..2a42394 --- /dev/null +++ b/tests/cohort/cohort_download_api_test.py @@ -0,0 +1,104 @@ +import json +import logging +import unittest +from unittest import mock +from unittest.mock import MagicMock, patch +from src.amplitude_experiment.cohort.cohort import Cohort +from src.amplitude_experiment.cohort.cohort_download_api import DirectCohortDownloadApi +from src.amplitude_experiment.exception import CohortTooLargeException + + +def response(code: int, body: dict = None): + mock_response = MagicMock() + mock_response.status = code + if body is not None: + mock_response.read.return_value = json.dumps(body).encode() + return mock_response + + +class CohortDownloadApiTest(unittest.TestCase): + + def setUp(self): + self.api = DirectCohortDownloadApi('api', 'secret', 15000, "https://example.amplitude.com", mock.create_autospec(logging.Logger)) + + def test_cohort_download_success(self): + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) + cohort_info_response = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) + + with patch.object(self.api, 'get_cohort', return_value=cohort_info_response): + + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) + + def test_cohort_download_many_202s_success(self): + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) + async_responses = [response(202)] * 9 + [response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']})] + + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) + + def test_cohort_request_status_with_two_failures_succeeds(self): + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) + error_response = response(503) + success_response = response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) + async_responses = [error_response, error_response, success_response] + + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) + + def test_cohort_request_status_429s_keep_retrying(self): + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'}) + error_response = response(429) + success_response = response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'User', 'memberIds': ['user']}) + async_responses = [error_response] * 9 + [success_response] + + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) + + def test_group_cohort_download_success(self): + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'group'}, group_type="org name") + cohort_info_response = Cohort(id="1234", last_modified=0, size=1, member_ids={'group'}, group_type="org name") + + with patch.object(self.api, 'get_cohort', return_value=cohort_info_response): + + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) + + def test_group_cohort_request_status_429s_keep_retrying(self): + cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'group'}, group_type="org name") + error_response = response(429) + success_response = response(200, {'cohortId': '1234', 'lastModified': 0, 'size': 1, 'groupType': 'org name', 'memberIds': ['group']}) + async_responses = [error_response] * 9 + [success_response] + + with patch.object(self.api, '_get_cohort_members_request', side_effect=async_responses): + + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(cohort, result_cohort) + + def test_cohort_size_too_large(self): + cohort = Cohort(id="1234", last_modified=0, size=16000, member_ids=set()) + too_large_response = response(413) + + with patch.object(self.api, '_get_cohort_members_request', return_value=too_large_response): + + with self.assertRaises(CohortTooLargeException): + self.api.get_cohort("1234", cohort) + + def test_cohort_not_modified(self): + cohort = Cohort(id="1234", last_modified=1000, size=1, member_ids=set()) + not_modified_response = response(204) + + with patch.object(self.api, '_get_cohort_members_request', return_value=not_modified_response): + result_cohort = self.api.get_cohort("1234", cohort) + self.assertEqual(None, result_cohort) + + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/cohort/cohort_loader_test.py b/tests/cohort/cohort_loader_test.py new file mode 100644 index 0000000..309a8ff --- /dev/null +++ b/tests/cohort/cohort_loader_test.py @@ -0,0 +1,72 @@ +import unittest +from unittest.mock import MagicMock + +from src.amplitude_experiment.cohort.cohort import Cohort +from src.amplitude_experiment.cohort.cohort_loader import CohortLoader +from src.amplitude_experiment.cohort.cohort_storage import InMemoryCohortStorage + +class CohortLoaderTest(unittest.TestCase): + def setUp(self): + self.api = MagicMock() + self.storage = InMemoryCohortStorage() + self.loader = CohortLoader(self.api, self.storage) + + def test_load_success(self): + self.api.get_cohort.side_effect = [ + Cohort(id="a", last_modified=0, size=1, member_ids={"1"}), + Cohort(id="b", last_modified=0, size=2, member_ids={"1", "2"}) + ] + + future_a = self.loader.load_cohort("a") + future_b = self.loader.load_cohort("b") + + future_a.result() + future_b.result() + + storage_description_a = self.storage.get_cohort("a") + storage_description_b = self.storage.get_cohort("b") + self.assertEqual(Cohort(id="a", last_modified=0, size=1, member_ids={"1"}), storage_description_a) + self.assertEqual(Cohort(id="b", last_modified=0, size=2, member_ids={"1", "2"}), storage_description_b) + + storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) + storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) + self.assertEqual({"a", "b"}, storage_user1_cohorts) + self.assertEqual({"b"}, storage_user2_cohorts) + + def test_filter_cohorts_already_computed_equivalent_cohorts_are_filtered(self): + self.storage.put_cohort(Cohort("a", last_modified=0, size=0, member_ids=set())) + self.storage.put_cohort(Cohort("b", last_modified=0, size=0, member_ids=set())) + self.api.get_cohort.side_effect = [ + Cohort(id="a", last_modified=0, size=0, member_ids=set()), + Cohort(id="b", last_modified=1, size=2, member_ids={"1", "2"}) + ] + + self.loader.load_cohort("a").result() + self.loader.load_cohort("b").result() + + storage_description_a = self.storage.get_cohort("a") + storage_description_b = self.storage.get_cohort("b") + self.assertEqual(Cohort(id="a", last_modified=0, size=0, member_ids=set()), storage_description_a) + self.assertEqual(Cohort(id="b", last_modified=1, size=2, member_ids={"1", "2"}), storage_description_b) + + storage_user1_cohorts = self.storage.get_cohorts_for_user("1", {"a", "b"}) + storage_user2_cohorts = self.storage.get_cohorts_for_user("2", {"a", "b"}) + self.assertEqual({"b"}, storage_user1_cohorts) + self.assertEqual({"b"}, storage_user2_cohorts) + + def test_load_download_failure_throws(self): + self.api.get_cohort.side_effect = [ + Cohort(id="a", last_modified=0, size=1, member_ids={"1"}), + Exception("Connection timed out"), + Cohort(id="c", last_modified=0, size=1, member_ids={"1"}) + ] + + self.loader.load_cohort("a").result() + with self.assertRaises(Exception): + self.loader.load_cohort("b").result() + self.loader.load_cohort("c").result() + + self.assertEqual({"a", "c"}, self.storage.get_cohorts_for_user("1", {"a", "b", "c"})) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/deployment/deployment_runner_test.py b/tests/deployment/deployment_runner_test.py new file mode 100644 index 0000000..f64a332 --- /dev/null +++ b/tests/deployment/deployment_runner_test.py @@ -0,0 +1,85 @@ +import unittest +from unittest import mock +from unittest.mock import patch +import logging + +from src.amplitude_experiment import LocalEvaluationConfig +from src.amplitude_experiment.cohort.cohort_loader import CohortLoader +from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig +from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi +from src.amplitude_experiment.deployment.deployment_runner import DeploymentRunner + +COHORT_ID = '1234' + + +class DeploymentRunnerTest(unittest.TestCase): + + def setUp(self): + self.flag = { + "key": "flag", + "variants": {}, + "segments": [ + { + "conditions": [ + [ + { + "selector": ["context", "user", "cohort_ids"], + "op": "set contains any", + "values": [COHORT_ID], + } + ] + ], + } + ] + } + + def test_start_throws_if_first_flag_config_load_fails(self): + flag_api = mock.create_autospec(FlagConfigApi) + cohort_download_api = mock.Mock() + flag_config_storage = mock.Mock() + cohort_storage = mock.Mock() + cohort_storage.get_cohort_ids.return_value = set() + logger = mock.create_autospec(logging.Logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage) + runner = DeploymentRunner( + LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), + flag_api, + flag_config_storage, + cohort_storage, + logger, + cohort_loader, + ) + flag_api.get_flag_configs.side_effect = RuntimeError("test") + with self.assertRaises(RuntimeError): + runner.start() + + def test_start_does_not_throw_if_cohort_load_fails(self): + flag_api = mock.create_autospec(FlagConfigApi) + cohort_download_api = mock.Mock() + flag_config_storage = mock.Mock() + cohort_storage = mock.Mock() + cohort_storage.get_cohort_ids.return_value = set() + logger = mock.create_autospec(logging.Logger) + cohort_loader = CohortLoader(cohort_download_api, cohort_storage) + runner = DeploymentRunner( + LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')), + flag_api, flag_config_storage, + cohort_storage, + logger, + cohort_loader, + ) + + # Mock methods as needed + with patch.object(runner, '_delete_unused_cohorts'): + flag_api.get_flag_configs.return_value = [self.flag] + cohort_download_api.get_cohort.side_effect = RuntimeError("test") + + # Simply call the method and let the test pass if no exception is raised + try: + runner.start() + except Exception as e: + self.fail(f"runner.start() raised an exception unexpectedly: {e}") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/local/client_eu_test.py b/tests/local/client_eu_test.py new file mode 100644 index 0000000..91e792d --- /dev/null +++ b/tests/local/client_eu_test.py @@ -0,0 +1,44 @@ +import unittest +from src.amplitude_experiment import LocalEvaluationClient, LocalEvaluationConfig, User, Variant +from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig +from src.amplitude_experiment.local.config import ServerZone +from dotenv import load_dotenv +import os + +SERVER_API_KEY = 'server-Qlp7XiSu6JtP2S3JzA95PnP27duZgQCF' + + +class LocalEvaluationClientTestCase(unittest.TestCase): + _local_evaluation_client: LocalEvaluationClient = None + + @classmethod + def setUpClass(cls) -> None: + load_dotenv() + api_key = os.environ['EU_API_KEY'] + secret_key = os.environ['EU_SECRET_KEY'] + cohort_sync_config = CohortSyncConfig(api_key=api_key, + secret_key=secret_key) + cls._local_evaluation_client = ( + LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, server_zone=ServerZone.EU, + cohort_sync_config=cohort_sync_config))) + cls._local_evaluation_client.start() + + @classmethod + def tearDownClass(cls) -> None: + cls._local_evaluation_client.stop() + + def test_evaluate_with_cohort_eu(self): + targeted_user = User(user_id='1', device_id='0') + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user) + .get('sdk-local-evaluation-user-cohort')) + expected_on_variant = Variant(key='on', value='on') + self.assertEqual(expected_on_variant, targeted_variant) + non_targeted_user = User(user_id='not_targeted') + non_targeted_variant = (self._local_evaluation_client.evaluate_v2(non_targeted_user) + .get('sdk-local-evaluation-user-cohort')) + expected_off_variant = Variant(key='off') + self.assertEqual(non_targeted_variant, expected_off_variant) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/local/client_test.py b/tests/local/client_test.py index 00293db..b6c50eb 100644 --- a/tests/local/client_test.py +++ b/tests/local/client_test.py @@ -1,7 +1,14 @@ +import re import unittest +from unittest import mock + from src.amplitude_experiment import LocalEvaluationClient, LocalEvaluationConfig, User, Variant +from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig +from dotenv import load_dotenv +import os + -API_KEY = 'server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz' +SERVER_API_KEY = 'server-qz35UwzJ5akieoAdIgzM4m9MIiOLXLoz' test_user = User(user_id='test_user') test_user_2 = User(user_id='user_id', device_id='device_id') @@ -11,7 +18,14 @@ class LocalEvaluationClientTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls._local_evaluation_client = LocalEvaluationClient(API_KEY, LocalEvaluationConfig(debug=False)) + load_dotenv() + api_key = os.environ['API_KEY'] + secret_key = os.environ['SECRET_KEY'] + cohort_sync_config = CohortSyncConfig(api_key=api_key, + secret_key=secret_key) + cls._local_evaluation_client = ( + LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, + cohort_sync_config=cohort_sync_config))) cls._local_evaluation_client.start() @classmethod @@ -56,6 +70,45 @@ def test_evaluate_with_dependencies_variant_holdout(self): expected_variant = None self.assertEqual(expected_variant, variants.get('sdk-local-evaluation-ci-test-holdout')) + def test_evaluate_with_cohort(self): + targeted_user = User(user_id='12345', device_id='device_id') + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user, + {'sdk-local-evaluation-user-cohort-ci-test'}) + .get('sdk-local-evaluation-user-cohort-ci-test')) + expected_on_variant = Variant(key='on', value='on') + self.assertEqual(expected_on_variant, targeted_variant) + non_targeted_user = User(user_id='not_targeted') + non_targeted_variant = (self._local_evaluation_client.evaluate_v2(non_targeted_user) + .get('sdk-local-evaluation-user-cohort-ci-test')) + expected_off_variant = Variant(key='off') + self.assertEqual(expected_off_variant, non_targeted_variant) + + def test_evaluate_with_group_cohort(self): + targeted_user = User(user_id='12345', device_id='device_id', groups={'org id': ['1']}) + targeted_variant = (self._local_evaluation_client.evaluate_v2(targeted_user, + {'sdk-local-evaluation-group-cohort-ci-test'}) + .get('sdk-local-evaluation-group-cohort-ci-test')) + expected_on_variant = Variant(key='on', value='on') + self.assertEqual(expected_on_variant, targeted_variant) + non_targeted_user = User(user_id='12345', device_id='device_id', groups={'org id': ['not_targeted']}) + non_targeted_variant = (self._local_evaluation_client.evaluate_v2(non_targeted_user) + .get('sdk-local-evaluation-group-cohort-ci-test')) + expected_off_variant = Variant(key='off') + self.assertEqual(expected_off_variant, non_targeted_variant) + + def test_evaluation_cohorts_not_in_storage_with_sync_config(self): + with mock.patch.object(self._local_evaluation_client.cohort_storage, 'put_cohort', return_value=None): + self._local_evaluation_client.cohort_storage.get_cohort_ids = mock.Mock(return_value=set()) + targeted_user = User(user_id='12345') + + with self.assertLogs(self._local_evaluation_client.logger, level='WARNING') as log: + self._local_evaluation_client.evaluate_v2(targeted_user, {'sdk-local-evaluation-user-cohort-ci-test'}) + log_message = ( + "WARNING:Amplitude:Evaluating flag sdk-local-evaluation-user-cohort-ci-test dependent on cohorts " + "{.*} without {.*} in storage" + ) + self.assertTrue(any(re.match(log_message, message) for message in log.output)) + if __name__ == '__main__': unittest.main() diff --git a/tests/util/flag_config_test.py b/tests/util/flag_config_test.py new file mode 100644 index 0000000..0b16ab9 --- /dev/null +++ b/tests/util/flag_config_test.py @@ -0,0 +1,163 @@ +import unittest + +from src.amplitude_experiment.util.flag_config import ( + get_all_cohort_ids_from_flags, + get_grouped_cohort_ids_from_flags, + get_all_cohort_ids_from_flag, + get_grouped_cohort_ids_from_flag, +) + + +class CohortUtilsTestCase(unittest.TestCase): + + def setUp(self): + self.flags = [ + { + 'key': 'flag-1', + 'metadata': { + 'deployed': True, + 'evaluationMode': 'local', + 'flagType': 'release', + 'flagVersion': 1 + }, + 'segments': [ + { + 'conditions': [ + [ + { + 'op': 'set contains any', + 'selector': ['context', 'user', 'cohort_ids'], + 'values': ['cohort1', 'cohort2'] + } + ] + ], + 'metadata': {'segmentName': 'Segment A'}, + 'variant': 'on' + }, + { + 'metadata': {'segmentName': 'All Other Users'}, + 'variant': 'off' + } + ], + 'variants': { + 'off': { + 'key': 'off', + 'metadata': {'default': True} + }, + 'on': { + 'key': 'on', + 'value': 'on' + } + } + }, + { + 'key': 'flag-2', + 'metadata': { + 'deployed': True, + 'evaluationMode': 'local', + 'flagType': 'release', + 'flagVersion': 2 + }, + 'segments': [ + { + 'conditions': [ + [ + { + 'op': 'set contains any', + 'selector': ['context', 'user', 'cohort_ids'], + 'values': ['cohort3', 'cohort4', 'cohort5', 'cohort6'] + } + ] + ], + 'metadata': {'segmentName': 'Segment B'}, + 'variant': 'on' + }, + { + 'metadata': {'segmentName': 'All Other Users'}, + 'variant': 'off' + } + ], + 'variants': { + 'off': { + 'key': 'off', + 'metadata': {'default': True} + }, + 'on': { + 'key': 'on', + 'value': 'on' + } + } + }, + { + 'key': 'flag-3', + 'metadata': { + 'deployed': True, + 'evaluationMode': 'local', + 'flagType': 'release', + 'flagVersion': 3 + }, + 'segments': [ + { + 'conditions': [ + [ + { + 'op': 'set contains any', + 'selector': ['context', 'groups', 'group_name', 'cohort_ids'], + 'values': ['cohort7', 'cohort8'] + } + ] + ], + 'metadata': {'segmentName': 'Segment C'}, + 'variant': 'on' + }, + { + 'metadata': {'segmentName': 'All Other Groups'}, + 'variant': 'off' + } + ], + 'variants': { + 'off': { + 'key': 'off', + 'metadata': {'default': True} + }, + 'on': { + 'key': 'on', + 'value': 'on' + } + } + } + ] + + def test_get_all_cohort_ids(self): + expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6', 'cohort7', 'cohort8'} + for flag in self.flags: + cohort_ids = get_all_cohort_ids_from_flag(flag) + self.assertTrue(cohort_ids.issubset(expected_cohort_ids)) + + def test_get_grouped_cohort_ids_for_flag(self): + expected_grouped_cohort_ids = { + 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'}, + 'group_name': {'cohort7', 'cohort8'} + } + for flag in self.flags: + grouped_cohort_ids = get_grouped_cohort_ids_from_flag(flag) + for key, values in grouped_cohort_ids.items(): + self.assertTrue(key in expected_grouped_cohort_ids) + self.assertTrue(values.issubset(expected_grouped_cohort_ids[key])) + + def test_get_all_cohort_ids_from_flags(self): + expected_cohort_ids = {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6', 'cohort7', 'cohort8'} + cohort_ids = get_all_cohort_ids_from_flags(self.flags) + self.assertEqual(cohort_ids, expected_cohort_ids) + + def test_get_grouped_cohort_ids_for_flag_from_flags(self): + expected_grouped_cohort_ids = { + 'User': {'cohort1', 'cohort2', 'cohort3', 'cohort4', 'cohort5', 'cohort6'}, + 'group_name': {'cohort7', 'cohort8'} + } + grouped_cohort_ids = get_grouped_cohort_ids_from_flags(self.flags) + self.assertEqual(grouped_cohort_ids, expected_grouped_cohort_ids) + + +if __name__ == '__main__': + unittest.main()