Open
Description
Describe the bug
Defining a __setattr__
function inside a @tensorclass
seems to make the class unable to traverse its fields in some cases (the fields are not shown when printing an instance).
(Probably) related to this is that when defining __setattr__
, auto_batch_size_
and auto_device_
do not work (I suppose it uses the same logic to traverse the fields as print)
To Reproduce
import torch
@tensorclass()
class PoC:
boring_variable: Tensor | None = None
@tensorclass()
class PoC_bug:
boring_variable: Tensor | None = None
def __setattr__(self, name, value):
super().__setattr__(name, value)
poc = PoC()
poc_bug = PoC_bug()
poc.boring_variable = torch.rand(2, 3)
poc_bug.boring_variable = torch.rand(2, 3)
print(poc)
"""
PoC(
boring_variable=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
print(poc_bug)
"""
PoC_bug(
,
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
assert poc.batch_size == torch.Size([])
assert poc_bug.batch_size == torch.Size([])
assert poc.device is None
assert poc_bug.device is None
poc.auto_batch_size_(1)
poc_bug.auto_batch_size_(1)
assert poc.batch_size == torch.Size([2])
assert poc_bug.batch_size == torch.Size([])
poc.auto_device_()
poc_bug.auto_device_()
assert poc.device == torch.device("cpu")
assert poc_bug.device is None
Expected behavior
PoC
and PoC_bug
classes should behave the same.
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)