Skip to content

[Feature Request] ProbabilisticActor in a Composite Action Space Environment #2167

Closed
@sakakibara-yuuki

Description

@sakakibara-yuuki

Motivation

Using PPO tutorial and #1473 issue as a guide, I created an environment with a composite action space.

I also created a module using ProbabilisticActor based on the #1473.
However, the values sampled by distribution_class were not written to action entries.
(In the PPO tutorial and no composite action space version of #1473, which does not use a composite action space, the values are written to action.)

when action space is not composite.

class Module(nn.Module):
    def forward(self, x):
        return x[..., :3], x[..., 3:6]
module = TensorDictModule(Module(),
                          in_keys=["x"],
                          out_keys=["loc", "scale"])
actor = ProbabilisticActor(module,
                           in_keys=["loc", "scale"],
                           distribution_class=d.Normal,
                           # distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical}}
                          )
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))

then, I get action entries.

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

when action space is composite, in #1473.

class Module(nn.Module):
    def forward(self, x):
        return x[..., :3], x[..., 3:6], x[..., 6:]
module = TensorDictModule(Module(),
                          in_keys=["x"],
                          out_keys=[("params", "normal", "loc"), ("params", "normal", "scale"), ("params", "categ", "logits")])
actor = ProbabilisticActor(module,
                           in_keys=["params"],
                           distribution_class=CompositeDistribution,
                           distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical}}
                          )
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))

then, I don't get action entries !

TensorDict(
    fields={
        categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        params: TensorDict(
            fields={
                categ: TensorDict(
                    fields={
                        logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                normal: TensorDict(
                    fields={
                        loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                        scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Solution

I would like to obtain the following results (The output tensordict contains an action entry)

TensorDict(
    fields={
        action: TensorDict(
            fields={
                categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
                normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        params: TensorDict(
            fields={
                 categ: TensorDict(
                      fields={
                          logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
                      batch_size=torch.Size([]),
                      device=None,
                      is_shared=False),
                  normal: TensorDict(
                      fields={
                          loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                          scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
                      batch_size=torch.Size([]),
                      device=None,
                      is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Could you provide guidance on how to correctly write the sampled values into the action key when using ProbabilisticActor in a composite action space?

Alternatives

A possible symptomatic remedy would be to create a TensorDictModule that moves the normal and categ keys under the action key after using the ProbabilisticActor.

Or, Including the action key in module output_keys and the distribution_map of ProbabilisticActor solves the problem, but causes inconvenience when using PPOLoss.

class Module(nn.Module):
    def forward(self, x):
        return x[..., :3], x[..., 3:6], x[..., 6:]
module = TensorDictModule(Module(),
                          in_keys=["x"],
                          out_keys=[("params", "action", "normal", "loc"), ("params", "action", "normal", "scale"), ("params", "action", "categ", "logits")])
actor = ProbabilisticActor(module,
                           out_keys=[("action", "normal"), ("action", "categ")],
                           in_keys=["params"],
                           distribution_class=CompositeDistribution,
                           distribution_kwargs={"distribution_map": {("action", "normal"): d.Normal, ("action", "categ"): d.Categorical}}
                          )
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))

This may be due to the fact that out_keys is not taken into account when out_tensors in the forward method of ProbabilisticTensorModule is an instance of TensorDictBase.

(tensordict/nn/probabilistic.py line 379)

    @dispatch(auto_batch_size=False)
    @set_skip_existing(None)
    def forward(
        self,
        tensordict: TensorDictBase,
        tensordict_out: TensorDictBase | None = None,
        _requires_sample: bool = True,
    ) -> TensorDictBase:
        if tensordict_out is None:
            tensordict_out = tensordict

        dist = self.get_dist(tensordict)
        if _requires_sample:
            out_tensors = self._dist_sample(dist, interaction_type=interaction_type())
            if isinstance(out_tensors, TensorDictBase):
                tensordict_out.update(out_tensors)                                        #   <= no using out_keys
                if self.return_log_prob:
                    tensordict_out = dist.log_prob(tensordict_out)
            else:
                if isinstance(out_tensors, Tensor):
                    out_tensors = (out_tensors,)
                tensordict_out.update(
                    {key: value for key, value in zip(self.out_keys, out_tensors)}    #  <= using out_keys
                )
                if self.return_log_prob:
                    log_prob = dist.log_prob(*out_tensors)
                    tensordict_out.set(self.log_prob_key, log_prob)
 

Additional context

I could not determine if this is an issue or a specification, but I decided to post it.
I apologize for my poor English.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions