Skip to content

Commit

Permalink
Merge pull request #43 from baoliay2008/dev
Browse files Browse the repository at this point in the history
✨ feat: implement FFT acceleration for Elo algorithm
  • Loading branch information
baoliay2008 authored Dec 30, 2023
2 parents 91180a3 + f7d7032 commit 2967fbb
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
90 changes: 90 additions & 0 deletions app/core/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np
from scipy.signal import fftconvolve

from app.core.elo import delta_coefficients

EXPAND_SIZE = 100
MAX_RATING = 4000 * EXPAND_SIZE


def pre_calc_convolution(old_rating: np.ndarray) -> np.ndarray:
"""
Pre-calculate convolution values for the Elo rating update.
:param old_rating:
:return:
"""
f = 1 / (
1 + np.power(10, np.arange(-MAX_RATING, MAX_RATING + 1) / (400 * EXPAND_SIZE))
)
g = np.bincount(np.round(old_rating * EXPAND_SIZE).astype(int))
convolution = fftconvolve(f, g, mode="full")
convolution = convolution[: 2 * MAX_RATING + 1]
return convolution


def get_expected_rank(convolution: np.ndarray, x: int) -> float:
"""
Get the expected rank based on pre-calculated convolution values.
:param convolution:
:param x:
:return:
"""
return convolution[x + MAX_RATING] + 0.5


def get_equation_left(convolution: np.ndarray, x: int) -> float:
"""
Get the left side of equation for expected rating based on pre-calculated convolution values
:param convolution:
:param x:
:return:
"""
return convolution[x + MAX_RATING] + 1


def binary_search_expected_rating(convolution: np.ndarray, mean_rank: float) -> int:
"""
Perform binary search to find the expected rating for a given mean rank.
:param convolution:
:param mean_rank:
:return:
"""
lo, hi = 0, MAX_RATING
while lo < hi:
mid = (lo + hi) // 2
if get_equation_left(convolution, mid) < mean_rank:
hi = mid
else:
lo = mid + 1
return mid


def get_expected_rating(rank: int, rating: float, convolution: np.ndarray) -> float:
"""
Calculate the expected rating based on current rank, rating, and pre-calculated convolution.
:param rank:
:param rating:
:param convolution:
:return:
"""
expected_rank = get_expected_rank(convolution, round(rating * EXPAND_SIZE))
mean_rank = np.sqrt(expected_rank * rank)
return binary_search_expected_rating(convolution, mean_rank) / EXPAND_SIZE


def fft_delta(ranks: np.ndarray, ratings: np.ndarray, ks: np.ndarray) -> np.ndarray:
"""
Calculate Elo rating changes using Fast Fourier Transform (FFT)
:param ranks:
:param ratings:
:param ks:
:return:
"""
convolution = pre_calc_convolution(ratings)
expected_ratings = list()
for i in range(len(ranks)):
rank = ranks[i]
rating = ratings[i]
expected_ratings.append(get_expected_rating(rank, rating, convolution))
delta_ratings = (np.array(expected_ratings) - ratings) * delta_coefficients(ks)
return delta_ratings
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pymongo==4.6.1
pytest==7.4.3
pytz==2023.3.post1
PyYAML==6.0.1
scipy==1.11.4
six==1.16.0
sniffio==1.3.0
starlette==0.27.0
Expand Down
36 changes: 36 additions & 0 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Final

import numpy as np

from app.core.fft import fft_delta

RATING_DELTA_PRECISION: Final[float] = 0.05


def test_fft_delta():
"""
Test function for the fft_delta function.
Loads test data from a NumPy file containing columns: ks, ranks, old_ratings, new_ratings.
Calculates delta_ratings using fft_delta function and checks if the resulting new_ratings
match the expected values within a specified precision.
Raises:
AssertionError: If the calculated ratings deviate from the expected ratings
by more than RATING_DELTA_PRECISION.
"""

with open("tests/test_data/contest_k_rating_test.npy", "rb") as f:
data = np.load(f)
ks = data[:, 0]
ranks = data[:, 1]
old_ratings = data[:, 2]
new_ratings = data[:, 3]

delta_ratings = fft_delta(ranks, old_ratings, ks)
testing_new_ratings = old_ratings + delta_ratings

errors = np.abs(new_ratings - testing_new_ratings)
assert np.all(
errors < RATING_DELTA_PRECISION
), f"Elo delta test failed. Some errors are not within {RATING_DELTA_PRECISION=}."

0 comments on commit 2967fbb

Please sign in to comment.