diff --git a/bqskit/ir/gates/parameterized/diagonal.py b/bqskit/ir/gates/parameterized/diagonal.py index 63bb40b9..8518edc9 100644 --- a/bqskit/ir/gates/parameterized/diagonal.py +++ b/bqskit/ir/gates/parameterized/diagonal.py @@ -51,16 +51,17 @@ def get_grad(self, params: RealVector = []) -> npt.NDArray[np.complex128]: """ self.check_parameters(params) - mat = np.eye(2 ** self.num_qudits, dtype=np.complex128) + grad = np.zeros( + ( + len(params), 2 ** self.num_qudits, + 2 ** self.num_qudits, + ), dtype=np.complex128, + ) - for i in range(1, 2 ** self.num_qudits): - mat[i][i] = 1j * np.exp(1j * params[i - 1]) + for i, ind in enumerate(range(1, 2 ** self.num_qudits)): + grad[ind][i][i] = 1j * np.exp(1j * params[ind]) - return np.array( - [ - mat, - ], dtype=np.complex128, - ) + return grad def optimize(self, env_matrix: npt.NDArray[np.complex128]) -> list[float]: """