Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Structured dtype #1195

Open
wants to merge 1 commit into
base: gh/vmoens/46/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,20 @@ def _stack_uninit_params(list_of_params, dim=0, out=None):
)
out.batch_size = torch.Size([len(list_of_params)])
return out

def implements_for_tdtype(torch_function: Callable) -> Callable[[Callable], Callable]:
"""Register a torch function override for TensorDict."""

from tensordict.dtype import TDTYPE_HANDLED_FUNCTIONS

@functools.wraps(torch_function)
def decorator(func: Callable) -> Callable:
TDTYPE_HANDLED_FUNCTIONS[torch_function] = func
return func

return decorator

@implements_for_tdtype(torch.Tensor.view)
def view(tensor: torch.tensor, dtype: Any) -> TensorDictBase:
from tensordict.dtype import StructDtype
return StructDtype.view(tensor, dtype)
112 changes: 112 additions & 0 deletions tensordict/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import deque
import orjson as json
from typing import Callable, Any


TDTYPE_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}

class StructDtype:
# def __new__(cls, map=None):
# if isinstance(map, StructDtype):
# return map
# return super().__new__(cls)
def __init__(self, map=None):
if map is None:
map = {}
assert isinstance(map, dict)
self._maps = map

@classmethod
def from_td(cls, data: "TensorDictBase"):
from tensordict.base import _is_tensor_collection
self = cls()
map = self._maps
stack = deque()
stack.append((self, data))
while len(stack):
sdtype, local_data = stack.popleft()
map = sdtype._maps
# TODO: handle lazy stacks here
for k, v in local_data.items():
cls = type(v)
if _is_tensor_collection(cls):
# TODO: handle different dtypes here
# TODO: handle LazyStacks here
newmap = map[k] = StructDtype({})
stack.append((newmap, v))
else:
map[k] = {
"shape": v.shape,
"dtype": v.dtype,
}
return self

def items(self, include_nested: bool=False, leaves_only: bool=False):
stack = deque()
stack.append(self)
while len(stack):
node = stack.popleft()
for k, v in node._maps.items():
if isinstance(v, StructDtype):
if include_nested:
stack.append(v)
if not leaves_only:
yield (k, v)
else:
yield k, v

def values(self, include_nested: bool=False, leaves_only: bool=False):
yield from (_, v in self.items(include_nested=include_nested, leaves_only=leaves_only))

def keys(self, include_nested: bool=False, leaves_only: bool=False):
yield from (k, _ in self.items(include_nested=include_nested, leaves_only=leaves_only))

# def json(self):
# return json.dumps(metadata_dict)

@classmethod
def __torch_function__(
cls,
func: Callable,
types: tuple[type, ...],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> Callable:
if kwargs is None:
kwargs = {}
if func not in TDTYPE_HANDLED_FUNCTIONS:
return NotImplemented
return TDTYPE_HANDLED_FUNCTIONS[func](*args, **kwargs)


@classmethod
def view(cls, tensor, dtype):
from tensordict import TensorDict
ns = []
shapes = []
dts = []
keys = []
stack = deque()
stack.append((dtype.items(), ()))
tensor_itemsize = tensor.dtype.itemsize
while len(stack):
items, prefix = stack.popleft()
for k, dt in items:
currentk = prefix + (k,)
if isinstance(dt, StructDtype):
stack.append((dt.items(), currentk))
continue
assert currentk not in keys, (currentk, keys)
keys.append(currentk)
s = dt["shape"]
dt = dt["dtype"]
shapes.append(s)
dts.append(dt)
nelts = (dt.itemsize * s.numel()) // tensor_itemsize
ns.append(nelts)

return TensorDict({k: v.view(dt).view(shape) for k, v, dt, shape in zip(keys, tensor.split(ns), dts, shapes, strict=True)})
Loading