Skip to content

Commit 8960e89

Browse files
committed
small fixes
1 parent b0c9757 commit 8960e89

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

romtools/rom/projections.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,11 @@ def optimal_l2_projection(input_tensor : np.ndarray , vector_space : romtools.Ve
8989
if return_full_state == False:
9090
return reduced_state
9191

92+
if len(input_tensor.shape) == 2:
93+
full_state = np.einsum('nik,k...->ni...',basis,reduced_state) + shift_vector
9294
else:
93-
if len(input_tensor.shape) == 2:
94-
full_state = np.einsum('nik,k...->ni...',basis,reduced_state) + shift_vector
95-
elif len(input_tensor.shape) == 3:
96-
full_state = np.einsum('nik,k...->ni...',basis,reduced_state) + shift_vector[...,None]
97-
return reduced_state,full_state
95+
full_state = np.einsum('nik,k...->ni...',basis,reduced_state) + shift_vector[...,None]
96+
return reduced_state,full_state
9897

9998

10099

0 commit comments

Comments
 (0)