Skip to content

Commit

Permalink
Fix seeding issues with correlated Thompson sampling (#29)
Browse files Browse the repository at this point in the history
* Add regression test for #28
* Fix seeding problem in correlated Thompson sampling

Closes #28
  • Loading branch information
michaelosthege authored May 26, 2022
1 parent 78acd67 commit 69c36ce
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyrff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from . thompson import sample_batch, sampling_probabilities
from . utils import multi_start_fmin

__version__ = '2.0.1'
__version__ = '2.0.2'
24 changes: 24 additions & 0 deletions pyrff/test_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ def test_correlated_sampling(self):
assert batch.count('C') == 0
pass

@pytest.mark.parametrize("correlated", [False, True])
def test_seeding(self, correlated):
"""This is a regression test for https://github.com/michaelosthege/pyrff/issues/28"""
rng = numpy.random.RandomState(123)
samples = [
rng.uniform(size=200),
rng.uniform(size=200),
rng.uniform(size=200),
]
batches = []
for _ in range(10):
batch = thompson.sample_batch(
candidate_samples=samples,
ids=["A", "B", "C"],
correlated=correlated,
batch_size=10,
seed=123,
)
batches.append("".join(batch))

# Assert that all batches are identical.
assert len(set(batches)) == 1
pass


class TestExceptions:
def test_id_count(self):
Expand Down
2 changes: 1 addition & 1 deletion pyrff/thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def sample_batch(
# to prevent always selecting lower-numbered candidates when >=2 samples are equal
col_order = random.permutation(n_candidates)
if correlated:
idx = numpy.repeat(numpy.random.randint(n_samples[0]), n_candidates)
idx = numpy.repeat(random.randint(n_samples[0]), n_candidates)
else:
idx = random.randint(n_samples, size=n_candidates)
selected_samples = samples[:, col_order][idx, numpy.arange(n_candidates)]
Expand Down

0 comments on commit 69c36ce

Please sign in to comment.