Skip to content

Commit 9e0f36e

Browse files
committed
fix
1 parent 2ffd69f commit 9e0f36e

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

romtools/trial_space/__init__.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@
9090

9191
import abc
9292
import numpy as np
93-
from romtools.trial_space.utils import *
93+
from romtools.trial_space.utils import tensor_to_matrix, matrix_to_tensor
94+
from romtools.trial_space.utils.truncater import *
95+
from romtools.trial_space.utils.shifter import *
96+
from romtools.trial_space.utils.scaler import *
97+
from romtools.trial_space.utils.splitter import *
98+
from romtools.trial_space.utils.orthogonalizer import *
99+
94100

95101
class TrialSpace(abc.ABC):
96102
'''
@@ -242,7 +248,7 @@ def __init__(self,
242248

243249
n_var = snapshots.shape[0]
244250
shifted_snapshots, self.__shift_vector = shifter(snapshots)
245-
snapshot_matrix = utils.tensor_to_matrix(shifted_snapshots)
251+
snapshot_matrix = tensor_to_matrix(shifted_snapshots)
246252
shifted_split_snapshots = splitter(snapshot_matrix)
247253

248254
svd_picked = np.linalg.svd if svdFnc is None else svdFnc
@@ -318,14 +324,14 @@ def __init__(self, snapshots,
318324
n_var = snapshots.shape[0]
319325
shifted_snapshots, self.__shift_vector = shifter(snapshots)
320326
scaled_shifted_snapshots = scaler.pre_scaling(shifted_snapshots)
321-
snapshot_matrix = utils.tensor_to_matrix(scaled_shifted_snapshots)
327+
snapshot_matrix = tensor_to_matrix(scaled_shifted_snapshots)
322328
snapshot_matrix = splitter(snapshot_matrix)
323329

324330
lsv, svals, _ = np.linalg.svd(snapshot_matrix, full_matrices=False)
325331
self.__basis = truncater(lsv, svals)
326332
self.__basis = matrix_to_tensor(n_var, self.__basis)
327333
self.__basis = scaler.post_scaling(self.__basis)
328-
self.__basis = utils.tensor_to_matrix(self.__basis)
334+
self.__basis = tensor_to_matrix(self.__basis)
329335
self.__basis = orthogonalizer(self.__basis)
330336
self.__basis = matrix_to_tensor(n_var, self.__basis)
331337
self.__dimension = self.__basis.shape[2]

0 commit comments

Comments
 (0)