diff --git a/src/osqp/tests/nn_test.py b/src/osqp/tests/nn_test.py index f9a77b3a..6f3f8241 100644 --- a/src/osqp/tests/nn_test.py +++ b/src/osqp/tests/nn_test.py @@ -133,7 +133,7 @@ def f(P): return 0.5 * np.sum(np.square(x_hat - true_x)) - dP_fd = approx_fprime(P.tolil().reshape((1,n*n)), f, epsilon=EPS) + dP_fd = approx_fprime(P.toarray().flatten(), f, epsilon=EPS) if verbose: print('dP_fd: ', np.round(dP_fd, decimals=4)) print('dP: ', np.round(dP, decimals=4)) @@ -164,7 +164,7 @@ def f(A): return 0.5 * np.sum(np.square(x_hat - true_x)) - dA_fd = approx_fprime(A.tolil().reshape((1,m*n)), f, epsilon=EPS) + dA_fd = approx_fprime(A.toarray().flatten(), f, epsilon=EPS) if verbose: print('dA_fd: ', np.round(dA_fd, decimals=4)) print('dA: ', np.round(dA, decimals=4))