Open
Description
Describe the bug
torch.vmap
seems to be incompatible with tensordict.TensorDictBase
input when chunk_size
is not None.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
import torch
from tensordict import tensorclass
@tensorclass
class Data:
a: torch.Tensor
b: torch.Tensor
def AplusB(data):
return data.a+data.b
data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
result = torch.vmap(AplusB, chunk_size=1)(data)
print(result)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[3], [line 13](vscode-notebook-cell:?execution_count=3&line=13)
[10](vscode-notebook-cell:?execution_count=3&line=10) return data.a+data.b
[12](vscode-notebook-cell:?execution_count=3&line=12) data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
---> [13](vscode-notebook-cell:?execution_count=3&line=13) result = torch.vmap(AplusB, chunk_size=1)(data)
[14](vscode-notebook-cell:?execution_count=3&line=14) print(result)
File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
[202](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:202) def wrapped(*args, **kwargs):
--> [203](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203) return vmap_impl(
[204](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:204) func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
[205](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:205) )
File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
[312](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:312) batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
[313](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:313) in_dims, args, func
[314](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:314) )
[316](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:316) if chunk_size is not None:
--> [317](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317) chunks_flat_args = _get_chunked_inputs(
[318](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:318) flat_args, flat_in_dims, batch_size, chunk_size
[319](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:319) )
[320](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:320) return _chunked_vmap(
[321](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:321) func,
[322](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:322) flat_in_dims,
(...)
[327](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:327) **kwargs,
[328](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:328) )
[330](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:330) # If chunk_size is not specified.
File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359, in _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)
[356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356) chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
[357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357) split_idxs = tuple(itertools.accumulate(chunk_sizes))
--> [359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
[360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360) t.tensor_split(split_idxs, dim=in_dim)
[361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361) if in_dim is not None
[362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362) else [
[363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363) t,
[364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364) ]
[365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365) * len(split_idxs)
[366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366) for t, in_dim in zip(flat_args, flat_in_dims)
[367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
[369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
[370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
[371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)
File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360, in <genexpr>(.0)
[356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356) chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
[357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357) split_idxs = tuple(itertools.accumulate(chunk_sizes))
[359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
--> [360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360) t.tensor_split(split_idxs, dim=in_dim)
[361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361) if in_dim is not None
[362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362) else [
[363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363) t,
[364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364) ]
[365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365) * len(split_idxs)
[366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366) for t, in_dim in zip(flat_args, flat_in_dims)
[367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
[369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
[370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
[371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)
File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098, in _getattr(self, item)
[1096](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1096) return out.data if hasattr(out, "data") else out.tolist()
[1097](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1097) return _wrap_method(self, item, out)
-> [1098](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098) raise AttributeError(item)
AttributeError: tensor_split
Expected behavior
The expected behavior is no error should be spit out.
Screenshots
nope.
System info
Describe the characteristic of your environment:
- Linux Mint 22
- conda
- python=3.12
- torch=2.5.1+cu124
Additional context
might be related to #823
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)