diff --git a/src/osqp/tests/nn_test.py b/src/osqp/tests/nn_test.py index ff45e019..41900803 100644 --- a/src/osqp/tests/nn_test.py +++ b/src/osqp/tests/nn_test.py @@ -11,6 +11,7 @@ ATOL = 1e-2 RTOL = 1e-4 +EPS = 1e-5 cuda = False verbose = True @@ -44,7 +45,6 @@ def get_grads( grads = get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type) return [P, q, A, l, u, true_x], grads - def get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type): P_idx = P.nonzero() P_shape = P.shape @@ -77,7 +77,7 @@ def get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type): grads = [x.grad.data.squeeze(0).cpu().numpy() for x in [P_torch, q_torch, A_torch, l_torch, u_torch]] return grads -def test_dl_dq(algebra, solver_type, atol, rtol, decimal_tol): +def test_dl_dq(algebra, solver_type): n, m = 5, 5 model = osqp.OSQP(algebra=algebra) @@ -100,13 +100,13 @@ def f(q): return 0.5 * np.sum(np.square(x_hat - true_x)) - dq_fd = approx_fprime(q, f) + dq_fd = approx_fprime(q, f, epsilon=EPS) if verbose: print('dq_fd: ', np.round(dq_fd, decimals=4)) print('dq: ', np.round(dq, decimals=4)) npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL) -def test_dl_dP(algebra, solver_type, atol, rtol, decimal_tol): +def test_dl_dP(algebra, solver_type): n, m = 5, 5 model = osqp.OSQP(algebra=algebra) @@ -129,13 +129,13 @@ def f(P): return 0.5 * np.sum(np.square(x_hat - true_x)) - dP_fd = approx_fprime(P, f) + dP_fd = approx_fprime(P, f, epsilon=EPS) if verbose: print('dP_fd: ', np.round(dP_fd, decimals=4)) print('dP: ', np.round(dP, decimals=4)) npt.assert_allclose(dP_fd, dP, rtol=RTOL, atol=ATOL) -def test_dl_dA(algebra, solver_type, atol, rtol, decimal_tol): +def test_dl_dA(algebra, solver_type): n, m = 5, 5 model = osqp.OSQP(algebra=algebra) @@ -158,13 +158,13 @@ def f(A): return 0.5 * np.sum(np.square(x_hat - true_x)) - dA_fd = approx_fprime(A, f) + dA_fd = approx_fprime(A, f, epsilon=EPS) if verbose: print('dA_fd: ', np.round(dA_fd, decimals=4)) print('dA: ', np.round(dA, decimals=4)) npt.assert_allclose(dA_fd, dA, rtol=RTOL, atol=ATOL) -def test_dl_dl(algebra, solver_type, atol, rtol, decimal_tol): +def test_dl_dl(algebra, solver_type): n, m = 5, 5 model = osqp.OSQP(algebra=algebra) @@ -187,13 +187,13 @@ def f(l): return 0.5 * np.sum(np.square(x_hat - true_x)) - dl_fd = approx_fprime(l, f) + dl_fd = approx_fprime(l, f, epsilon=EPS) if verbose: print('dl_fd: ', np.round(dl_fd, decimals=4)) print('dl: ', np.round(dl, decimals=4)) npt.assert_allclose(dl_fd, dl, rtol=RTOL, atol=ATOL) -def test_dl_du(algebra, solver_type, atol, rtol, decimal_tol): +def test_dl_du(algebra, solver_type): n, m = 5, 5 model = osqp.OSQP(algebra=algebra) @@ -216,7 +216,7 @@ def f(u): return 0.5 * np.sum(np.square(x_hat - true_x)) - du_fd = approx_fprime(u, f) + du_fd = approx_fprime(u, f, epsilon=EPS) if verbose: print('du_fd: ', np.round(du_fd, decimals=4)) print('du: ', np.round(du, decimals=4))