Skip to content

[BUG] torch.vmap fails when chunk_size is set to some positive integer. #1091

Open
@busFred

Description

@busFred

Describe the bug

torch.vmap seems to be incompatible with tensordict.TensorDictBase input when chunk_size is not None.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torch
from tensordict import tensorclass

@tensorclass
class Data:
    a: torch.Tensor
    b: torch.Tensor

def AplusB(data):
    return data.a+data.b

data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
result = torch.vmap(AplusB, chunk_size=1)(data)
print(result)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], [line 13](vscode-notebook-cell:?execution_count=3&line=13)
     [10](vscode-notebook-cell:?execution_count=3&line=10)     return data.a+data.b
     [12](vscode-notebook-cell:?execution_count=3&line=12) data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
---> [13](vscode-notebook-cell:?execution_count=3&line=13) result = torch.vmap(AplusB, chunk_size=1)(data)
     [14](vscode-notebook-cell:?execution_count=3&line=14) print(result)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    [202](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:202) def wrapped(*args, **kwargs):
--> [203](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203)     return vmap_impl(
    [204](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:204)         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    [205](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:205)     )

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    [312](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:312) batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
    [313](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:313)     in_dims, args, func
    [314](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:314) )
    [316](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:316) if chunk_size is not None:
--> [317](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317)     chunks_flat_args = _get_chunked_inputs(
    [318](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:318)         flat_args, flat_in_dims, batch_size, chunk_size
    [319](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:319)     )
    [320](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:320)     return _chunked_vmap(
    [321](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:321)         func,
    [322](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:322)         flat_in_dims,
   (...)
    [327](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:327)         **kwargs,
    [328](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:328)     )
    [330](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:330) # If chunk_size is not specified.

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359, in _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)
    [356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356)     chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
    [357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357)     split_idxs = tuple(itertools.accumulate(chunk_sizes))
--> [359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
    [360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360)     t.tensor_split(split_idxs, dim=in_dim)
    [361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361)     if in_dim is not None
    [362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362)     else [
    [363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363)         t,
    [364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364)     ]
    [365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365)     * len(split_idxs)
    [366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366)     for t, in_dim in zip(flat_args, flat_in_dims)
    [367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
    [369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
    [370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
    [371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360, in <genexpr>(.0)
    [356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356)     chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
    [357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357)     split_idxs = tuple(itertools.accumulate(chunk_sizes))
    [359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
--> [360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360)     t.tensor_split(split_idxs, dim=in_dim)
    [361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361)     if in_dim is not None
    [362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362)     else [
    [363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363)         t,
    [364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364)     ]
    [365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365)     * len(split_idxs)
    [366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366)     for t, in_dim in zip(flat_args, flat_in_dims)
    [367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
    [369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
    [370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
    [371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098, in _getattr(self, item)
   [1096](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1096)         return out.data if hasattr(out, "data") else out.tolist()
   [1097](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1097)     return _wrap_method(self, item, out)
-> [1098](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098) raise AttributeError(item)

AttributeError: tensor_split

Expected behavior

The expected behavior is no error should be spit out.

Screenshots

nope.

System info

Describe the characteristic of your environment:

  • Linux Mint 22
  • conda
  • python=3.12
  • torch=2.5.1+cu124

Additional context

might be related to #823

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions