Skip to content

Commit

Permalink
[Feature] ActionDiscretizer custom sampling (#2609)
Browse files Browse the repository at this point in the history
Co-authored-by: Oliver Slumbers <[email protected]>
  • Loading branch information
oslumbers and oslumbersh authored Dec 3, 2024
1 parent 607ebc5 commit 3da76f0
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8583,6 +8583,26 @@ def _indent(s):
f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})"
)

def _custom_arange(self, nint, device):
result = torch.arange(
start=0.0,
end=1.0,
step=1 / nint,
dtype=self.dtype,
device=device,
)
result_ = result
if self.sampling in (
self.SamplingStrategy.HIGH,
self.SamplingStrategy.MEDIAN,
):
result_ = (1 - result).flip(0)
if self.sampling == self.SamplingStrategy.MEDIAN:
result = (result + result_) / 2
else:
result = result_
return result

def transform_input_spec(self, input_spec):
try:
action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]]
Expand Down Expand Up @@ -8611,29 +8631,11 @@ def transform_input_spec(self, input_spec):
num_intervals = int(num_intervals.squeeze())
self.num_intervals = torch.as_tensor(num_intervals)

def custom_arange(nint):
result = torch.arange(
start=0.0,
end=1.0,
step=1 / nint,
dtype=self.dtype,
device=action_spec.device,
)
result_ = result
if self.sampling in (
self.SamplingStrategy.HIGH,
self.SamplingStrategy.MEDIAN,
):
result_ = (1 - result).flip(0)
if self.sampling == self.SamplingStrategy.MEDIAN:
result = (result + result_) / 2
else:
result = result_
return result

if isinstance(num_intervals, int):
arange = (
custom_arange(num_intervals).expand((*n_act, num_intervals))
self._custom_arange(num_intervals, action_spec.device).expand(
(*n_act, num_intervals)
)
* interval
)
low = action_spec.low
Expand All @@ -8642,7 +8644,7 @@ def custom_arange(nint):
self.register_buffer("intervals", low + arange)
else:
arange = [
custom_arange(_num_intervals) * interval
self._custom_arange(_num_intervals, action_spec.device) * interval
for _num_intervals, interval in zip(
num_intervals.tolist(), interval.unbind(-2)
)
Expand Down

1 comment on commit 3da76f0

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 3da76f0 Previous: 607ebc5 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 31.775268382321023 iter/sec (stddev: 0.191116422475419) 229.6351932644977 iter/sec (stddev: 0.0007991904573999483) 7.23

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.