-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'benchmarks' of https://github.com/lab-cosmo/mops into b…
…enchmarks
- Loading branch information
Showing
10 changed files
with
176 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,49 @@ | ||
import numpy as np | ||
import math | ||
import time | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
def benchmark(function, repeats=1000, plot=True): | ||
def benchmark(function, repeats=1000, warmup=10, plot=True): | ||
for _ in range(warmup): | ||
function() | ||
|
||
timings = [] | ||
for _ in range(repeats): | ||
start = time.time() | ||
function() | ||
end = time.time() | ||
timings.append(end-start) | ||
timings.append(end - start) | ||
|
||
times_array = np.array(timings) | ||
mean = np.mean(times_array) | ||
std = np.std(times_array) | ||
if std > 0.1 * mean: print("warning: inconsistent timings") | ||
if std > 0.1 * mean: | ||
print("warning: inconsistent timings") | ||
|
||
if plot: | ||
plt.plot(np.arange(times_array.shape[0]), times_array, ".") | ||
plt.savefig("benchmark_plot.pdf") | ||
|
||
return mean, std | ||
|
||
|
||
def format_mean_std(mean, std_dev, decimals=2): | ||
# find the exponent | ||
if mean != 0: | ||
exponent = math.floor(math.log10(abs(mean))) | ||
else: | ||
exponent = 0 | ||
|
||
# scale the mean and standard deviation by the exponent | ||
scaled_mean = mean / (10**exponent) | ||
scaled_std_dev = std_dev / (10**exponent) | ||
|
||
# format the scaled mean and standard deviation | ||
format_string = f"{{:.{decimals}f}}" | ||
formatted_mean = format_string.format(scaled_mean) | ||
formatted_std_dev = format_string.format(scaled_std_dev) | ||
final_string = f"({formatted_mean}±{formatted_std_dev})e{exponent}" | ||
|
||
return final_string |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import numpy as np | ||
from benchmark import benchmark, format_mean_std | ||
from mops.reference_implementations import outer_product_scatter_add as ref_opsa | ||
|
||
from mops import outer_product_scatter_add as opsa | ||
|
||
np.random.seed(0xDEADBEEF) | ||
|
||
A = np.random.rand(1000, 20) | ||
B = np.random.rand(1000, 5) | ||
|
||
indices = np.sort(np.random.randint(10, size=(1000,))) | ||
|
||
ref_mean, ref_std = benchmark(lambda: ref_opsa(A, B, indices, np.max(indices) + 1)) | ||
mean, std = benchmark(lambda: opsa(A, B, indices, np.max(indices) + 1)) | ||
|
||
print("Reference implementation:", format_mean_std(ref_mean, ref_std)) | ||
print("Optimized implementation:", format_mean_std(mean, std)) | ||
|
||
print("Speed-up:", ref_mean / mean) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import numpy as np | ||
from benchmark import benchmark, format_mean_std | ||
from mops.reference_implementations import ( | ||
outer_product_scatter_add_with_weights as ref_opsax, | ||
) | ||
|
||
from mops import outer_product_scatter_add_with_weights as opsax | ||
|
||
np.random.seed(0xDEADBEEF) | ||
|
||
|
||
A = np.random.rand(100, 10) | ||
R = np.random.rand(100, 5) | ||
n_O = 20 | ||
X = np.random.rand(n_O, 5) | ||
|
||
I = np.random.randint(20, size=(100,)) | ||
J = np.random.randint(20, size=(100,)) | ||
|
||
ref_mean, ref_std = benchmark(lambda: ref_opsax(A, R, X, I, J, 20)) | ||
mean, std = benchmark(lambda: opsax(A, R, X, I, J, 20)) | ||
|
||
print("Reference implementation:", format_mean_std(ref_mean, ref_std)) | ||
print("Optimized implementation:", format_mean_std(mean, std)) | ||
|
||
print("Speed-up:", ref_mean / mean) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import numpy as np | ||
from benchmark import benchmark, format_mean_std | ||
from mops.reference_implementations import sparse_accumulation_of_products as ref_sap | ||
|
||
from mops import sparse_accumulation_of_products as sap | ||
|
||
np.random.seed(0xDEADBEEF) | ||
|
||
A = np.random.rand(1000, 20) | ||
B = np.random.rand(1000, 6) | ||
C = np.random.rand(100) | ||
|
||
P_A = np.random.randint(20, size=(100,)) | ||
P_B = np.random.randint(6, size=(100,)) | ||
n_O = 50 | ||
P_O = np.random.randint(n_O, size=(100,)) | ||
|
||
ref_mean, ref_std = benchmark(lambda: ref_sap(A, B, C, P_A, P_B, P_O, n_O)) | ||
mean, std = benchmark(lambda: sap(A, B, C, P_A, P_B, P_O, n_O)) | ||
|
||
print("Reference implementation:", format_mean_std(ref_mean, ref_std)) | ||
print("Optimized implementation:", format_mean_std(mean, std)) | ||
|
||
print("Speed-up:", ref_mean / mean) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
from benchmark import benchmark, format_mean_std | ||
from mops.reference_implementations import ( | ||
sparse_accumulation_scatter_add_with_weights as ref_sasax, | ||
) | ||
|
||
from mops import sparse_accumulation_scatter_add_with_weights as sasax | ||
|
||
np.random.seed(0xDEADBEEF) | ||
|
||
A = np.random.rand(100, 20) | ||
R = np.random.rand(100, 200) | ||
X = np.random.rand(25, 13, 200) | ||
C = np.random.rand(50) | ||
n_O1 = 25 | ||
I = np.random.randint(25, size=(100,)) | ||
J = np.random.randint(25, size=(100,)) | ||
n_O2 = 15 | ||
M_1 = np.random.randint(20, size=(50,)) | ||
M_2 = np.random.randint(13, size=(50,)) | ||
M_3 = np.random.randint(n_O2, size=(50,)) | ||
|
||
ref_mean, ref_std = benchmark( | ||
lambda: ref_sasax(A, R, X, C, I, J, M_1, M_2, M_3, n_O1, n_O2) | ||
) | ||
mean, std = benchmark(lambda: sasax(A, R, X, C, I, J, M_1, M_2, M_3, n_O1, n_O2)) | ||
|
||
print("Reference implementation:", format_mean_std(ref_mean, ref_std)) | ||
print("Optimized implementation:", format_mean_std(mean, std)) | ||
|
||
print("Speed-up:", ref_mean / mean) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters