Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedups & fixes #9

Merged
merged 2 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# v0.3.4

- Fix a data_dims access issue
- Marginally improve the speed of handling FoldedTensors in standard torch operations
- Use default torch types (e.g. `torch.float32` or `torch.torch64`)

# v0.3.3

- Handle empty inputs (e.g. `as_folded_tensor([[[], []], [[]]])`) by returning an empty tensor
Expand Down
115 changes: 77 additions & 38 deletions foldedtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch.autograd import Function

from . import _C

np_to_torch_dtype = {
torch.bool: bool,
torch.uint8: np.uint8,
Expand All @@ -19,9 +21,29 @@
torch.complex128: np.complex128,
}

from . import _C
pass_through_functions = {
torch.Tensor._grad.__get__,
torch.Tensor.grad,
torch.Tensor._base.__get__,
torch.Tensor.__repr__,
torch.Tensor.__str__,
torch.Tensor.__format__,
torch.Tensor.shape.__get__,
torch.Tensor.size.__get__,
torch.Tensor.dtype.__get__,
torch.Tensor.device.__get__,
}
if hasattr(torch._C, "TensorBase"):
pass_through_functions.add(torch._C.TensorBase.size)
else:
pass_through_functions.add(torch.Tensor.size)

try:
DisableTorchFunctionSubclass = torch._C.DisableTorchFunctionSubclass
except AttributeError:
DisableTorchFunctionSubclass = torch._C.DisableTorchFunction

__version__ = "0.3.3"
__version__ = "0.3.4"


# noinspection PyMethodOverriding
Expand Down Expand Up @@ -71,7 +93,6 @@ def backward(ctx, grad_output):
ctx.lengths,
ctx.old_data_dims,
)
# new_data_flat.index_put_({new_indexer}, old_data_flat.index_select(0, old_indexer));
shape_suffix = grad_output.shape[len(ctx.new_data_dims) :]
grad_input = torch.zeros(
(*shape_prefix, *shape_suffix), dtype=grad_output.dtype, device=device
Expand All @@ -80,20 +101,13 @@ def backward(ctx, grad_output):
-1, *shape_suffix
).index_select(0, ctx.output_indexer)
return grad_input, None
# return FoldedTensor(
# data=refolded_data,
# lengths=ctx.lengths,
# data_dims=ctx.old_data_dims,
# full_names=full_names,
# indexer=indexer,
# )


type_to_dtype_dict = {
int: torch.int64,
float: torch.float64,
int: torch.tensor([0]).dtype,
float: torch.tensor([0.0]).dtype,
bool: torch.bool,
None: torch.float64,
None: torch.tensor([0.0]).dtype,
}


Expand Down Expand Up @@ -151,16 +165,18 @@ def as_folded_tensor(
)
if (data_dims[-1] + 1) != len(full_names):
raise ValueError(
"The last dimension of `data_dims` must be the last variable dimension."
"The last dimension of `data_dims` must be the last "
"variable dimension."
)
elif full_names is not None:
data_dims = tuple(range(len(full_names)))
if isinstance(data, torch.Tensor) and lengths is not None:
data_dims = data_dims or tuple(range(len(lengths)))
np_indexer, shape = _C.make_refolding_indexer(lengths, data_dims)
assert shape == list(
data.shape[: len(data_dims)]
), f"Shape inferred from lengths is not compatible with data dims: {shape}, {data.shape}, {len(data_dims)}"
assert shape == list(data.shape[: len(data_dims)]), (
f"Shape inferred from lengths is not compatible with data dims: {shape}, "
f"{data.shape}, {len(data_dims)}"
)
result = FoldedTensor(
data=data,
lengths=lengths,
Expand Down Expand Up @@ -208,6 +224,23 @@ def as_folded_tensor(
return result


def _postprocess_func_result(result, input):
if (
input is not None
and input.shape[: len(input.data_dims)] != result.shape[: len(input.data_dims)]
):
return result

return FoldedTensor(
data=result,
lengths=input.lengths,
data_dims=input.data_dims,
full_names=input.full_names,
indexer=input.indexer,
mask=input._mask,
)


# noinspection PyUnresolvedReferences,PyInitNewSignature
class FoldedTensor(torch.Tensor):
"""
Expand Down Expand Up @@ -296,46 +329,52 @@ def to(self, *args, **kwargs):

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
result = super().__torch_function__(func, types, args, kwargs)
if kwargs is None:
kwargs = {}
if func in pass_through_functions:
with DisableTorchFunctionSubclass():
return func(*args, **kwargs)
with DisableTorchFunctionSubclass():
result = func(*args, **kwargs)

if func is torch.Tensor.share_memory_:
self = args[0]
self.indexer.share_memory_()
if self._mask is not None:
self._mask.share_memory_()

if not isinstance(result, torch.Tensor):
return result
return self

ft = None
for arg in (*args, *(kwargs or {}).values()):
for arg in (*args, *kwargs.values()):
if isinstance(arg, FoldedTensor):
assert (
ft is None or ft.data_dims == arg.data_dims
), "Cannot perform operation on FoldedTensors with different structure"
), "Cannot perform operation on FoldedTensors with different structures"
ft = arg
if isinstance(arg, (list, tuple)):
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, FoldedTensor):
assert (
ft is None or ft.data_dims == item.data_dims
), "Cannot perform operation on FoldedTensors with different structure"
assert ft is None or ft.data_dims == item.data_dims, (
"Cannot perform operation on FoldedTensors with "
"different structures"
)
ft = item

if isinstance(result, torch.Tensor):
return _postprocess_func_result(result, ft)

if (
ft is not None
and ft.shape[: len(ft.data_dims)] != result.shape[: len(ft.data_dims)]
isinstance(result, (tuple, list))
and len(result)
and isinstance(result[0], torch.Tensor)
):
return result.as_subclass(torch.Tensor)
return type(result)(
_postprocess_func_result(item, ft)
if isinstance(item, FoldedTensor)
else item
for item in result
)

result = FoldedTensor(
data=result,
lengths=ft.lengths,
data_dims=ft.data_dims,
full_names=ft.full_names,
indexer=ft.indexer,
mask=ft._mask,
)
return result

def clone(self):
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ readme = "README.md"
urls.homepage = "https://github.com/aphp/foldedtensor/"
urls.repository = "https://github.com/aphp/foldedtensor/"
dynamic = ["version"]
requires-python = ">3.7.6,<4.0"
requires-python = ">3.7.1,<4.0"

dependencies = [
"torch>1.0.0",
Expand Down Expand Up @@ -82,7 +82,6 @@ fix = true
exclude = [
".git",
"__pycache__",
"__init__.py",
".mypy_cache",
".pytest_cache",
".venv",
Expand All @@ -98,5 +97,6 @@ select = [
]
fixable = ["E", "F", "W", "I"]

[isort]
known-first-party = ["foldedtensor"]
[tool.ruff.isort]
known-first-party = ["foldedtensor"]
known-third-party = ["build"]
13 changes: 13 additions & 0 deletions tests/test_folded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,16 @@ def test_imbalanced_sequence_2():
)

assert "'int' object is not iterable" in str(e.value)


def test_max():
ft = as_folded_tensor(
[
[0, 1, 2],
[3, 4],
],
dtype=torch.float,
)
values, indices = ft.max(-1)
assert (values == torch.tensor([2, 4])).all()
assert (indices == torch.tensor([2, 1])).all()
Loading