Skip to content

Commit

Permalink
Variation broadcaster that allows a variation to be broadcasted to mu…
Browse files Browse the repository at this point in the history
…ltiple callers.

PiperOrigin-RevId: 683243432
Change-Id: I6ed10e3cb14a39fa5d82886d6dab9faa0ce634d9
  • Loading branch information
Leonard Hasenclever authored and copybara-github committed Oct 7, 2024
1 parent 2456cfa commit a2ad4e3
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
65 changes: 65 additions & 0 deletions dm_control/composer/variation/variation_broadcaster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""A broadcaster that allows sharing of variation values across many callers."""

import collections
import weakref

from dm_control.composer import variation


class VariationBroadcaster:
"""Allows a variation to be broadcasted to multiple callers.
This class wraps a `Variation` object and generates multiple proxies that
can be used in place of the wrapped `Variation`. The broadcaster updates its
value in rounds. At the beginning of each round, the broadcaster re-evaluates
the wrapped `Variation` and caches the new value internally. When a proxy
is called, the broadcaster will return this cached value, thus ensuring that
all proxied values are the same. The round ends when all of the proxies have
been called exactly once. It is an error to call any particular proxy more
than once per round.
"""

def __init__(self, wrapped_variation: variation.Variation):
self._wrapped_variation = wrapped_variation
self._cached_values = weakref.WeakKeyDictionary()

def get_proxy(self) -> variation.Variation:
"""Returns a `Variation` to be used in place of the wrapped `Variation`."""
new_proxy = _BroadcastedValueProxy(self)
self._cached_values[new_proxy] = collections.deque()
return new_proxy

def _get_value(self, proxy, random_state):
"""Returns the variation value for a proxy owned by this broadcaster."""
cached_values = self._cached_values[proxy]
if not cached_values:
new_value = variation.evaluate(
self._wrapped_variation, None, None, random_state)
for values in self._cached_values.values():
values.append(new_value)
return cached_values.popleft()


class _BroadcastedValueProxy(variation.Variation):

def __init__(self, broadcaster):
self._broadcaster = broadcaster

def __call__(self, initial_value=None, current_value=None, random_state=None):
value = self._broadcaster._get_value(self, random_state) # pylint: disable=protected-access
return value
104 changes: 104 additions & 0 deletions dm_control/composer/variation/variation_broadcaster_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2024 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from absl.testing import absltest
from dm_control.composer import variation
from dm_control.composer.variation import distributions
from dm_control.composer.variation import variation_broadcaster
import numpy as np


class VariationBroadcasterTest(absltest.TestCase):

def test_can_generate_values(self):
random_state = np.random.RandomState(2348)
expected_values = [random_state.uniform(0, 1) for _ in range(5)]

random_state = np.random.RandomState(2348)
broadcaster = variation_broadcaster.VariationBroadcaster(
distributions.Uniform(0, 1)
)
proxy_1 = broadcaster.get_proxy()
proxy_2 = broadcaster.get_proxy()
proxy_3 = broadcaster.get_proxy()

self.assertEqual(
variation.evaluate(proxy_1, random_state=random_state),
expected_values[0],
)
self.assertEqual(
variation.evaluate(proxy_2, random_state=random_state),
expected_values[0],
)
self.assertEqual(
variation.evaluate(proxy_3, random_state=random_state),
expected_values[0],
)

self.assertEqual(
variation.evaluate(proxy_1, random_state=random_state),
expected_values[1],
)
self.assertEqual(
variation.evaluate(proxy_1, random_state=random_state),
expected_values[2],
)

self.assertEqual(
variation.evaluate(proxy_2, random_state=random_state),
expected_values[1],
)
self.assertEqual(
variation.evaluate(proxy_3, random_state=random_state),
expected_values[1],
)
self.assertEqual(
variation.evaluate(proxy_3, random_state=random_state),
expected_values[2],
)

self.assertEqual(
variation.evaluate(proxy_3, random_state=random_state),
expected_values[3],
)
self.assertEqual(
variation.evaluate(proxy_1, random_state=random_state),
expected_values[3],
)
self.assertEqual(
variation.evaluate(proxy_2, random_state=random_state),
expected_values[2],
)

self.assertEqual(
variation.evaluate(proxy_1, random_state=random_state),
expected_values[4],
)
self.assertEqual(
variation.evaluate(proxy_2, random_state=random_state),
expected_values[3],
)
self.assertEqual(
variation.evaluate(proxy_2, random_state=random_state),
expected_values[4],
)
self.assertEqual(
variation.evaluate(proxy_3, random_state=random_state),
expected_values[4],
)


if __name__ == '__main__':
absltest.main()

0 comments on commit a2ad4e3

Please sign in to comment.