Skip to content

Commit 36707b9

Browse files
authored
Merge pull request #143 from Pressio/139-extents-in-vector-space
addressed extents
2 parents bebff38 + a22cffc commit 36707b9

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

romtools/vector_space/__init__.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,12 @@ class VectorSpace(abc.ABC):
104104
'''
105105

106106
@abc.abstractmethod
107-
def get_dimension(self) -> int:
107+
def extents(self) -> np.ndarray:
108108
'''
109109
Retrieves the dimension of the vector space
110110
111111
Returns:
112-
`int`: The dimension of the vector space.
112+
`np.ndarray`: The dimension of the vector space (n_var,nx,K).
113113
114114
Concrete subclasses should implement this method to return the
115115
appropriate dimension for their specific vector space implementation.
@@ -183,14 +183,14 @@ def __init__(self,
183183
self.__basis = splitter.split(snapshot_matrix)
184184
self.__basis = orthogonalizer.orthogonalize(self.__basis)
185185
self.__basis = _matrix_to_tensor(n_var, self.__basis)
186-
self.__dimension = self.__basis.shape[2]
186+
self.__extents = self.__basis.shape
187187
self.__shift_vector = shifter.get_shift_vector()
188188

189-
def get_dimension(self) -> int:
189+
def extents(self) -> np.ndarray:
190190
'''
191-
Concrete implementation of `VectorSpace.get_dimension()`
191+
Concrete implementation of `VectorSpace.extents()`
192192
'''
193-
return self.__dimension
193+
return self.__extents
194194

195195
def get_shift_vector(self) -> np.ndarray:
196196
'''
@@ -271,14 +271,14 @@ def __init__(self,
271271
self.__basis = _tensor_to_matrix(self.__basis)
272272
self.__basis = orthogonalizer.orthogonalize(self.__basis)
273273
self.__basis = _matrix_to_tensor(n_var, self.__basis)
274-
self.__dimension = self.__basis.shape[2]
274+
self.__extents = self.__basis.shape
275275
self.__shift_vector = shifter.get_shift_vector()
276276

277-
def get_dimension(self) -> int:
277+
def extents(self) -> np.ndarray:
278278
'''
279-
Concrete implementation of `VectorSpace.get_dimension()`
279+
Concrete implementation of `VectorSpace.extents()`
280280
'''
281-
return self.__dimension
281+
return self.__extents
282282

283283
def get_shift_vector(self) -> np.ndarray:
284284
'''

tests/romtools/test_vector_space.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_dictionary_vector_space():
3131
assert np.allclose(my_vector_space.get_basis().flatten(),
3232
snapshots.flatten())
3333
assert np.allclose(my_vector_space.get_shift_vector(), 0)
34-
assert np.allclose(my_vector_space.get_dimension(), 6)
34+
assert np.allclose(my_vector_space.extents()[-1], 6)
3535

3636
# test with a shift
3737
my_shifter = utils.create_average_shifter(snapshots)
@@ -43,7 +43,7 @@ def test_dictionary_vector_space():
4343
(original_snapshots - np.mean(original_snapshots, axis=2)[:, :, None]).flatten())
4444
assert np.allclose(my_vector_space.get_shift_vector(),
4545
np.mean(original_snapshots, axis=2))
46-
assert np.allclose(my_vector_space.get_dimension(), 6)
46+
assert np.allclose(my_vector_space.extents()[-1], 6)
4747

4848
# test with a shift and splitting
4949
my_shifter = utils.create_average_shifter(snapshots)
@@ -54,7 +54,7 @@ def test_dictionary_vector_space():
5454
my_splitter)
5555
assert np.allclose(my_vector_space.get_shift_vector(),
5656
np.mean(snapshots, axis=2))
57-
assert np.allclose(my_vector_space.get_dimension(), 12)
57+
assert np.allclose(my_vector_space.extents()[-1], 12)
5858

5959
# test with a shift, splitting, and orthogonalization
6060
my_shifter = utils.create_average_shifter(snapshots)
@@ -66,7 +66,7 @@ def test_dictionary_vector_space():
6666
my_orthogonalizer)
6767
assert np.allclose(my_vector_space.get_shift_vector(),
6868
np.mean(snapshots, axis=2))
69-
assert np.allclose(my_vector_space.get_dimension(), 12)
69+
assert np.allclose(my_vector_space.extents()[-1], 12)
7070
basis = my_vector_space.get_basis()
7171
basis = _tensor_to_matrix(basis)
7272
assert np.allclose(basis.transpose() @ basis, np.eye(12))
@@ -82,7 +82,7 @@ def test_vector_space_from_pod():
8282
u, s, v = np.linalg.svd(snapshotMatrix, full_matrices=False)
8383
basis_tensor = my_vector_space.get_basis()
8484
assert np.allclose(u.reshape(basis_tensor.shape), basis_tensor)
85-
assert np.allclose(6, my_vector_space.get_dimension())
85+
assert np.allclose(6, my_vector_space.extents()[-1])
8686
assert np.allclose(0, my_vector_space.get_shift_vector())
8787

8888
# test with a shift
@@ -93,7 +93,7 @@ def test_vector_space_from_pod():
9393
assert np.allclose(u.reshape(basis_tensor.shape), basis_tensor) # FAILS
9494
assert np.allclose(my_vector_space.get_shift_vector(),
9595
np.mean(original_snapshots, axis=2))
96-
assert np.allclose(my_vector_space.get_dimension(), 6)
96+
assert np.allclose(my_vector_space.extents()[-1], 6)
9797

9898
# test with a shift and splitting
9999
snapshots = np.random.normal(size=(3, 8, 6))
@@ -107,7 +107,7 @@ def test_vector_space_from_pod():
107107
assert np.allclose(u.reshape(basis_tensor.shape), basis_tensor)
108108
assert np.allclose(my_vector_space.get_shift_vector(),
109109
np.mean(original_snapshots, axis=2))
110-
assert np.allclose(my_vector_space.get_dimension(), 12)
110+
assert np.allclose(my_vector_space.extents()[-1], 12)
111111

112112
# test with a shift, splitting, and orthogonalization
113113
snapshots = np.random.normal(size=(3, 8, 6))
@@ -124,7 +124,7 @@ def test_vector_space_from_pod():
124124
assert np.allclose(u.reshape(basis_tensor.shape), basis_tensor)
125125
assert np.allclose(my_vector_space.get_shift_vector(),
126126
np.mean(original_snapshots, axis=2))
127-
assert np.allclose(my_vector_space.get_dimension(), 12)
127+
assert np.allclose(my_vector_space.extents()[-1], 12)
128128

129129

130130
@pytest.mark.mpi_skip
@@ -139,7 +139,7 @@ def test_trial_space_from_scaled_pod():
139139
u = u.reshape(basis_tensor.shape)
140140
u = my_scaler.post_scale(u)
141141
assert np.allclose(u, basis_tensor), print(u, my_vector_space.get_basis())
142-
assert np.allclose(6, my_vector_space.get_dimension())
142+
assert np.allclose(6, my_vector_space.extents()[-1])
143143
assert np.allclose(0, my_vector_space.get_shift_vector())
144144

145145
# test with a shift
@@ -159,7 +159,7 @@ def test_trial_space_from_scaled_pod():
159159
assert np.allclose(basis_tensor, u) # FAILS
160160
assert np.allclose(my_vector_space.get_shift_vector(),
161161
np.mean(original_snapshots, axis=2))
162-
assert np.allclose(my_vector_space.get_dimension(), 6)
162+
assert np.allclose(my_vector_space.extents()[-1], 6)
163163

164164
# test with a shift and splitting
165165
snapshots = np.random.normal(size=(3, 8, 6))
@@ -180,7 +180,7 @@ def test_trial_space_from_scaled_pod():
180180
assert np.allclose(basis_tensor, u)
181181
assert np.allclose(my_vector_space.get_shift_vector(),
182182
np.mean(original_snapshots, axis=2))
183-
assert np.allclose(my_vector_space.get_dimension(), 12)
183+
assert np.allclose(my_vector_space.extents()[-1], 12)
184184

185185
# test with a shift, splitting, and orthogonalization
186186
snapshots = np.random.normal(size=(3, 8, 6))
@@ -207,7 +207,7 @@ def test_trial_space_from_scaled_pod():
207207
assert np.allclose(basis_tensor, u)
208208
assert np.allclose(my_vector_space.get_shift_vector(),
209209
np.mean(original_snapshots, axis=2))
210-
assert np.allclose(my_vector_space.get_dimension(), 12)
210+
assert np.allclose(my_vector_space.extents()[-1], 12)
211211

212212

213213
if __name__ == "__main__":

tests/romtools/test_vector_space_svd_mpi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_vector_space_from_pod_mpi():
4242
snaps = construct_snapshots(comm)
4343
myVectorSpace = rt.VectorSpaceFromPOD(snaps, svdFnc=MyFakeSvd(comm))
4444
U = myVectorSpace.get_basis()
45-
k = myVectorSpace.get_dimension()
45+
k = myVectorSpace.extents()[-1]
4646
if rank == 0:
4747
assert np.allclose(U, np.zeros((3, 3, 2)))
4848
assert np.allclose(2, k)

0 commit comments

Comments
 (0)