Skip to content

[Feature] Support lazy tensordict inputs in KL reward transform #2884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,17 +2150,18 @@ def _step_proc_data(self, next_tensordict_out):
else next_tensordict_out.shape
)
for reward_key in self.reward_keys:
reward = next_tensordict_out.get(reward_key)
expected_reward_shape = torch.Size(
[
*leading_batch_size,
*self.output_spec["full_reward_spec"][reward_key].shape,
]
)
actual_reward_shape = reward.shape
if actual_reward_shape != expected_reward_shape:
reward = reward.view(expected_reward_shape)
next_tensordict_out.set(reward_key, reward)
if all(s > 0 for s in expected_reward_shape):
reward = next_tensordict_out.get(reward_key, as_nested=True)
actual_reward_shape = reward.shape
if actual_reward_shape != expected_reward_shape:
reward = reward.view(expected_reward_shape)
next_tensordict_out.set(reward_key, reward)

self._complete_done(self.full_done_spec, next_tensordict_out)

Expand Down
147 changes: 104 additions & 43 deletions torchrl/envs/transforms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

import torch
from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictParams,
)
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
from tensordict.utils import _zip_strict, is_seq_of_nested_key
from torch import nn

Expand Down Expand Up @@ -657,9 +653,11 @@ def __init__(
in_keys=None,
out_keys=None,
requires_grad=False,
# TODO: adapt this to new API
log_prob_key: NestedKey = "sample_log_prob",
action_key: NestedKey = "action",
functional: bool = True,
action_key: NestedKey | None = None,
functional: bool | None = None,
device: torch.device | None = None,
):
if in_keys is None:
in_keys = self.DEFAULT_IN_KEYS
Expand All @@ -676,15 +674,16 @@ def __init__(
raise ValueError(
f"Only one in_key/out_key is allowed, got in_keys={self.in_keys}, out_keys={self.out_keys}."
)
# for convenience, convert out_keys to tuples
self._out_keys = [
out_key if isinstance(out_key, tuple) else (out_key,)
for out_key in self._out_keys
]
self._out_keys = [unravel_key(out_key) for out_key in self._out_keys]

# update the in_keys for dispatch etc
self.in_keys = self.in_keys + actor.in_keys
self.in_keys = [unravel_key(in_key) for in_key in self.in_keys]

if functional is None:
from torchrl.modules.llm import CategoricalSequential

functional = not isinstance(actor, CategoricalSequential)
self.functional = functional
# check that the model has parameters
if functional:
Expand Down Expand Up @@ -721,6 +720,7 @@ def _make_detached_param(x):

# self._buffers["actor_params"] = params.clone().detach()

self.device = device
self.action_key = action_key

# find the sample log-prob key
Expand All @@ -736,55 +736,102 @@ def find_sample_log_prob(module):
coef = torch.as_tensor(coef)
self.register_buffer("coef", coef)

def set_container(self, container: Transform | EnvBase) -> None:
result = super().set_container(container)
if self.action_key is None:
parent = getattr(self, "parent", None)
if parent is not None:
action_keys = parent.action_keys
if len(action_keys) != 1:
raise ValueError(
f"More than one action_key found. Please pass the `action_key` argument directly to {type(self).__name__}."
)
action_key = action_keys[0]
self.action_key = action_key
return result

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
tensordict_reset = self._step(tensordict_reset, tensordict_reset)
return tensordict_reset

def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
# run the actor on the tensordict
action = next_tensordict.get(self.action_key, None)
action_key = self.action_key
if action_key is None:
raise ValueError(
f"action_key is required. Please set a parent for the {type(self).__name__} to recover the action keys automatically, "
f"or pass the action_key argument directly to {type(self).__name__} constructor."
)
action = tensordict.get(action_key, None)
if action is None:
if not self.missing_tolerance:
raise RuntimeError(
f"Action with key {action_key} not found data {tensordict}"
)
# being called after reset or without action, skipping
if self.out_keys[0] != ("reward",) and self.parent is not None:
if self.out_keys[0] != "reward" and self.parent is not None:
next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
return next_tensordict

if self.device is not None:
action = action.to(self.device)

if self.functional:
with self.frozen_params.to_module(self.functional_actor):
dist = self.functional_actor.get_dist(next_tensordict.clone(False))
dist = self.functional_actor.get_dist(tensordict.clone(False))
# get the log_prob given the original model
log_prob = dist.log_prob(action)
elif isinstance(
self.functional_actor,
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
):
# with self.frozen_params.to_module(self.functional_actor):
dist = self.functional_actor.get_dist(next_tensordict.copy())
# get the log_prob given the original model
log_prob = dist.log_prob(action)
else:
log_prob = self.functional_actor(next_tensordict.copy()).get(
self.sample_log_prob_key
elif hasattr(self.functional_actor, "log_prob"):
if self.device is not None:
td_device = tensordict.to(self.device)
else:
td_device = tensordict
log_prob = self.functional_actor.log_prob(
td_device, as_nested_tensor=True, layout=torch.strided
)
else:
log_prob = self.functional_actor(tensordict).get(self.sample_log_prob_key)

reward_key = self.in_keys[0]
reward = next_tensordict.get("next").get(reward_key)
curr_log_prob = next_tensordict.get(self.sample_log_prob_key)
reward = next_tensordict.get(reward_key)
curr_log_prob = tensordict.get(
self.sample_log_prob_key, as_nested_tensor=True, layout=torch.strided
)
log_prob = log_prob.to(curr_log_prob.device)
curr_log_prob = curr_log_prob.unsqueeze(-1)
# log_prob = log_prob.unsqueeze(-1)

# we use the unbiased consistent estimator of the KL: log_p(x) - log_q(x) when x ~ p(x)
kl = (curr_log_prob - log_prob).view_as(reward)
next_tensordict.set(("next", *self.out_keys[0]), reward + self.coef * kl)
if not reward.is_nested and log_prob.is_nested:
reward = torch.nested.nested_tensor(
[rew.expand(lp.shape) for rew, lp in zip(reward, log_prob)],
layout=torch.strided,
)
if log_prob[0].shape != curr_log_prob[0].shape:
# Don't check shapes if nested
raise ValueError(
f"the log-probability tensor shapes must match, got cur_log_prob.shape={curr_log_prob[0].shape} and log_prob.shape={log_prob[0].shape}."
)
if reward is not None and reward.ndim != curr_log_prob.ndim:
raise ValueError(
"The number of dimensions of reward must be the same as the number of dimensions of the KL "
f"term. Got ndim={reward.ndim} and {curr_log_prob.ndim} respectively."
)
kl = curr_log_prob - log_prob
if reward is None:
reward = 0
next_tensordict.set(self.out_keys[0], reward + self.coef * kl)
return next_tensordict

def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
with tensordict.unlock_():
return self._call(tensordict.set("next", next_tensordict)).pop("next")

forward = _call
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
next_td = tensordict.pop("next")
next_td = self._step(tensordict, next_td)
return tensordict.set("next", next_td)

def transform_output_spec(self, output_spec: Composite) -> Composite:
in_key = unravel_key(self.in_keys[0])
Expand All @@ -800,24 +847,38 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
reward_key = "reward"
else:
raise KeyError("Couln't find the reward key.")

shape = output_spec["full_reward_spec"][reward_key].shape
shape = (*shape[:-2], -1, 1)
reward_spec = Unbounded(
device=output_spec.device,
shape=output_spec["full_reward_spec"][reward_key].shape,
shape=shape,
)
output_spec["full_reward_spec"] = Composite(
{reward_key: reward_spec},
shape=output_spec["full_reward_spec"].shape,
)
elif in_key == "reward":
# TODO: we should at least allow to make this a component of the reward specs, to avoid a call during reset
parent = self.parent
reward_spec = output_spec["full_reward_spec"][parent.reward_key].clone()
reward_spec = output_spec["full_reward_spec"][parent.reward_key]

shape = reward_spec.shape
shape = (*shape[:-2], -1, 1)
reward_spec = reward_spec.clone()
reward_spec.shape = torch.Size(shape)

# then we need to populate the output keys
observation_spec = output_spec["full_observation_spec"]
observation_spec[out_key] = reward_spec
else:
observation_spec = output_spec["full_observation_spec"]
reward_spec = observation_spec[in_key].clone()
reward_spec = observation_spec[in_key]

shape = reward_spec.shape
shape = (*shape[:-2], -1, 1)
reward_spec = reward_spec.clone()
reward_spec.shape = torch.Size(shape)

# then we need to populate the output keys
observation_spec[out_key] = reward_spec
return output_spec
Loading