Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] ActionDiscretizer custom sampler #2608

Open
1 task done
oslumbers opened this issue Nov 25, 2024 · 0 comments
Open
1 task done

[Feature Request] ActionDiscretizer custom sampler #2608

oslumbers opened this issue Nov 25, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@oslumbers
Copy link
Contributor

Motivation

Currently you cannot implement a custom sampling technique for the ActionDiscretizer transform.

Solution

Bring custom_arange out of transform_input_spec and make it a method of ActionDiscretizer. Wrappers around ActionDiscretizer can then update this.

Alternatives

Additional context

class SamplingStrategy(IntEnum):

        MEDIAN = 0
        LOW = 1
        HIGH = 2
        RANDOM = 3

    def __init__(
        self,
        num_intervals: int | torch.Tensor,
        action_key: NestedKey = "action",
        out_action_key: NestedKey = None,
        sampling=None,
        categorical: bool = True,
    ):
        if out_action_key is None:
            out_action_key = action_key
        super().__init__(in_keys_inv=[action_key], out_keys_inv=[out_action_key])
        self.action_key = action_key
        self.out_action_key = out_action_key
        if not isinstance(num_intervals, torch.Tensor):
            self.num_intervals = num_intervals
        else:
            self.register_buffer("num_intervals", num_intervals)
        if sampling is None:
            sampling = self.SamplingStrategy.MEDIAN
        self.sampling = sampling
        self.categorical = categorical

    def __repr__(self):
        def _indent(s):
            return indent(s, 4 * " ")

        num_intervals = f"num_intervals={self.num_intervals}"
        action_key = f"action_key={self.action_key}"
        out_action_key = f"out_action_key={self.out_action_key}"
        sampling = f"sampling={self.sampling}"
        categorical = f"categorical={self.categorical}"
        return (
            f"{type(self).__name__}(\n{_indent(num_intervals)},\n{_indent(action_key)},"
            f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})"
        )

    def _custom_arange(self, nint, device):
        result = torch.arange(
            start=0.0,
            end=1.0,
            step=1 / nint,
            dtype=self.dtype,
            device=device,
        )
        result_ = result
        if self.sampling in (
            self.SamplingStrategy.HIGH,
            self.SamplingStrategy.MEDIAN,
        ):
            result_ = (1 - result).flip(0)
        if self.sampling == self.SamplingStrategy.MEDIAN:
            result = (result + result_) / 2
        else:
            result = result_
        return result

    def transform_input_spec(self, input_spec):
        try:
            action_spec = input_spec["full_action_spec", self.in_keys_inv[0]]
            if not isinstance(action_spec, Bounded):
                raise TypeError(
                    f"action spec type {type(action_spec)} is not supported."
                )

            n_act = action_spec.shape
            if not n_act:
                n_act = 1
            else:
                n_act = n_act[-1]
            self.n_act = n_act

            self.dtype = action_spec.dtype
            interval = (action_spec.high - action_spec.low).unsqueeze(-1)

            num_intervals = self.num_intervals

            if isinstance(num_intervals, int):
                arange = (
                    self._custom_arange(num_intervals, action_spec.device).expand(
                        n_act, num_intervals
                    )
                    * interval
                )
                self.register_buffer(
                    "intervals", action_spec.low.unsqueeze(-1) + arange
                )
            else:
                arange = [
                    self._custom_arange(_num_intervals, action_spec.device) * interval
                    for _num_intervals, interval in zip(
                        num_intervals.tolist(), interval.unbind(-2)
                    )
                ]
                self.intervals = [
                    low + arange
                    for low, arange in zip(
                        action_spec.low.unsqueeze(-1).unbind(-2), arange
                    )
                ]

Checklist

  • I have checked that there is no similar issue in the repo (required)
@oslumbers oslumbers added the enhancement New feature or request label Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants