Skip to content

Commit 302b8d8

Browse files
author
John Tencer
committed
add test for vector parameter and fix bug
1 parent 9606c26 commit 302b8d8

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

romtools/workflows/parameter_spaces.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def generate_samples(self, uniform_dist_samples: np.array) -> np.array:
118118
for param in self.get_parameter_list():
119119
param_samples = param.generate_samples(uniform_dist_samples[:, param_idx:param_idx+param.get_dimensionality()])
120120
samples.append(param_samples)
121+
param_idx += param.get_dimensionality()
121122
return np.concatenate(samples, axis=1)
122123

123124

@@ -227,7 +228,8 @@ def get_parameter_list(self) -> Iterable[Parameter]:
227228

228229
class HeterogeneousParameterSpace(ParameterSpace):
229230
'''
230-
Heterogeneous parameter space consisting of a list of arbitrary Parameter objects
231+
Heterogeneous parameter space consisting of a list of arbitrary Parameter
232+
objects
231233
'''
232234
def __init__(self, parameter_objs: Iterable[Parameter]):
233235
self.parameters = parameter_objs

tests/romtools/workflows/test_parameter_spaces.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ def test_uniform_parameter():
2323
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)
2424

2525

26+
def test_vector_parameter():
27+
param = UniformParameter('p1', [-1, 0], [1, 3])
28+
assert param.get_name() == 'p1'
29+
assert param.get_dimensionality() == 2
30+
31+
germ = np.array([[0.1, 0.2], [0.5, 0.6], [0.7, 0.5]])
32+
s = param.generate_samples(germ)
33+
assert s.shape == (3, 2)
34+
gold = [[-0.8, 0.6],
35+
[ 0.0, 1.8],
36+
[ 0.4, 1.5]]
37+
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)
38+
39+
2640
def test_string_parameter():
2741
param = StringParameter('p1', 'p1val')
2842
assert param.get_name() == 'p1'
@@ -39,12 +53,12 @@ def test_uniform_param_space():
3953
assert param_space.get_names() == ['p1', 'p2']
4054
assert param_space.get_dimensionality() == 2
4155

42-
germ = np.array([[0.1, 0.2], [0.5, 0.6], [0.7, 0.8]])
56+
germ = np.array([[0.1, 0.2], [0.5, 0.6], [0.7, 0.5]])
4357
s = param_space.generate_samples(germ)
4458
assert s.shape == (3, 2)
45-
gold = [[-0.8, 0.3],
46-
[ 0.0, 1.5],
47-
[ 0.4, 2.1]]
59+
gold = [[-0.8, 0.6],
60+
[ 0.0, 1.8],
61+
[ 0.4, 1.5]]
4862
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)
4963

5064

@@ -81,7 +95,7 @@ def test_hetero_param_space():
8195
assert s.shape == (4, 3)
8296
np.testing.assert_allclose(s[:, 0].astype(float), [-0.8, -0.2, 0.4, -1.0],
8397
rtol=1e-5, atol=1e-8)
84-
np.testing.assert_allclose(s[:, 1].astype(float), [0.1, 0.4, 0.7, 0.0],
98+
np.testing.assert_allclose(s[:, 1].astype(float), [0.2, 0.5, 0.8, 1.0],
8599
rtol=1e-5, atol=1e-8)
86100
assert (s[:, 2] == ['p3val', 'p3val', 'p3val', 'p3val']).all()
87101

@@ -91,10 +105,10 @@ def test_monte_carlo_sample():
91105
s = monte_carlo_sample(param_space, 4, seed=12)
92106
assert s.shape == (4, 2)
93107

94-
gold = [[-0.69167432, 0.46248853],
95-
[-0.47336997, 0.78994505],
96-
[-0.97085008, 0.04372489],
97-
[ 0.80142971, 2.70214456]]
108+
gold = [[-0.69167432, 2.22014909],
109+
[-0.47336997, 1.60121818],
110+
[-0.97085008, 2.75624102],
111+
[ 0.80142971, 0.10026428]]
98112
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)
99113

100114

@@ -104,8 +118,8 @@ def test_latin_hypercube_sample():
104118
s = latin_hypercube_sample(param_space, 4, seed=12)
105119
assert s.shape == (4, 2)
106120

107-
gold = [[-0.12541223, 1.31188166],
108-
[ 0.40533981, 2.10800971],
109-
[ 0.82505538, 2.73758307],
110-
[-0.83522287, 0.24716569]]
121+
gold = [[-0.12541223, 0.78993529],
122+
[ 0.40533981, 2.86553144],
123+
[ 0.82505538, 2.07709407],
124+
[-0.83522287, 0.66369046]]
111125
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)

0 commit comments

Comments
 (0)