Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support cohort targeting for local evaluation #47

Merged
merged 44 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
def9924
feat: Enable use of project API key for default deployments
tyiuhc Feb 1, 2024
405dcb6
initial commit
tyiuhc Jun 6, 2024
eac1cd3
update local eval client
tyiuhc Jun 6, 2024
7f864d5
fix imports
tyiuhc Jun 6, 2024
3ef92da
refactor
tyiuhc Jun 10, 2024
473fe7b
fix tests, add logging config
tyiuhc Jun 10, 2024
3d5dfe9
add CohortNotModifiedException
tyiuhc Jun 12, 2024
93eb012
update user transformation to evaluation context
tyiuhc Jun 13, 2024
981de9f
refactor and simplify to not use cohort_description
tyiuhc Jun 14, 2024
080e8f0
nit: fix formatting
tyiuhc Jun 14, 2024
23c8b25
handle flag fetch fail
tyiuhc Jun 14, 2024
f249b44
Use lastModified instead of lastComputed
tyiuhc Jun 17, 2024
04d35f6
add cohort_request_delay_millis to config
tyiuhc Jun 17, 2024
b4402eb
fix DirectCohortDownloadApi constructor
tyiuhc Jun 17, 2024
a7ffbb6
Simplify deployment_runner, clean up comments
tyiuhc Jun 25, 2024
f88d56e
revert default deployment changes
tyiuhc Jun 25, 2024
f98f09f
Update cohort sync config with comments and server_url config
tyiuhc Jun 26, 2024
7243d9b
fix EU flag url
tyiuhc Jun 26, 2024
4fe1fac
export CohortSyncConfig and ServerZone
tyiuhc Jun 26, 2024
a51067b
nit: simplify logic
tyiuhc Jun 27, 2024
f0e899b
Handle 204 errors
tyiuhc Jun 28, 2024
5130cdd
update deployment_runner flag/cohort update logic, update tests, fix …
tyiuhc Jul 2, 2024
5ed0a98
Update logger requirement for classes
tyiuhc Jul 3, 2024
916e5c1
Refactor cohort_loader update_storage_cohorts
tyiuhc Jul 3, 2024
013ffc9
fix lint
tyiuhc Jul 3, 2024
57e1cc2
remove unnecessary import
tyiuhc Jul 22, 2024
93d2f15
update test.yml
tyiuhc Jul 23, 2024
5b30cb7
add client cohort ci tests
tyiuhc Jul 24, 2024
12267a0
update requirements-dev dotenv version
tyiuhc Jul 24, 2024
03b9081
debug env vars
tyiuhc Jul 24, 2024
832a00c
test yml set env vars
tyiuhc Jul 24, 2024
135a286
test cases use os.environ for secrets
tyiuhc Jul 24, 2024
e1ff4a2
test-arm.yml env syntax fix
tyiuhc Jul 25, 2024
9d6d62f
update client tests
tyiuhc Jul 25, 2024
7fddc96
cohort not modified should not throw exception
tyiuhc Jul 30, 2024
8cfb128
nit: update test name
tyiuhc Jul 31, 2024
c71e0c7
do not throw exception upon start() if cohort download fails, log war…
tyiuhc Aug 1, 2024
85c6cf3
fix deployment runner logging
tyiuhc Aug 1, 2024
646dd5a
nit: fix test name
tyiuhc Aug 1, 2024
9864e46
update error log and test
tyiuhc Aug 1, 2024
7043732
update_stored_cohorts using load_cohort
tyiuhc Aug 5, 2024
1d974f1
refresh cohorts based on flag configs in storage
tyiuhc Aug 6, 2024
06e693e
update cohort_sync_config fields: include polling and remove request …
tyiuhc Aug 6, 2024
a9006cf
add SDK+version to cohort request header
tyiuhc Aug 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions .github/workflows/test-arm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@ on: [pull_request]
jobs:
aarch_job:
runs-on: ubuntu-latest
name: Test on ubuntu aarch64
environment: Unit Test
name: Test on Ubuntu aarch64
steps:
- uses: actions/checkout@v3
- uses: uraimo/run-on-arch-action@v2
name: Run Unit Test
- name: Checkout source code
uses: actions/checkout@v3

- name: Set up and run unit test on aarch64
uses: uraimo/run-on-arch-action@v2
id: runcmd
with:
env: |
API_KEY: ${{ secrets.API_KEY }}
SECRET_KEY: ${{ secrets.SECRET_KEY }}
EU_API_KEY: ${{ secrets.EU_API_KEY }}
EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }}
arch: aarch64
distro: ubuntu20.04
githubToken: ${{ github.token }}
Expand Down
9 changes: 7 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on: [pull_request]
jobs:
test:
runs-on: ubuntu-latest
environment: Unit Test
strategy:
matrix:
python-version: [ "3.7" ]
Expand All @@ -19,8 +20,12 @@ jobs:
cache: 'pip'

- name: Install requirements
run: pip install -r requirements.txt
pip install -r requirements-dev.txt
run: pip install -r requirements.txt && pip install -r requirements-dev.txt

- name: Unit Test
env:
API_KEY: ${{ secrets.API_KEY }}
SECRET_KEY: ${{ secrets.SECRET_KEY }}
EU_API_KEY: ${{ secrets.EU_API_KEY }}
EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }}
run: python -m unittest discover -s ./tests -p '*_test.py'
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ user = User(
variants = experiment.evaluate(user)
```

# Running unit tests suite
To setup for running test on local, create a `.env` file with following
contents, and replace `{API_KEY}` and `{SECRET_KEY}` (or `{EU_API_KEY}` and `{EU_SECRET_KEY}` for EU data center) for the project in test:

```
API_KEY={API_KEY}
SECRET_KEY={SECRET_KEY}
```

## More Information
Please visit our :100:[Developer Center](https://www.docs.developers.amplitude.com/experiment/sdks/python-sdk/) for more instructions on using our the SDK.

Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
parameterized~=0.9.0
python-dotenv~=0.21.1
2 changes: 2 additions & 0 deletions src/amplitude_experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
from .cookie import AmplitudeCookie
from .local.client import LocalEvaluationClient
from .local.config import LocalEvaluationConfig
from .local.config import ServerZone
from .assignment import AssignmentConfig
from .cohort.cohort_sync_config import CohortSyncConfig
13 changes: 13 additions & 0 deletions src/amplitude_experiment/cohort/cohort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass, field
from typing import ClassVar, Set

USER_GROUP_TYPE: ClassVar[str] = "User"


@dataclass
class Cohort:
id: str
last_modified: int
size: int
member_ids: Set[str]
group_type: str = field(default=USER_GROUP_TYPE)
91 changes: 91 additions & 0 deletions src/amplitude_experiment/cohort/cohort_download_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import time
import logging
import base64
import json
from http.client import HTTPResponse
from typing import Optional
from ..version import __version__

from .cohort import Cohort
from ..connection_pool import HTTPConnectionPool
from ..exception import HTTPErrorResponseException, CohortTooLargeException

COHORT_REQUEST_RETRY_DELAY_MILLIS = 100


class CohortDownloadApi:

def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:
raise NotImplementedError


class DirectCohortDownloadApi(CohortDownloadApi):
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger):
super().__init__()
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.server_url = server_url
self.logger = logger
self.__setup_connection_pool()

def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None:
self.logger.debug(f"getCohortMembers({cohort_id}): start")
errors = 0
while True:
response = None
try:
last_modified = None if cohort is None else cohort.last_modified
response = self._get_cohort_members_request(cohort_id, last_modified)
self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}")
if response.status == 200:
cohort_info = json.loads(response.read().decode("utf8"))
self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}")
return Cohort(
id=cohort_info['cohortId'],
last_modified=cohort_info['lastModified'],
size=cohort_info['size'],
member_ids=set(cohort_info['memberIds']),
group_type=cohort_info['groupType'],
)
elif response.status == 204:
self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified")
return
elif response.status == 413:
raise CohortTooLargeException(
f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}")
elif response.status != 202:
raise HTTPErrorResponseException(response.status,
f"Unexpected response code: {response.status}")
except Exception as e:
if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429):
errors += 1
self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}")
if errors >= 3 or isinstance(e, CohortTooLargeException):
raise e
time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000)

def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse:
headers = {
'Authorization': f'Basic {self._get_basic_auth()}',
'X-Amp-Exp-Library': f"experiment-python-server/{__version__}"
}
conn = self._connection_pool.acquire()
try:
url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}'
if last_modified is not None:
url += f'&lastModified={last_modified}'
response = conn.request('GET', url, headers=headers)
return response
finally:
self._connection_pool.release(conn)

def _get_basic_auth(self) -> str:
credentials = f'{self.api_key}:{self.secret_key}'
return base64.b64encode(credentials.encode('utf-8')).decode('utf-8')

def __setup_connection_pool(self):
scheme, _, host = self.server_url.split('/', 3)
timeout = 10
self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout,
scheme=scheme)
67 changes: 67 additions & 0 deletions src/amplitude_experiment/cohort/cohort_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from typing import Dict, Set
from concurrent.futures import ThreadPoolExecutor, Future, as_completed
import threading

from .cohort import Cohort
from .cohort_download_api import CohortDownloadApi
from .cohort_storage import CohortStorage
from ..exception import CohortsDownloadException


class CohortLoader:
def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage):
self.cohort_download_api = cohort_download_api
self.cohort_storage = cohort_storage
self.jobs: Dict[str, Future] = {}
self.lock_jobs = threading.Lock()
self.executor = ThreadPoolExecutor(
max_workers=32,
thread_name_prefix='CohortLoaderExecutor'
)

def load_cohort(self, cohort_id: str) -> Future:
with self.lock_jobs:
if cohort_id not in self.jobs:
future = self.executor.submit(self.__load_cohort_internal, cohort_id)
future.add_done_callback(lambda f: self._remove_job(cohort_id))
self.jobs[cohort_id] = future
return self.jobs[cohort_id]

def _remove_job(self, cohort_id: str):
if cohort_id in self.jobs:
with self.lock_jobs:
self.jobs.pop(cohort_id, None)

def download_cohort(self, cohort_id: str) -> Cohort:
cohort = self.cohort_storage.get_cohort(cohort_id)
return self.cohort_download_api.get_cohort(cohort_id, cohort)

def download_cohorts(self, cohort_ids: Set[str]) -> Future:
def update_task(task_cohort_ids):
errors = []
futures = []
for cohort_id in task_cohort_ids:
future = self.load_cohort(cohort_id)
futures.append(future)

for future in as_completed(futures):
try:
future.result()
except Exception as e:
cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None)
if cohort_id:
errors.append((cohort_id, e))

if errors:
raise CohortsDownloadException(errors)

return self.executor.submit(update_task, cohort_ids)

def __load_cohort_internal(self, cohort_id):
try:
cohort = self.download_cohort(cohort_id)
if cohort is not None:
self.cohort_storage.put_cohort(cohort)
except Exception as e:
raise e
73 changes: 73 additions & 0 deletions src/amplitude_experiment/cohort/cohort_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, Set, Optional
from threading import RLock

from .cohort import Cohort, USER_GROUP_TYPE


class CohortStorage:
def get_cohort(self, cohort_id: str):
raise NotImplementedError

def get_cohorts(self):
raise NotImplementedError

def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
raise NotImplementedError

def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
raise NotImplementedError

def put_cohort(self, cohort_description: Cohort):
raise NotImplementedError

def delete_cohort(self, group_type: str, cohort_id: str):
raise NotImplementedError

def get_cohort_ids(self) -> Set[str]:
raise NotImplementedError


class InMemoryCohortStorage(CohortStorage):
def __init__(self):
self.lock = RLock()
self.group_to_cohort_store: Dict[str, Set[str]] = {}
self.cohort_store: Dict[str, Cohort] = {}

def get_cohort(self, cohort_id: str):
with self.lock:
return self.cohort_store.get(cohort_id)

def get_cohorts(self):
return self.cohort_store.copy()

def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids)

def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
result = set()
with self.lock:
group_type_cohorts = self.group_to_cohort_store.get(group_type, {})
for cohort_id in group_type_cohorts:
members = self.cohort_store.get(cohort_id).member_ids
if cohort_id in cohort_ids and group_name in members:
result.add(cohort_id)
return result

def put_cohort(self, cohort: Cohort):
with self.lock:
if cohort.group_type not in self.group_to_cohort_store:
self.group_to_cohort_store[cohort.group_type] = set()
self.group_to_cohort_store[cohort.group_type].add(cohort.id)
self.cohort_store[cohort.id] = cohort

def delete_cohort(self, group_type: str, cohort_id: str):
with self.lock:
group_cohorts = self.group_to_cohort_store.get(group_type, {})
if cohort_id in group_cohorts:
group_cohorts.remove(cohort_id)
if cohort_id in self.cohort_store:
del self.cohort_store[cohort_id]

def get_cohort_ids(self):
with self.lock:
return set(self.cohort_store.keys())
24 changes: 24 additions & 0 deletions src/amplitude_experiment/cohort/cohort_sync_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com'
EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com'


class CohortSyncConfig:
"""Experiment Cohort Sync Configuration
This configuration is used to set up the cohort loader. The cohort loader is responsible for
downloading cohorts from the server and storing them locally.
Parameters:
api_key (str): The project API Key
secret_key (str): The project Secret Key
max_cohort_size (int): The maximum cohort size that can be downloaded
cohort_polling_interval_millis (int): The interval, in milliseconds, at which to poll for
cohort updates, minimum 60000
cohort_server_url (str): The server endpoint from which to request cohorts
"""

def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647,
cohort_polling_interval_millis: int = 60000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL):
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.cohort_polling_interval_millis = max(cohort_polling_interval_millis, 60000)
self.cohort_server_url = cohort_server_url
Loading
Loading