Skip to content

Commit e2fdfda

Browse files
author
John Tencer
committed
get all the tests working again
1 parent ebfdc83 commit e2fdfda

File tree

4 files changed

+45
-26
lines changed

4 files changed

+45
-26
lines changed

romtools/workflows/greedy/run_greedy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888

8989
from romtools.workflows.greedy.\
9090
greedy_coupler_base import GreedyCouplerBase
91+
from romtools.workflows.parameter_spaces import MonteCarloSample
9192

9293

9394
def run_fom_sample(coupler: GreedyCouplerBase,
@@ -128,7 +129,7 @@ def run_greedy(greedy_coupler: GreedyCouplerBase,
128129

129130
# create parameter domain
130131
parameter_space = greedy_coupler.get_parameter_space()
131-
parameter_samples = parameter_space.generate_samples(testing_sample_size)
132+
parameter_samples = MonteCarloSample(parameter_space, testing_sample_size)
132133

133134
# Make FOM/ROM directories
134135
greedy_coupler.create_fom_and_rom_cases(starting_sample_index,
@@ -239,7 +240,7 @@ def run_greedy(greedy_coupler: GreedyCouplerBase,
239240
basis_time += time.time() - t0
240241

241242
# Add a new sample
242-
new_parameter_sample = parameter_space.generate_samples(1)
243+
new_parameter_sample = MonteCarloSample(parameter_space, 1)
243244
parameter_samples = np.append(parameter_samples,
244245
new_parameter_sample, axis=0)
245246
new_sample_number = testing_sample_size + outer_loop_counter - 1

romtools/workflows/parameter_spaces.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def generate_samples(self, uniform_dist_samples: np.array) -> np.array:
118118
for param in self.get_parameter_list():
119119
param_samples = param.generate_samples(uniform_dist_samples[:, param_idx:param_idx+param.get_dimensionality()])
120120
samples.append(param_samples)
121-
return samples
121+
print(param_samples.shape)
122+
return np.concatenate(samples, axis=1)
122123

123124

124125
def MonteCarloSample(param_space: ParameterSpace, number_of_samples: int):
@@ -158,6 +159,9 @@ def get_dimensionality(self) -> int:
158159

159160
def generate_samples(self, uniform_dist_samples: np.array) -> np.array:
160161
assert uniform_dist_samples.shape[1] == self.get_dimensionality()
162+
print(self._lower_bound)
163+
print(self._upper_bound)
164+
print(uniform_dist_samples)
161165
return qmc.scale(uniform_dist_samples,
162166
self._lower_bound,
163167
self._upper_bound)

romtools/workflows/sampling/sampling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
from romtools.workflows.sampling.\
5151
sampling_coupler_base import SamplingCouplerBase
52+
from romtools.workflows.parameter_spaces import MonteCarloSample
5253

5354

5455
def run_sampling(sampling_coupler: SamplingCouplerBase,
@@ -61,7 +62,7 @@ def run_sampling(sampling_coupler: SamplingCouplerBase,
6162

6263
# create parameter domain
6364
parameter_space = sampling_coupler.get_parameter_space()
64-
parameter_samples = parameter_space.generate_samples(testing_sample_size)
65+
parameter_samples = MonteCarloSample(parameter_space, testing_sample_size)
6566

6667
# Make FOM/ROM directories
6768
starting_sample_index = 0

tests/romtools/workflows/test_parameter_spaces.py

+35-22
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,27 @@
77

88

99
def test_uniform_parameter():
10-
np.random.seed(12)
1110
param = UniformParameter('p1', -1, 1)
1211
assert param.get_name() == 'p1'
1312
assert param.get_dimensionality() == 1
14-
s = param.generate_samples(3)
15-
assert s.shape == (3, 1)
16-
gold = [[-0.69167432],
17-
[ 0.48009939],
18-
[-0.47336997]]
13+
14+
germ = np.array([[0.1], [0.5], [0.7]])
15+
s = param.generate_samples(germ)
16+
assert s.shape == germ.shape
17+
gold = [[-0.8],
18+
[ 0.0],
19+
[ 0.4]]
1920
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)
2021

2122

2223
def test_string_parameter():
2324
param = StringParameter('p1', 'p1val')
2425
assert param.get_name() == 'p1'
2526
assert param.get_dimensionality() == 1
26-
s = param.generate_samples(3)
27-
assert s.shape == (3, 1)
27+
28+
germ = np.array([[0.1], [0.5], [0.7]])
29+
s = param.generate_samples(germ)
30+
assert s.shape == germ.shape
2831
assert (s == [['p1val', 'p1val', 'p1val']]).all()
2932

3033

@@ -33,20 +36,25 @@ def test_uniform_param_space():
3336
param_space = UniformParameterSpace(['p1', 'p2'], [-1, 0], [1, 3])
3437
assert param_space.get_names() == ['p1', 'p2']
3538
assert param_space.get_dimensionality() == 2
36-
s = param_space.generate_samples(4)
37-
assert s.shape == (4, 2)
38-
gold = [[-0.69167432, 0.04372489],
39-
[ 0.48009939, 2.75624102],
40-
[-0.47336997, 2.70214456],
41-
[ 0.06747879, 0.10026428]]
39+
40+
germ = np.array([[0.1, 0.2], [0.5, 0.6], [0.7, 0.8]])
41+
s = param_space.generate_samples(germ)
42+
assert s.shape == (3, 2)
43+
gold = [[-0.8, 0.3],
44+
[ 0.0, 1.5],
45+
[ 0.4, 2.1]]
4246
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)
4347

4448

4549
def test_const_param_space():
4650
param_space = ConstParameterSpace(['p1', 'p2', 'p3'], [1, 3, 'p3val'])
4751
assert param_space.get_names() == ['p1', 'p2', 'p3']
4852
assert param_space.get_dimensionality() == 3
49-
s = param_space.generate_samples(4)
53+
germ = np.array([[0.1, 0.2, 0.3],
54+
[0.4, 0.5, 0.6],
55+
[0.7, 0.8, 0.9],
56+
[0.0, 1.0, 0.5]])
57+
s = param_space.generate_samples(germ)
5058
assert s.shape == (4, 3)
5159
assert (s == [['1', '3', 'p3val'],
5260
['1', '3', 'p3val'],
@@ -57,16 +65,21 @@ def test_const_param_space():
5765
def test_hetero_param_space():
5866
np.random.seed(12)
5967
param1 = UniformParameter('p1', -1, 1)
60-
param2 = UniformParameter('p2', 0, 0)
68+
param2 = UniformParameter('p2', 0, 1)
6169
param3 = StringParameter('p3', 'p3val')
6270
param_space = HeterogeneousParameterSpace((param1, param2, param3))
6371

6472
assert param_space.get_names() == ['p1', 'p2', 'p3']
6573
assert param_space.get_dimensionality() == 3
66-
s = param_space.generate_samples(4)
74+
75+
germ = np.array([[0.1, 0.2, 0.3],
76+
[0.4, 0.5, 0.6],
77+
[0.7, 0.8, 0.9],
78+
[0.0, 1.0, 0.5]])
79+
s = param_space.generate_samples(germ)
6780
assert s.shape == (4, 3)
68-
print(s)
69-
assert (s == [['-0.6916743152406553', '0.0', 'p3val'],
70-
['0.4800993930308095', '0.0', 'p3val'],
71-
['-0.47336996962973066', '0.0', 'p3val'],
72-
['0.06747878676059549', '0.0', 'p3val']]).all()
81+
np.testing.assert_allclose(s[:, 0].astype(float), [-0.8, -0.2, 0.4, -1.0],
82+
rtol=1e-5, atol=1e-8)
83+
np.testing.assert_allclose(s[:, 1].astype(float), [0.1, 0.4, 0.7, 0.0],
84+
rtol=1e-5, atol=1e-8)
85+
assert (s[:, 2] == ['p3val', 'p3val', 'p3val', 'p3val']).all()

0 commit comments

Comments
 (0)