Skip to content

Commit 54d9988

Browse files
authored
Merge pull request #108 from Pressio/const_param_space
add ConstParamSpace class and testing
2 parents 1214c40 + 9b72ff7 commit 54d9988

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

romtools/workflows/parameter_spaces.py

+23
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,26 @@ def generate_samples(self, number_of_samples):
9696
samples = np.random.uniform(self.__lower_bounds, self.__upper_bounds,
9797
size=(number_of_samples, self.__n_params))
9898
return samples
99+
100+
101+
class ConstParamSpace(ParameterSpace):
102+
'''
103+
Constant parameter space which converts all constant values to str-type
104+
105+
Useful if you need to execute workflows in a non-stochastic setting
106+
'''
107+
def __init__(self, parameter_names, parameter_values):
108+
self._parameter_names = parameter_names
109+
self._n_params = len(parameter_names)
110+
self._parameter_values = np.array(parameter_values, dtype=str)
111+
self._parameter_values = self._parameter_values.reshape(1,
112+
self._n_params)
113+
114+
def get_names(self):
115+
return self._parameter_names
116+
117+
def get_dimensionality(self):
118+
return self._n_params
119+
120+
def generate_samples(self, number_of_samples):
121+
return np.repeat(self._parameter_values, number_of_samples, axis=0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
from romtools.workflows.parameter_spaces import UniformParameterSpace
3+
from romtools.workflows.parameter_spaces import ConstParamSpace
4+
5+
6+
def test_uniform_param_space():
7+
np.random.seed(12)
8+
param_space = UniformParameterSpace(['p1', 'p2'], [-1, 0], [1, 3])
9+
assert param_space.get_names() == ['p1', 'p2']
10+
assert param_space.get_dimensionality() == 2
11+
s = param_space.generate_samples(4)
12+
assert s.shape == (4, 2)
13+
gold = [[-0.69167432, 2.22014909],
14+
[-0.47336997, 1.60121818],
15+
[-0.97085008, 2.75624102],
16+
[0.80142971, 0.10026428]]
17+
np.testing.assert_allclose(s, gold, rtol=1e-5, atol=1e-8)
18+
19+
20+
def test_const_param_space():
21+
param_space = ConstParamSpace(['p1', 'p2', 'p3'], [1, 3, 'p3val'])
22+
assert param_space.get_names() == ['p1', 'p2', 'p3']
23+
assert param_space.get_dimensionality() == 3
24+
s = param_space.generate_samples(4)
25+
assert s.shape == (4, 3)
26+
assert (s == [['1', '3', 'p3val'],
27+
['1', '3', 'p3val'],
28+
['1', '3', 'p3val'],
29+
['1', '3', 'p3val']]).all()

0 commit comments

Comments
 (0)