-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support cohort targeting for local evaluation (#68)
- Loading branch information
Showing
30 changed files
with
1,311 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
module AmplitudeExperiment | ||
USER_GROUP_TYPE = 'User'.freeze | ||
# Cohort | ||
class Cohort | ||
attr_accessor :id, :last_modified, :size, :member_ids, :group_type | ||
|
||
def initialize(id, last_modified, size, member_ids, group_type = USER_GROUP_TYPE) | ||
@id = id | ||
@last_modified = last_modified | ||
@size = size | ||
@member_ids = member_ids.to_set | ||
@group_type = group_type | ||
end | ||
|
||
def ==(other) | ||
return false unless other.is_a?(Cohort) | ||
|
||
@id == other.id && | ||
@last_modified == other.last_modified && | ||
@size == other.size && | ||
@member_ids == other.member_ids && | ||
@group_type == other.group_type | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
require 'base64' | ||
require 'json' | ||
require 'net/http' | ||
require 'uri' | ||
require 'set' | ||
|
||
module AmplitudeExperiment | ||
# CohortDownloadApi | ||
class CohortDownloadApi | ||
COHORT_REQUEST_TIMEOUT_MILLIS = 5000 | ||
COHORT_REQUEST_RETRY_DELAY_MILLIS = 100 | ||
|
||
def get_cohort(cohort_id, cohort = nil) | ||
raise NotImplementedError | ||
end | ||
end | ||
|
||
# DirectCohortDownloadApi | ||
class DirectCohortDownloadApi < CohortDownloadApi | ||
def initialize(api_key, secret_key, max_cohort_size, server_url, logger) | ||
super() | ||
@api_key = api_key | ||
@secret_key = secret_key | ||
@max_cohort_size = max_cohort_size | ||
@server_url = server_url | ||
@logger = logger | ||
end | ||
|
||
def get_cohort(cohort_id, cohort = nil) | ||
@logger.debug("getCohortMembers(#{cohort_id}): start") | ||
errors = 0 | ||
|
||
loop do | ||
begin | ||
last_modified = cohort.nil? ? nil : cohort.last_modified | ||
response = get_cohort_members_request(cohort_id, last_modified) | ||
@logger.debug("getCohortMembers(#{cohort_id}): status=#{response.code}") | ||
|
||
case response.code.to_i | ||
when 200 | ||
cohort_info = JSON.parse(response.body) | ||
@logger.debug("getCohortMembers(#{cohort_id}): end - resultSize=#{cohort_info['size']}") | ||
return Cohort.new( | ||
cohort_info['cohortId'], | ||
cohort_info['lastModified'], | ||
cohort_info['size'], | ||
cohort_info['memberIds'].to_set, | ||
cohort_info['groupType'] | ||
) | ||
when 204 | ||
@logger.debug("getCohortMembers(#{cohort_id}): Cohort not modified") | ||
return nil | ||
when 413 | ||
raise CohortTooLargeError.new(cohort_id, "Cohort exceeds max cohort size: #{response.code}") | ||
else | ||
raise HTTPErrorResponseError.new(response.code, cohort_id, "Unexpected response code: #{response.code}") if response.code.to_i != 202 | ||
|
||
end | ||
rescue StandardError => e | ||
errors += 1 unless response && e.is_a?(HTTPErrorResponseError) && response.code.to_i == 429 | ||
@logger.debug("getCohortMembers(#{cohort_id}): request-status error #{errors} - #{e}") | ||
raise e if errors >= 3 || e.is_a?(CohortTooLargeError) | ||
end | ||
|
||
sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000.0) | ||
end | ||
end | ||
|
||
private | ||
|
||
def get_cohort_members_request(cohort_id, last_modified) | ||
headers = { | ||
'Authorization' => "Basic #{basic_auth}", | ||
'Content-Type' => 'application/json;charset=utf-8', | ||
'X-Amp-Exp-Library' => "experiment-ruby-server/#{VERSION}" | ||
} | ||
url = "#{@server_url}/sdk/v1/cohort/#{cohort_id}?maxCohortSize=#{@max_cohort_size}" | ||
url += "&lastModified=#{last_modified}" if last_modified | ||
|
||
request = Net::HTTP::Get.new(URI(url), headers) | ||
http = PersistentHttpClient.get(@server_url, { read_timeout: COHORT_REQUEST_TIMEOUT_MILLIS }, basic_auth) | ||
http.request(request) | ||
end | ||
|
||
def basic_auth | ||
credentials = "#{@api_key}:#{@secret_key}" | ||
Base64.strict_encode64(credentials) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
module AmplitudeExperiment | ||
# CohortLoader | ||
class CohortLoader | ||
def initialize(cohort_download_api, cohort_storage) | ||
@cohort_download_api = cohort_download_api | ||
@cohort_storage = cohort_storage | ||
@jobs = {} | ||
@lock_jobs = Mutex.new | ||
end | ||
|
||
def load_cohort(cohort_id) | ||
@lock_jobs.synchronize do | ||
unless @jobs.key?(cohort_id) | ||
future = Concurrent::Promises.future do | ||
load_cohort_internal(cohort_id) | ||
ensure | ||
remove_job(cohort_id) | ||
end | ||
@jobs[cohort_id] = future | ||
end | ||
@jobs[cohort_id] | ||
end | ||
end | ||
|
||
private | ||
|
||
def load_cohort_internal(cohort_id) | ||
stored_cohort = @cohort_storage.cohort(cohort_id) | ||
updated_cohort = @cohort_download_api.get_cohort(cohort_id, stored_cohort) | ||
@cohort_storage.put_cohort(updated_cohort) unless updated_cohort.nil? | ||
end | ||
|
||
def remove_job(cohort_id) | ||
@lock_jobs.synchronize do | ||
@jobs.delete(cohort_id) | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
module AmplitudeExperiment | ||
# CohortStorage | ||
class CohortStorage | ||
def cohort(cohort_id) | ||
raise NotImplementedError | ||
end | ||
|
||
def cohorts | ||
raise NotImplementedError | ||
end | ||
|
||
def get_cohorts_for_user(user_id, cohort_ids) | ||
raise NotImplementedError | ||
end | ||
|
||
def get_cohorts_for_group(group_type, group_name, cohort_ids) | ||
raise NotImplementedError | ||
end | ||
|
||
def put_cohort(cohort_description) | ||
raise NotImplementedError | ||
end | ||
|
||
def delete_cohort(group_type, cohort_id) | ||
raise NotImplementedError | ||
end | ||
|
||
def cohort_ids | ||
raise NotImplementedError | ||
end | ||
end | ||
|
||
class InMemoryCohortStorage < CohortStorage | ||
def initialize | ||
super | ||
@lock = Mutex.new | ||
@group_to_cohort_store = {} | ||
@cohort_store = {} | ||
end | ||
|
||
def cohort(cohort_id) | ||
@lock.synchronize do | ||
@cohort_store[cohort_id] | ||
end | ||
end | ||
|
||
def cohorts | ||
@lock.synchronize do | ||
@cohort_store.dup | ||
end | ||
end | ||
|
||
def get_cohorts_for_user(user_id, cohort_ids) | ||
get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids) | ||
end | ||
|
||
def get_cohorts_for_group(group_type, group_name, cohort_ids) | ||
result = Set.new | ||
@lock.synchronize do | ||
group_type_cohorts = @group_to_cohort_store[group_type] || Set.new | ||
group_type_cohorts.each do |cohort_id| | ||
members = @cohort_store[cohort_id]&.member_ids || Set.new | ||
result.add(cohort_id) if cohort_ids.include?(cohort_id) && members.include?(group_name) | ||
end | ||
end | ||
result | ||
end | ||
|
||
def put_cohort(cohort) | ||
@lock.synchronize do | ||
@group_to_cohort_store[cohort.group_type] ||= Set.new | ||
@group_to_cohort_store[cohort.group_type].add(cohort.id) | ||
@cohort_store[cohort.id] = cohort | ||
end | ||
end | ||
|
||
def delete_cohort(group_type, cohort_id) | ||
@lock.synchronize do | ||
group_cohorts = @group_to_cohort_store[group_type] || Set.new | ||
group_cohorts.delete(cohort_id) | ||
@cohort_store.delete(cohort_id) | ||
end | ||
end | ||
|
||
def cohort_ids | ||
@lock.synchronize do | ||
@cohort_store.keys.to_set | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
module AmplitudeExperiment | ||
DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com'.freeze | ||
EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com'.freeze | ||
|
||
# Experiment Cohort Sync Configuration | ||
class CohortSyncConfig | ||
# This configuration is used to set up the cohort loader. The cohort loader is responsible for | ||
# downloading cohorts from the server and storing them locally. | ||
# Parameters: | ||
# api_key (str): The project API Key | ||
# secret_key (str): The project Secret Key | ||
# max_cohort_size (int): The maximum cohort size that can be downloaded | ||
# cohort_polling_interval_millis (int): The interval in milliseconds to poll for cohorts, the minimum value is 60000 | ||
# cohort_server_url (str): The server endpoint from which to request cohorts | ||
|
||
attr_accessor :api_key, :secret_key, :max_cohort_size, :cohort_polling_interval_millis, :cohort_server_url | ||
|
||
def initialize(api_key, secret_key, max_cohort_size: 2_147_483_647, cohort_polling_interval_millis: 60_000, | ||
cohort_server_url: DEFAULT_COHORT_SYNC_URL) | ||
@api_key = api_key | ||
@secret_key = secret_key | ||
@max_cohort_size = max_cohort_size | ||
@cohort_polling_interval_millis = [cohort_polling_interval_millis, 60_000].max | ||
@cohort_server_url = cohort_server_url | ||
end | ||
end | ||
end |
Oops, something went wrong.