Skip to content

Commit

Permalink
Defined epsilon for nn_test. Removed unused args.
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitSolomonPrinceton committed Jul 9, 2024
1 parent 64b5d13 commit b16b1a3
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/osqp/tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

ATOL = 1e-2
RTOL = 1e-4
EPS = 1e-5

cuda = False
verbose = True
Expand Down Expand Up @@ -44,7 +45,6 @@ def get_grads(
grads = get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type)
return [P, q, A, l, u, true_x], grads


def get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type):
P_idx = P.nonzero()
P_shape = P.shape
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type):
grads = [x.grad.data.squeeze(0).cpu().numpy() for x in [P_torch, q_torch, A_torch, l_torch, u_torch]]
return grads

def test_dl_dq(algebra, solver_type, atol, rtol, decimal_tol):
def test_dl_dq(algebra, solver_type):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
Expand All @@ -100,13 +100,13 @@ def f(q):

return 0.5 * np.sum(np.square(x_hat - true_x))

dq_fd = approx_fprime(q, f)
dq_fd = approx_fprime(q, f, epsilon=EPS)
if verbose:
print('dq_fd: ', np.round(dq_fd, decimals=4))
print('dq: ', np.round(dq, decimals=4))
npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL)

def test_dl_dP(algebra, solver_type, atol, rtol, decimal_tol):
def test_dl_dP(algebra, solver_type):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
Expand All @@ -129,13 +129,13 @@ def f(P):

return 0.5 * np.sum(np.square(x_hat - true_x))

dP_fd = approx_fprime(P, f)
dP_fd = approx_fprime(P, f, epsilon=EPS)
if verbose:
print('dP_fd: ', np.round(dP_fd, decimals=4))
print('dP: ', np.round(dP, decimals=4))
npt.assert_allclose(dP_fd, dP, rtol=RTOL, atol=ATOL)

def test_dl_dA(algebra, solver_type, atol, rtol, decimal_tol):
def test_dl_dA(algebra, solver_type):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
Expand All @@ -158,13 +158,13 @@ def f(A):

return 0.5 * np.sum(np.square(x_hat - true_x))

dA_fd = approx_fprime(A, f)
dA_fd = approx_fprime(A, f, epsilon=EPS)
if verbose:
print('dA_fd: ', np.round(dA_fd, decimals=4))
print('dA: ', np.round(dA, decimals=4))
npt.assert_allclose(dA_fd, dA, rtol=RTOL, atol=ATOL)

def test_dl_dl(algebra, solver_type, atol, rtol, decimal_tol):
def test_dl_dl(algebra, solver_type):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
Expand All @@ -187,13 +187,13 @@ def f(l):

return 0.5 * np.sum(np.square(x_hat - true_x))

dl_fd = approx_fprime(l, f)
dl_fd = approx_fprime(l, f, epsilon=EPS)
if verbose:
print('dl_fd: ', np.round(dl_fd, decimals=4))
print('dl: ', np.round(dl, decimals=4))
npt.assert_allclose(dl_fd, dl, rtol=RTOL, atol=ATOL)

def test_dl_du(algebra, solver_type, atol, rtol, decimal_tol):
def test_dl_du(algebra, solver_type):
n, m = 5, 5

model = osqp.OSQP(algebra=algebra)
Expand All @@ -216,7 +216,7 @@ def f(u):

return 0.5 * np.sum(np.square(x_hat - true_x))

du_fd = approx_fprime(u, f)
du_fd = approx_fprime(u, f, epsilon=EPS)
if verbose:
print('du_fd: ', np.round(du_fd, decimals=4))
print('du: ', np.round(du, decimals=4))
Expand Down

0 comments on commit b16b1a3

Please sign in to comment.