-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from baoliay2008/dev
✨ feat: implement FFT acceleration for Elo algorithm
- Loading branch information
Showing
3 changed files
with
127 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import numpy as np | ||
from scipy.signal import fftconvolve | ||
|
||
from app.core.elo import delta_coefficients | ||
|
||
EXPAND_SIZE = 100 | ||
MAX_RATING = 4000 * EXPAND_SIZE | ||
|
||
|
||
def pre_calc_convolution(old_rating: np.ndarray) -> np.ndarray: | ||
""" | ||
Pre-calculate convolution values for the Elo rating update. | ||
:param old_rating: | ||
:return: | ||
""" | ||
f = 1 / ( | ||
1 + np.power(10, np.arange(-MAX_RATING, MAX_RATING + 1) / (400 * EXPAND_SIZE)) | ||
) | ||
g = np.bincount(np.round(old_rating * EXPAND_SIZE).astype(int)) | ||
convolution = fftconvolve(f, g, mode="full") | ||
convolution = convolution[: 2 * MAX_RATING + 1] | ||
return convolution | ||
|
||
|
||
def get_expected_rank(convolution: np.ndarray, x: int) -> float: | ||
""" | ||
Get the expected rank based on pre-calculated convolution values. | ||
:param convolution: | ||
:param x: | ||
:return: | ||
""" | ||
return convolution[x + MAX_RATING] + 0.5 | ||
|
||
|
||
def get_equation_left(convolution: np.ndarray, x: int) -> float: | ||
""" | ||
Get the left side of equation for expected rating based on pre-calculated convolution values | ||
:param convolution: | ||
:param x: | ||
:return: | ||
""" | ||
return convolution[x + MAX_RATING] + 1 | ||
|
||
|
||
def binary_search_expected_rating(convolution: np.ndarray, mean_rank: float) -> int: | ||
""" | ||
Perform binary search to find the expected rating for a given mean rank. | ||
:param convolution: | ||
:param mean_rank: | ||
:return: | ||
""" | ||
lo, hi = 0, MAX_RATING | ||
while lo < hi: | ||
mid = (lo + hi) // 2 | ||
if get_equation_left(convolution, mid) < mean_rank: | ||
hi = mid | ||
else: | ||
lo = mid + 1 | ||
return mid | ||
|
||
|
||
def get_expected_rating(rank: int, rating: float, convolution: np.ndarray) -> float: | ||
""" | ||
Calculate the expected rating based on current rank, rating, and pre-calculated convolution. | ||
:param rank: | ||
:param rating: | ||
:param convolution: | ||
:return: | ||
""" | ||
expected_rank = get_expected_rank(convolution, round(rating * EXPAND_SIZE)) | ||
mean_rank = np.sqrt(expected_rank * rank) | ||
return binary_search_expected_rating(convolution, mean_rank) / EXPAND_SIZE | ||
|
||
|
||
def fft_delta(ranks: np.ndarray, ratings: np.ndarray, ks: np.ndarray) -> np.ndarray: | ||
""" | ||
Calculate Elo rating changes using Fast Fourier Transform (FFT) | ||
:param ranks: | ||
:param ratings: | ||
:param ks: | ||
:return: | ||
""" | ||
convolution = pre_calc_convolution(ratings) | ||
expected_ratings = list() | ||
for i in range(len(ranks)): | ||
rank = ranks[i] | ||
rating = ratings[i] | ||
expected_ratings.append(get_expected_rating(rank, rating, convolution)) | ||
delta_ratings = (np.array(expected_ratings) - ratings) * delta_coefficients(ks) | ||
return delta_ratings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Final | ||
|
||
import numpy as np | ||
|
||
from app.core.fft import fft_delta | ||
|
||
RATING_DELTA_PRECISION: Final[float] = 0.05 | ||
|
||
|
||
def test_fft_delta(): | ||
""" | ||
Test function for the fft_delta function. | ||
Loads test data from a NumPy file containing columns: ks, ranks, old_ratings, new_ratings. | ||
Calculates delta_ratings using fft_delta function and checks if the resulting new_ratings | ||
match the expected values within a specified precision. | ||
Raises: | ||
AssertionError: If the calculated ratings deviate from the expected ratings | ||
by more than RATING_DELTA_PRECISION. | ||
""" | ||
|
||
with open("tests/test_data/contest_k_rating_test.npy", "rb") as f: | ||
data = np.load(f) | ||
ks = data[:, 0] | ||
ranks = data[:, 1] | ||
old_ratings = data[:, 2] | ||
new_ratings = data[:, 3] | ||
|
||
delta_ratings = fft_delta(ranks, old_ratings, ks) | ||
testing_new_ratings = old_ratings + delta_ratings | ||
|
||
errors = np.abs(new_ratings - testing_new_ratings) | ||
assert np.all( | ||
errors < RATING_DELTA_PRECISION | ||
), f"Elo delta test failed. Some errors are not within {RATING_DELTA_PRECISION=}." |