Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] ProbabilisticActor in a Composite Action Space Environment #2167

Closed
1 task done
sakakibara-yuuki opened this issue May 21, 2024 · 1 comment · Fixed by #2220
Closed
1 task done
Assignees
Labels
enhancement New feature or request

Comments

@sakakibara-yuuki
Copy link

Motivation

Using PPO tutorial and #1473 issue as a guide, I created an environment with a composite action space.

I also created a module using ProbabilisticActor based on the #1473.
However, the values sampled by distribution_class were not written to action entries.
(In the PPO tutorial and no composite action space version of #1473, which does not use a composite action space, the values are written to action.)

when action space is not composite.

class Module(nn.Module):
    def forward(self, x):
        return x[..., :3], x[..., 3:6]
module = TensorDictModule(Module(),
                          in_keys=["x"],
                          out_keys=["loc", "scale"])
actor = ProbabilisticActor(module,
                           in_keys=["loc", "scale"],
                           distribution_class=d.Normal,
                           # distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical}}
                          )
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))

then, I get action entries.

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

when action space is composite, in #1473.

class Module(nn.Module):
    def forward(self, x):
        return x[..., :3], x[..., 3:6], x[..., 6:]
module = TensorDictModule(Module(),
                          in_keys=["x"],
                          out_keys=[("params", "normal", "loc"), ("params", "normal", "scale"), ("params", "categ", "logits")])
actor = ProbabilisticActor(module,
                           in_keys=["params"],
                           distribution_class=CompositeDistribution,
                           distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical}}
                          )
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))

then, I don't get action entries !

TensorDict(
    fields={
        categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                categ: TensorDict(
                    fields={
                        logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                normal: TensorDict(
                    fields={
                        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                        scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Solution

I would like to obtain the following results (The output tensordict contains an action entry)

TensorDict(
    fields={
        action: TensorDict(
            fields={
                categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        params: TensorDict(
            fields={
                 categ: TensorDict(
                      fields={
                          logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
                      batch_size=torch.Size([]),
                      device=None,
                      is_shared=False),
                  normal: TensorDict(
                      fields={
                          loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                          scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
                      batch_size=torch.Size([]),
                      device=None,
                      is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Could you provide guidance on how to correctly write the sampled values into the action key when using ProbabilisticActor in a composite action space?

Alternatives

A possible symptomatic remedy would be to create a TensorDictModule that moves the normal and categ keys under the action key after using the ProbabilisticActor.

Or, Including the action key in module output_keys and the distribution_map of ProbabilisticActor solves the problem, but causes inconvenience when using PPOLoss.

class Module(nn.Module):
    def forward(self, x):
        return x[..., :3], x[..., 3:6], x[..., 6:]
module = TensorDictModule(Module(),
                          in_keys=["x"],
                          out_keys=[("params", "action", "normal", "loc"), ("params", "action", "normal", "scale"), ("params", "action", "categ", "logits")])
actor = ProbabilisticActor(module,
                           out_keys=[("action", "normal"), ("action", "categ")],
                           in_keys=["params"],
                           distribution_class=CompositeDistribution,
                           distribution_kwargs={"distribution_map": {("action", "normal"): d.Normal, ("action", "categ"): d.Categorical}}
                          )
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))

This may be due to the fact that out_keys is not taken into account when out_tensors in the forward method of ProbabilisticTensorModule is an instance of TensorDictBase.

(tensordict/nn/probabilistic.py line 379)

    @dispatch(auto_batch_size=False)
    @set_skip_existing(None)
    def forward(
        self,
        tensordict: TensorDictBase,
        tensordict_out: TensorDictBase | None = None,
        _requires_sample: bool = True,
    ) -> TensorDictBase:
        if tensordict_out is None:
            tensordict_out = tensordict

        dist = self.get_dist(tensordict)
        if _requires_sample:
            out_tensors = self._dist_sample(dist, interaction_type=interaction_type())
            if isinstance(out_tensors, TensorDictBase):
                tensordict_out.update(out_tensors)                                        #   <= no using out_keys
                if self.return_log_prob:
                    tensordict_out = dist.log_prob(tensordict_out)
            else:
                if isinstance(out_tensors, Tensor):
                    out_tensors = (out_tensors,)
                tensordict_out.update(
                    {key: value for key, value in zip(self.out_keys, out_tensors)}    #  <= using out_keys
                )
                if self.return_log_prob:
                    log_prob = dist.log_prob(*out_tensors)
                    tensordict_out.set(self.log_prob_key, log_prob)
 

Additional context

I could not determine if this is an issue or a specification, but I decided to post it.
I apologize for my poor English.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@vmoens
Copy link
Contributor

vmoens commented Jun 11, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants