Skip to content

Commit

Permalink
Linter
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitSolomonPrinceton committed Jul 9, 2024
1 parent b16b1a3 commit 907781e
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/osqp/tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

ATOL = 1e-2
RTOL = 1e-4
EPS = 1e-5
EPS = 1e-5

cuda = False
verbose = True
Expand Down Expand Up @@ -45,6 +45,7 @@ def get_grads(
grads = get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type)
return [P, q, A, l, u, true_x], grads


def get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type):
P_idx = P.nonzero()
P_shape = P.shape
Expand Down Expand Up @@ -77,6 +78,7 @@ def get_grads_torch(P, q, A, l, u, true_x, algebra, solver_type):
grads = [x.grad.data.squeeze(0).cpu().numpy() for x in [P_torch, q_torch, A_torch, l_torch, u_torch]]
return grads


def test_dl_dq(algebra, solver_type):
n, m = 5, 5

Expand Down Expand Up @@ -106,6 +108,7 @@ def f(q):
print('dq: ', np.round(dq, decimals=4))
npt.assert_allclose(dq_fd, dq, rtol=RTOL, atol=ATOL)


def test_dl_dP(algebra, solver_type):
n, m = 5, 5

Expand Down Expand Up @@ -135,6 +138,7 @@ def f(P):
print('dP: ', np.round(dP, decimals=4))
npt.assert_allclose(dP_fd, dP, rtol=RTOL, atol=ATOL)


def test_dl_dA(algebra, solver_type):
n, m = 5, 5

Expand Down Expand Up @@ -164,6 +168,7 @@ def f(A):
print('dA: ', np.round(dA, decimals=4))
npt.assert_allclose(dA_fd, dA, rtol=RTOL, atol=ATOL)


def test_dl_dl(algebra, solver_type):
n, m = 5, 5

Expand Down Expand Up @@ -193,6 +198,7 @@ def f(l):
print('dl: ', np.round(dl, decimals=4))
npt.assert_allclose(dl_fd, dl, rtol=RTOL, atol=ATOL)


def test_dl_du(algebra, solver_type):
n, m = 5, 5

Expand Down

0 comments on commit 907781e

Please sign in to comment.