Skip to content

Commit

Permalink
Update test_engine.py
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored Feb 12, 2021
1 parent fad0b69 commit a3c42bb
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import pickle
import platform
import psutil
import random

Expand Down Expand Up @@ -1044,15 +1045,21 @@ def test_contribs_sparse():
# convert data to dense and get back same contribs
contribs_dense = gbm.predict(X_test.toarray(), pred_contrib=True)
# validate the values are the same
np.testing.assert_allclose(contribs_csr.toarray(), contribs_dense)
if platform.machine() == 'aarch64':
np.testing.assert_allclose(contribs_csr.toarray(), contribs_dense, rtol=1, atol=1e-12)
else:
np.testing.assert_allclose(contribs_csr.toarray(), contribs_dense)
assert (np.linalg.norm(gbm.predict(X_test, raw_score=True)
- np.sum(contribs_dense, axis=1)) < 1e-4)
# validate using CSC matrix
X_test_csc = X_test.tocsc()
contribs_csc = gbm.predict(X_test_csc, pred_contrib=True)
assert isspmatrix_csc(contribs_csc)
# validate the values are the same
np.testing.assert_allclose(contribs_csc.toarray(), contribs_dense)
if platform.machine() == 'aarch64':
np.testing.assert_allclose(contribs_csc.toarray(), contribs_dense, rtol=1, atol=1e-12)
else:
np.testing.assert_allclose(contribs_csc.toarray(), contribs_dense)


def test_contribs_sparse_multiclass():
Expand Down Expand Up @@ -1084,7 +1091,10 @@ def test_contribs_sparse_multiclass():
contribs_csr_array = np.swapaxes(np.array([sparse_array.todense() for sparse_array in contribs_csr]), 0, 1)
contribs_csr_arr_re = contribs_csr_array.reshape((contribs_csr_array.shape[0],
contribs_csr_array.shape[1] * contribs_csr_array.shape[2]))
np.testing.assert_allclose(contribs_csr_arr_re, contribs_dense)
if platform.machine() == 'aarch64':
np.testing.assert_allclose(contribs_csr_arr_re, contribs_dense, rtol=1, atol=1e-12)
else:
np.testing.assert_allclose(contribs_csr_arr_re, contribs_dense)
contribs_dense_re = contribs_dense.reshape(contribs_csr_array.shape)
assert np.linalg.norm(gbm.predict(X_test, raw_score=True) - np.sum(contribs_dense_re, axis=2)) < 1e-4
# validate using CSC matrix
Expand All @@ -1097,7 +1107,10 @@ def test_contribs_sparse_multiclass():
contribs_csc_array = np.swapaxes(np.array([sparse_array.todense() for sparse_array in contribs_csc]), 0, 1)
contribs_csc_array = contribs_csc_array.reshape((contribs_csc_array.shape[0],
contribs_csc_array.shape[1] * contribs_csc_array.shape[2]))
np.testing.assert_allclose(contribs_csc_array, contribs_dense)
if platform.machine() == 'aarch64':
np.testing.assert_allclose(contribs_csc_array, contribs_dense, rtol=1, atol=1e-12)
else:
np.testing.assert_allclose(contribs_csc_array, contribs_dense)


@pytest.mark.skipif(psutil.virtual_memory().available / 1024 / 1024 / 1024 < 3, reason='not enough RAM')
Expand Down

0 comments on commit a3c42bb

Please sign in to comment.