-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
176 lines (127 loc) · 5.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import csv
import random
import re
from pprint import pprint
from matcher import JMPreference, MatcherConfig, Slot, User, run_matcher
JM_PREFERENCE_FILE = "jm_preferences.csv"
SM_PREFERENCE_FILE = "sm_preferences.csv"
SHORT_OUT = "output_short.csv"
LONG_OUT = "output_long.csv"
# JM PREFERENCES
JM_FIRST_NAME_COLUMN = "First name"
JM_LAST_NAME_COLUMN = "Last name"
JM_ROLE_COLUMN = "For which position are you accepting/rejecting?"
JM_ROLE_CHECK = "Junior Mentor"
JM_SOCIAL_COLUMN = "How much time would you like to spend within your family focusing on teaching improvement vs being more social?"
# SM PREFERENCES
SM_TIME_COLUMN = "Time"
# SETTINGS
MIN_FAMILY_SIZE = 3
MAX_FAMILY_SIZE = 4
def parse_sm_slots() -> list[Slot]:
slots = []
next_id = 0
with open(SM_PREFERENCE_FILE, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
time = row[SM_TIME_COLUMN]
sms = [
val.strip() for key, val in row.items() if "sm" in key.lower().split()
]
slots.append(Slot(id=str(next_id), time=time, sm_list=sms))
next_id += 1
return slots
def parse_jm_preferences(slots: list[Slot]) -> tuple[list[User], list[JMPreference]]:
users = []
preferences = []
slots_by_time = {}
for slot in slots:
if slot.time not in slots_by_time:
slots_by_time[slot.time] = []
slots_by_time[slot.time].append(slot)
slots_preference_check = {time: False for time in slots_by_time}
with open(JM_PREFERENCE_FILE, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
if row[JM_ROLE_COLUMN] != JM_ROLE_CHECK:
continue
first_name = row[JM_FIRST_NAME_COLUMN].strip()
last_name = row[JM_LAST_NAME_COLUMN].strip()
name = f"{first_name} {last_name}"
cur_user = User(id=name, name=name, sociability=int(row[JM_SOCIAL_COLUMN]))
for key, val in row.items():
if "[" in key and "]" in key:
# detected as a preference column
time = re.findall(r"\[([^\]]+)\]", key)
assert len(time) == 1
time = time[0]
time_slots = slots_by_time.get(time, [])
if time in slots_preference_check:
slots_preference_check[time] = True
for slot in time_slots:
preferences.append(
JMPreference(
user_id=cur_user.id, slot_id=slot.id, value=int(val)
)
)
users.append(cur_user)
assert all(val for val in slots_preference_check.values())
return users, preferences
def main():
slots = parse_sm_slots()
users, preferences = parse_jm_preferences(slots)
random.shuffle(users)
random.shuffle(preferences)
config = MatcherConfig(min_family_size=MIN_FAMILY_SIZE, max_family_size=MAX_FAMILY_SIZE, sociability_bias=0)
matching = run_matcher(users, preferences, slots, config)
users_by_id = {user.id: user for user in users}
slots_by_id = {slot.id: slot for slot in slots}
preference_map = {}
for preference in preferences:
preference_map[(preference.user_id, preference.slot_id)] = preference.value
max_sms = max(len(slot.sm_list) for slot in slots)
matching_by_slot: dict[str, list[User]] = {}
for user_id, slot in matching.items():
user = users_by_id[user_id]
if slot.id not in matching_by_slot:
matching_by_slot[slot.id] = []
matching_by_slot[slot.id].append(user)
max_jms = max(len(users) for users in matching_by_slot.values())
rows = []
for slot_id in sorted(matching_by_slot.keys()):
users = sorted(matching_by_slot[slot_id], key=lambda u: u.id)
slot = slots_by_id[slot_id]
user_names = [user.name for user in users]
padded_sm_list = slot.sm_list + [""] * (max_sms - len(slot.sm_list))
padded_jm_list = user_names + [""] * (max_jms - len(user_names))
rows.append([*padded_sm_list, *padded_jm_list])
header = [f"SM {i+1}" for i in range(max_sms)] + [
f"JM {i+1}" for i in range(max_jms)
]
with open(SHORT_OUT, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(header)
writer.writerows(rows)
rows = []
for slot_id in sorted(matching_by_slot.keys()):
users = sorted(matching_by_slot[slot_id], key=lambda u: u.id)
slot = slots_by_id[slot_id]
padded_sm_list = slot.sm_list + [""] * (max_sms - len(slot.sm_list))
for user in users:
rows.append([*padded_sm_list, user.name, preference_map[user.id, slot_id], user.sociability])
header = [f"SM {i+1}" for i in range(max_sms)] + ["JM", "Preference", "Sociability"]
with open(LONG_OUT, "w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(header)
writer.writerows(rows)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=None, help="Random seed to use")
args = parser.parse_args()
seed = args.seed
if args.seed is None:
seed = random.randint(0, int(1e6))
print(f"Using seed: {seed}\n")
random.seed(seed)
main()