Description
Motivation
Using PPO tutorial and #1473 issue as a guide, I created an environment with a composite action space.
I also created a module using ProbabilisticActor
based on the #1473.
However, the values sampled by distribution_class were not written to action
entries.
(In the PPO tutorial and no composite action space version of #1473, which does not use a composite action space, the values are written to action.)
when action space is not composite.
class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6]
module = TensorDictModule(Module(),
in_keys=["x"],
out_keys=["loc", "scale"])
actor = ProbabilisticActor(module,
in_keys=["loc", "scale"],
distribution_class=d.Normal,
# distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical}}
)
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))
then, I get action
entries.
TensorDict(
fields={
action: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
when action space is composite, in #1473.
class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6], x[..., 6:]
module = TensorDictModule(Module(),
in_keys=["x"],
out_keys=[("params", "normal", "loc"), ("params", "normal", "scale"), ("params", "categ", "logits")])
actor = ProbabilisticActor(module,
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical}}
)
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))
then, I don't get action
entries !
TensorDict(
fields={
categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
params: TensorDict(
fields={
categ: TensorDict(
fields={
logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
normal: TensorDict(
fields={
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
Solution
I would like to obtain the following results (The output tensordict contains an action
entry)
TensorDict(
fields={
action: TensorDict(
fields={
categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
params: TensorDict(
fields={
categ: TensorDict(
fields={
logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
normal: TensorDict(
fields={
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
Could you provide guidance on how to correctly write the sampled values into the action
key when using ProbabilisticActor
in a composite action space?
Alternatives
A possible symptomatic remedy would be to create a TensorDictModule that moves the normal
and categ
keys under the action
key after using the ProbabilisticActor.
Or, Including the action
key in module output_keys and the distribution_map of ProbabilisticActor solves the problem, but causes inconvenience when using PPOLoss.
class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6], x[..., 6:]
module = TensorDictModule(Module(),
in_keys=["x"],
out_keys=[("params", "action", "normal", "loc"), ("params", "action", "normal", "scale"), ("params", "action", "categ", "logits")])
actor = ProbabilisticActor(module,
out_keys=[("action", "normal"), ("action", "categ")],
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={"distribution_map": {("action", "normal"): d.Normal, ("action", "categ"): d.Categorical}}
)
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))
This may be due to the fact that out_keys is not taken into account when out_tensors
in the forward
method of ProbabilisticTensorModule
is an instance of TensorDictBase
.
(tensordict/nn/probabilistic.py line 379)
@dispatch(auto_batch_size=False)
@set_skip_existing(None)
def forward(
self,
tensordict: TensorDictBase,
tensordict_out: TensorDictBase | None = None,
_requires_sample: bool = True,
) -> TensorDictBase:
if tensordict_out is None:
tensordict_out = tensordict
dist = self.get_dist(tensordict)
if _requires_sample:
out_tensors = self._dist_sample(dist, interaction_type=interaction_type())
if isinstance(out_tensors, TensorDictBase):
tensordict_out.update(out_tensors) # <= no using out_keys
if self.return_log_prob:
tensordict_out = dist.log_prob(tensordict_out)
else:
if isinstance(out_tensors, Tensor):
out_tensors = (out_tensors,)
tensordict_out.update(
{key: value for key, value in zip(self.out_keys, out_tensors)} # <= using out_keys
)
if self.return_log_prob:
log_prob = dist.log_prob(*out_tensors)
tensordict_out.set(self.log_prob_key, log_prob)
Additional context
I could not determine if this is an issue or a specification, but I decided to post it.
I apologize for my poor English.
Checklist
- I have checked that there is no similar issue in the repo (required)