Skip to content

Commit

Permalink
feat: support cohort targeting for local evaluation (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
tyiuhc authored Aug 27, 2024
1 parent 4503a60 commit d8c62c4
Show file tree
Hide file tree
Showing 25 changed files with 1,255 additions and 70 deletions.
16 changes: 12 additions & 4 deletions .github/workflows/test-arm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@ on: [pull_request]
jobs:
aarch_job:
runs-on: ubuntu-latest
name: Test on ubuntu aarch64
environment: Unit Test
name: Test on Ubuntu aarch64
steps:
- uses: actions/checkout@v3
- uses: uraimo/run-on-arch-action@v2
name: Run Unit Test
- name: Checkout source code
uses: actions/checkout@v3

- name: Set up and run unit test on aarch64
uses: uraimo/run-on-arch-action@v2
id: runcmd
with:
env: |
API_KEY: ${{ secrets.API_KEY }}
SECRET_KEY: ${{ secrets.SECRET_KEY }}
EU_API_KEY: ${{ secrets.EU_API_KEY }}
EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }}
arch: aarch64
distro: ubuntu20.04
githubToken: ${{ github.token }}
Expand Down
9 changes: 7 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on: [pull_request]
jobs:
test:
runs-on: ubuntu-latest
environment: Unit Test
strategy:
matrix:
python-version: [ "3.7" ]
Expand All @@ -19,8 +20,12 @@ jobs:
cache: 'pip'

- name: Install requirements
run: pip install -r requirements.txt
pip install -r requirements-dev.txt
run: pip install -r requirements.txt && pip install -r requirements-dev.txt

- name: Unit Test
env:
API_KEY: ${{ secrets.API_KEY }}
SECRET_KEY: ${{ secrets.SECRET_KEY }}
EU_API_KEY: ${{ secrets.EU_API_KEY }}
EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }}
run: python -m unittest discover -s ./tests -p '*_test.py'
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ user = User(
variants = experiment.evaluate(user)
```

# Running unit tests suite
To setup for running test on local, create a `.env` file with following
contents, and replace `{API_KEY}` and `{SECRET_KEY}` (or `{EU_API_KEY}` and `{EU_SECRET_KEY}` for EU data center) for the project in test:

```
API_KEY={API_KEY}
SECRET_KEY={SECRET_KEY}
```

## More Information
Please visit our :100:[Developer Center](https://www.docs.developers.amplitude.com/experiment/sdks/python-sdk/) for more instructions on using our the SDK.

Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
parameterized~=0.9.0
python-dotenv~=0.21.1
2 changes: 2 additions & 0 deletions src/amplitude_experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
from .cookie import AmplitudeCookie
from .local.client import LocalEvaluationClient
from .local.config import LocalEvaluationConfig
from .local.config import ServerZone
from .assignment import AssignmentConfig
from .cohort.cohort_sync_config import CohortSyncConfig
13 changes: 13 additions & 0 deletions src/amplitude_experiment/cohort/cohort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass, field
from typing import ClassVar, Set

USER_GROUP_TYPE: ClassVar[str] = "User"


@dataclass
class Cohort:
id: str
last_modified: int
size: int
member_ids: Set[str]
group_type: str = field(default=USER_GROUP_TYPE)
91 changes: 91 additions & 0 deletions src/amplitude_experiment/cohort/cohort_download_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import time
import logging
import base64
import json
from http.client import HTTPResponse
from typing import Optional
from ..version import __version__

from .cohort import Cohort
from ..connection_pool import HTTPConnectionPool
from ..exception import HTTPErrorResponseException, CohortTooLargeException

COHORT_REQUEST_RETRY_DELAY_MILLIS = 100


class CohortDownloadApi:

def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:
raise NotImplementedError


class DirectCohortDownloadApi(CohortDownloadApi):
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger):
super().__init__()
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.server_url = server_url
self.logger = logger
self.__setup_connection_pool()

def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None:
self.logger.debug(f"getCohortMembers({cohort_id}): start")
errors = 0
while True:
response = None
try:
last_modified = None if cohort is None else cohort.last_modified
response = self._get_cohort_members_request(cohort_id, last_modified)
self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}")
if response.status == 200:
cohort_info = json.loads(response.read().decode("utf8"))
self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}")
return Cohort(
id=cohort_info['cohortId'],
last_modified=cohort_info['lastModified'],
size=cohort_info['size'],
member_ids=set(cohort_info['memberIds']),
group_type=cohort_info['groupType'],
)
elif response.status == 204:
self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified")
return
elif response.status == 413:
raise CohortTooLargeException(
f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}")
elif response.status != 202:
raise HTTPErrorResponseException(response.status,
f"Unexpected response code: {response.status}")
except Exception as e:
if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429):
errors += 1
self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}")
if errors >= 3 or isinstance(e, CohortTooLargeException):
raise e
time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000)

def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse:
headers = {
'Authorization': f'Basic {self._get_basic_auth()}',
'X-Amp-Exp-Library': f"experiment-python-server/{__version__}"
}
conn = self._connection_pool.acquire()
try:
url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}'
if last_modified is not None:
url += f'&lastModified={last_modified}'
response = conn.request('GET', url, headers=headers)
return response
finally:
self._connection_pool.release(conn)

def _get_basic_auth(self) -> str:
credentials = f'{self.api_key}:{self.secret_key}'
return base64.b64encode(credentials.encode('utf-8')).decode('utf-8')

def __setup_connection_pool(self):
scheme, _, host = self.server_url.split('/', 3)
timeout = 10
self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout,
scheme=scheme)
67 changes: 67 additions & 0 deletions src/amplitude_experiment/cohort/cohort_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from typing import Dict, Set
from concurrent.futures import ThreadPoolExecutor, Future, as_completed
import threading

from .cohort import Cohort
from .cohort_download_api import CohortDownloadApi
from .cohort_storage import CohortStorage
from ..exception import CohortsDownloadException


class CohortLoader:
def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage):
self.cohort_download_api = cohort_download_api
self.cohort_storage = cohort_storage
self.jobs: Dict[str, Future] = {}
self.lock_jobs = threading.Lock()
self.executor = ThreadPoolExecutor(
max_workers=32,
thread_name_prefix='CohortLoaderExecutor'
)

def load_cohort(self, cohort_id: str) -> Future:
with self.lock_jobs:
if cohort_id not in self.jobs:
future = self.executor.submit(self.__load_cohort_internal, cohort_id)
future.add_done_callback(lambda f: self._remove_job(cohort_id))
self.jobs[cohort_id] = future
return self.jobs[cohort_id]

def _remove_job(self, cohort_id: str):
if cohort_id in self.jobs:
with self.lock_jobs:
self.jobs.pop(cohort_id, None)

def download_cohort(self, cohort_id: str) -> Cohort:
cohort = self.cohort_storage.get_cohort(cohort_id)
return self.cohort_download_api.get_cohort(cohort_id, cohort)

def download_cohorts(self, cohort_ids: Set[str]) -> Future:
def update_task(task_cohort_ids):
errors = []
futures = []
for cohort_id in task_cohort_ids:
future = self.load_cohort(cohort_id)
futures.append(future)

for future in as_completed(futures):
try:
future.result()
except Exception as e:
cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None)
if cohort_id:
errors.append((cohort_id, e))

if errors:
raise CohortsDownloadException(errors)

return self.executor.submit(update_task, cohort_ids)

def __load_cohort_internal(self, cohort_id):
try:
cohort = self.download_cohort(cohort_id)
if cohort is not None:
self.cohort_storage.put_cohort(cohort)
except Exception as e:
raise e
73 changes: 73 additions & 0 deletions src/amplitude_experiment/cohort/cohort_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, Set, Optional
from threading import RLock

from .cohort import Cohort, USER_GROUP_TYPE


class CohortStorage:
def get_cohort(self, cohort_id: str):
raise NotImplementedError

def get_cohorts(self):
raise NotImplementedError

def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
raise NotImplementedError

def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
raise NotImplementedError

def put_cohort(self, cohort_description: Cohort):
raise NotImplementedError

def delete_cohort(self, group_type: str, cohort_id: str):
raise NotImplementedError

def get_cohort_ids(self) -> Set[str]:
raise NotImplementedError


class InMemoryCohortStorage(CohortStorage):
def __init__(self):
self.lock = RLock()
self.group_to_cohort_store: Dict[str, Set[str]] = {}
self.cohort_store: Dict[str, Cohort] = {}

def get_cohort(self, cohort_id: str):
with self.lock:
return self.cohort_store.get(cohort_id)

def get_cohorts(self):
return self.cohort_store.copy()

def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids)

def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
result = set()
with self.lock:
group_type_cohorts = self.group_to_cohort_store.get(group_type, {})
for cohort_id in group_type_cohorts:
members = self.cohort_store.get(cohort_id).member_ids
if cohort_id in cohort_ids and group_name in members:
result.add(cohort_id)
return result

def put_cohort(self, cohort: Cohort):
with self.lock:
if cohort.group_type not in self.group_to_cohort_store:
self.group_to_cohort_store[cohort.group_type] = set()
self.group_to_cohort_store[cohort.group_type].add(cohort.id)
self.cohort_store[cohort.id] = cohort

def delete_cohort(self, group_type: str, cohort_id: str):
with self.lock:
group_cohorts = self.group_to_cohort_store.get(group_type, {})
if cohort_id in group_cohorts:
group_cohorts.remove(cohort_id)
if cohort_id in self.cohort_store:
del self.cohort_store[cohort_id]

def get_cohort_ids(self):
with self.lock:
return set(self.cohort_store.keys())
24 changes: 24 additions & 0 deletions src/amplitude_experiment/cohort/cohort_sync_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com'
EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com'


class CohortSyncConfig:
"""Experiment Cohort Sync Configuration
This configuration is used to set up the cohort loader. The cohort loader is responsible for
downloading cohorts from the server and storing them locally.
Parameters:
api_key (str): The project API Key
secret_key (str): The project Secret Key
max_cohort_size (int): The maximum cohort size that can be downloaded
cohort_polling_interval_millis (int): The interval, in milliseconds, at which to poll for
cohort updates, minimum 60000
cohort_server_url (str): The server endpoint from which to request cohorts
"""

def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647,
cohort_polling_interval_millis: int = 60000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL):
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.cohort_polling_interval_millis = max(cohort_polling_interval_millis, 60000)
self.cohort_server_url = cohort_server_url
Loading

0 comments on commit d8c62c4

Please sign in to comment.