Skip to content

Commit

Permalink
Group matching feature (#102)
Browse files Browse the repository at this point in the history
* group matching

* simplify naming, add tests, add docstrings, enable group size in tasks

* fix linting errors
  • Loading branch information
kentwills authored Aug 30, 2017
1 parent e9def1d commit f0b9bd1
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 10 deletions.
2 changes: 1 addition & 1 deletion js/components/MetricsListItem.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const MetricsListItem = ({ metric }) => (
</li>
</ul>
</div>
);
);

MetricsListItem.propTypes = {
metric: PropTypes.object.isRequired, // eslint-disable-line react/forbid-prop-types
Expand Down
2 changes: 1 addition & 1 deletion js/containers/MetricsList.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MetricsList extends Component {
key={metric.title}
metric={metric}
/>
));
));
}
render() {
return (
Expand Down
2 changes: 1 addition & 1 deletion js/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ ReactDOM.render(
<Provider store={store(reducers)}>
<Router history={browserHistory} routes={routes} />
</Provider>,
document.querySelector('#container'),
document.querySelector('#container'),
);
98 changes: 97 additions & 1 deletion tests/match_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from yelp_beans.logic.subscription import get_specs_from_subscription
from yelp_beans.logic.subscription import store_specs_from_subscription
from yelp_beans.matching.group_match import generate_groups
from yelp_beans.matching.group_match import get_previous_meetings_counts
from yelp_beans.matching.group_match import get_user_weights
from yelp_beans.matching.match import generate_meetings
from yelp_beans.matching.match_utils import get_counts_for_pairs
from yelp_beans.matching.match_utils import get_previous_meetings
Expand Down Expand Up @@ -205,7 +207,7 @@ def test_pair_to_counts():
assert(counts[('user1', 'user2')] == 2)


def test_get_previous_meetings_counts():
def test_get_previous_meetings_counts(minimal_database):
pref_1 = SubscriptionDateTime(datetime=datetime.now() - timedelta(weeks=MEETING_COOLDOWN_WEEKS - 1)).put()
subscription = MeetingSubscription(title='all engineering weekly', datetime=[pref_1]).put()
user_pref = UserSubscriptionPreferences(preference=pref_1, subscription=subscription).put()
Expand All @@ -217,3 +219,97 @@ def test_get_previous_meetings_counts():
MeetingParticipant(meeting=meeting, user=user1).put()

assert(get_previous_meetings_counts([user1.get(), user2.get()], subscription) == {(user1.id(), user2.id()): 1})


def test_generate_groups():
result = generate_groups([1, 2, 3, 4, 5], 3)
assert [x for x in result] == [[1, 2, 3], [4, 5]]

result = generate_groups([1, 2, 3, 4], 2)
assert [x for x in result] == [[1, 2], [3, 4]]

result = generate_groups([1, 2, 3], 3)
assert [x for x in result] == [[1, 2, 3]]


def test_get_user_weights(minimal_database):
pref_1 = SubscriptionDateTime(datetime=datetime.now() - timedelta(weeks=MEETING_COOLDOWN_WEEKS - 1)).put()
subscription = MeetingSubscription(title='all engineering weekly', datetime=[pref_1]).put()
user_pref = UserSubscriptionPreferences(preference=pref_1, subscription=subscription).put()
user1 = User(email='[email protected]', metadata={'department': 'dept'}, subscription_preferences=[user_pref]).put()
user2 = User(email='[email protected]', metadata={'department': 'dept2'}, subscription_preferences=[user_pref]).put()
meeting_spec = MeetingSpec(meeting_subscription=subscription, datetime=pref_1.get().datetime).put()
meeting = Meeting(meeting_spec=meeting_spec, cancelled=False).put()
MeetingParticipant(meeting=meeting, user=user2).put()
MeetingParticipant(meeting=meeting, user=user1).put()
previous_meetings_count = get_previous_meetings_counts([user1.get(), user2.get()], subscription)

assert(get_user_weights([user1.get(), user2.get()], previous_meetings_count, 10, 5) == [[0, 5], [5, 0]])


def test_generate_group_meeting(minimal_database):
pref_1 = SubscriptionDateTime(datetime=datetime.now() - timedelta(weeks=MEETING_COOLDOWN_WEEKS - 1)).put()
subscription = MeetingSubscription(title='all engineering weekly', datetime=[pref_1]).put()
user_pref = UserSubscriptionPreferences(preference=pref_1, subscription=subscription).put()
meeting_spec = MeetingSpec(meeting_subscription=subscription, datetime=pref_1.get().datetime)
meeting_spec.put()

users = []
num_users = 21
for i in range(0, num_users):
user = User(email='{}@yelp.com'.format(i), metadata={
'department': 'dept{}'.format(i)}, subscription_preferences=[user_pref])
user.put()
MeetingRequest(user=user.key, meeting_spec=meeting_spec.key).put()
users.append(user)

matches, unmatched = generate_meetings(users, meeting_spec, prev_meeting_tuples=None, group_size=3)
assert(len(matches) == 7)
assert (len(unmatched) == 0)
matches, unmatched = generate_meetings(users, meeting_spec, prev_meeting_tuples=None, group_size=5)
assert(len(matches) == 4)
assert (len(unmatched) == 1)


def test_previous_meeting_penalty(minimal_database):
pref_1 = SubscriptionDateTime(datetime=datetime.now() - timedelta(weeks=MEETING_COOLDOWN_WEEKS - 1)).put()
pref_2 = SubscriptionDateTime(datetime=datetime.now() - timedelta(weeks=MEETING_COOLDOWN_WEEKS - 2)).put()
pref_3 = SubscriptionDateTime(datetime=datetime.now() - timedelta(weeks=MEETING_COOLDOWN_WEEKS - 3)).put()
subscription = MeetingSubscription(title='all engineering weekly', datetime=[pref_1, pref_2, pref_3]).put()
user_pref1 = UserSubscriptionPreferences(preference=pref_1, subscription=subscription).put()
user_pref2 = UserSubscriptionPreferences(preference=pref_2, subscription=subscription).put()
user_pref3 = UserSubscriptionPreferences(preference=pref_3, subscription=subscription).put()
meeting_spec1 = MeetingSpec(meeting_subscription=subscription, datetime=pref_1.get().datetime)
meeting_spec1.put()
meeting_spec2 = MeetingSpec(meeting_subscription=subscription, datetime=pref_2.get().datetime)
meeting_spec2.put()
meeting_spec3 = MeetingSpec(meeting_subscription=subscription, datetime=pref_3.get().datetime)
meeting_spec3.put()

users = []
num_users = 20
for i in range(0, num_users):
user = User(email='{}@yelp.com'.format(i), metadata={
'department': 'dept{}'.format(i)}, subscription_preferences=[user_pref1, user_pref2, user_pref3])
user.put()
MeetingRequest(user=user.key, meeting_spec=meeting_spec1.key).put()
MeetingRequest(user=user.key, meeting_spec=meeting_spec2.key).put()
MeetingRequest(user=user.key, meeting_spec=meeting_spec3.key).put()
users.append(user)

meeting1 = Meeting(meeting_spec=meeting_spec1.key, cancelled=False).put()
MeetingParticipant(meeting=meeting1, user=users[1].key).put()
MeetingParticipant(meeting=meeting1, user=users[0].key).put()
meeting2 = Meeting(meeting_spec=meeting_spec2.key, cancelled=False).put()
MeetingParticipant(meeting=meeting2, user=users[1].key).put()
MeetingParticipant(meeting=meeting2, user=users[0].key).put()
meeting3 = Meeting(meeting_spec=meeting_spec3.key, cancelled=False).put()
MeetingParticipant(meeting=meeting3, user=users[1].key).put()
MeetingParticipant(meeting=meeting3, user=users[0].key).put()

for run in range(10):
matches, unmatched = generate_meetings(users, meeting_spec1, prev_meeting_tuples=None, group_size=3)
assert(len(matches) == 6)
assert (len(unmatched) == 2)
for matched_group in matches:
assert(not (users[0] in matched_group and users[1] in matched_group))
152 changes: 148 additions & 4 deletions yelp_beans/matching/group_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,160 @@
from __future__ import unicode_literals

import itertools
import logging
import random

from yelp_beans.logic.user import user_preference
from yelp_beans.matching.match_utils import get_counts_for_pairs
from yelp_beans.matching.match_utils import get_previous_meetings


def get_previous_meetings_counts(users, subscription):
previous_meetings = get_previous_meetings(subscription)
def get_previous_meetings_counts(users, subscription_key):
"""
Given users for a subscription, return the number of times two people have matched
:param users: id of user
:param subscription_key: Key referencing the subscription model entity
:return: Tuple of user id's matched to count ie. {(4L, 5L): 5}
"""
previous_meetings = get_previous_meetings(subscription_key)
counts_for_pairs = get_counts_for_pairs(previous_meetings)
userids = sorted([user.key.id() for user in users])
all_pairs_counts = {pair: 0 for pair in itertools.combinations(userids, 2)}
user_ids = sorted([user.key.id() for user in users])
all_pairs_counts = {pair: 0 for pair in itertools.combinations(user_ids, 2)}
for pair in counts_for_pairs:
all_pairs_counts[pair] = counts_for_pairs[pair]
return all_pairs_counts


def get_user_weights(users, previous_meetings_counts, starting_weight, negative_weight):
"""
Given users asking for a match and historical information about the previous people they met,
return weights to promote groups where people haven't met each other.
:param users: list of user models
:param previous_meetings_counts: tuple of user id's matched to count
:param starting_weight: initial weight between users
:param negative_weight: amount to subtract from initial weight based on previous meetings
:return: adjacency matrix from user to user
"""
user_user_weights = []
for idx1, user1 in enumerate(users):
user_user_weights.append([])
for idx2, user2 in enumerate(users):
pair_tuple = tuple(sorted((user1.key.id(), user2.key.id())))
if pair_tuple not in previous_meetings_counts:
user_user_weights[idx1].append(0)
continue
weight = starting_weight - (negative_weight * previous_meetings_counts[pair_tuple])
user_user_weights[idx1].append(weight)
return user_user_weights


def generate_groups(group, partition_size):
"""
Given a group, partition into smaller groups of a specific size. Zero
for group size is invalid. Partitions will never exceed inputted value, but may be smaller.
:param group: List of ids
:param partition_size: Intended size for group
:return: list of groups
"""
for i in range(0, len(group), partition_size):
yield group[i:i + partition_size if (i + partition_size) < len(group) else len(group)]


def generate_group_meetings(users, spec, group_size, starting_weight, negative_weight):
population_size = len(users)
previous_meetings_counts = get_previous_meetings_counts(users, spec.meeting_subscription)
adj_matrix = get_user_weights(users, previous_meetings_counts, starting_weight, negative_weight)
annealing = Annealing(population_size, group_size, adj_matrix)
grouped_ids = generate_groups(annealing.simulated_annealing(), group_size)

matches = []
unmatched = []
for group in grouped_ids:
group_users = [users[idx] for idx in group]
if len(group) < group_size:
unmatched.extend(group_users)
continue
time = user_preference(users[0], spec)
group_users.append(time)
users_time_tuple = tuple(group_users)
matches.append(users_time_tuple)
logging.info('{} employees matched'.format(len(matches) * group_size))
for group in matches:
username_tuple = tuple([user.get_username() for user in group[:-1]])
logging.info(username_tuple)

logging.info('{} employees unmatched'.format(len(unmatched)))
logging.info([user.get_username() for user in unmatched])

return matches, unmatched


class Annealing:
def __init__(self, population_size, group_size, adj_matrix, max_iterations=100):
self.population_size = population_size
self.group_size = group_size
self.adj_matrix = adj_matrix
self.max_iterations = max_iterations

def get_initial_state(self):
ids = [i for i in range(self.population_size)]
random.shuffle(ids)
return State(self.population_size, self.group_size, ids)

def get_temp(self, iteration):
return 1.0 - (self.max_iterations - iteration) / (self.max_iterations + iteration)

def simulated_annealing(self):
prev_state = self.get_initial_state()
best_state = prev_state.copy()

best_cost = prev_cost = prev_state.get_cost(self.adj_matrix)

for iteration in range(self.max_iterations):
temp = self.get_temp(iteration)

curr_state = prev_state.get_mutated_state()
curr_cost = curr_state.get_cost(self.adj_matrix)

if curr_cost > best_cost:
best_cost = curr_cost
best_state = curr_state

if curr_cost > prev_cost or 1.0 * curr_cost / (prev_cost + 1) * temp < random.random():
prev_cost = curr_cost
prev_state = curr_state

return best_state.ids


class State:
def __init__(self, population_size, group_size, ids):
self.population_size = population_size
self.group_size = group_size
self.ids = ids

def copy(self):
return State(self.population_size, self.group_size, self.ids[:])

def get_cost(self, adj_matrix):
cost = 0
for i in range(0, len(self.ids), self.group_size):
cost += sum([
adj_matrix[edge[0]][edge[1]]
for edge in itertools.combinations(self.ids[i:i + self.group_size], 2)
])
return cost

def get_mutated_state(self):
x = random.randint(0, len(self.ids) - 1)
y = random.randint(0, len(self.ids) - 2)
if y >= x:
y += 1

ids = self.ids[:]
ids[x], ids[y] = ids[y], ids[x]
return State(
self.population_size,
self.group_size,
ids
)
5 changes: 4 additions & 1 deletion yelp_beans/matching/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from __future__ import print_function
from __future__ import unicode_literals

from yelp_beans.matching.group_match import generate_group_meetings
from yelp_beans.matching.pair_match import generate_pair_meetings


def generate_meetings(users, spec, prev_meeting_tuples=None, group_size=2):
if group_size == 2:
return generate_pair_meetings(users, spec, prev_meeting_tuples)
elif group_size > 2:
return generate_group_meetings(users, spec, group_size, 10, 5)
else:
raise NotImplementedError("Group matching not implemented yet.")
raise ValueError("Group size must be greater than 1.")
4 changes: 3 additions & 1 deletion yelp_beans/routes/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def generate_meeting_specs():
@tasks.route('/email_users_for_weekly_opt_in', methods=['GET'])
def weekly_opt_in():
for spec in get_specs_for_current_week():
logging.info(spec)
send_batch_weekly_opt_in_email(spec)
return 'OK'

Expand Down Expand Up @@ -67,7 +68,8 @@ def match_employees():
logging.info('Users: ')
logging.info([user.get_username() for user in users])

matches, unmatched = generate_meetings(users, spec, prev_meeting_tuples=None, group_size=2)
group_size = spec.meeting_subscription.get().size
matches, unmatched = generate_meetings(users, spec, prev_meeting_tuples=None, group_size=group_size)
save_meetings(matches, spec)

send_batch_unmatched_email(unmatched)
Expand Down

0 comments on commit f0b9bd1

Please sign in to comment.