Skip to content

Commit

Permalink
fix: handle empty inputs + better errors with varying deepness
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 14, 2024
1 parent d1f24cd commit d154ec6
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 17 deletions.
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# v0.3.3

- Handle empty inputs (e.g. `as_folded_tensor([[[], []], [[]]])`) by returning an empty tensor
- Correctly bubble errors when converting inputs with varying deepness (e.g. `as_folded_tensor([1, [2, 3]])`)

# v0.3.2

- Allow to use `as_folded_tensor` with no args, as a simple padding function
Expand Down
42 changes: 27 additions & 15 deletions foldedtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,20 @@ def backward(ctx, grad_output):

def get_metadata(nested_data):
item = None
deepness = 0

def rec(seq):
nonlocal item
def rec(seq, depth=0):
nonlocal item, deepness
if isinstance(seq, (list, tuple)):
depth += 1
deepness = max(deepness, depth)
for item in seq:
yield from (1 + res for res in rec(item))
yield from rec(item, depth)
else:
yield 0
yield

return next(rec(nested_data), 0), type(item)
next(rec(nested_data), 0)
return deepness, type(item)


def as_folded_tensor(
Expand Down Expand Up @@ -139,20 +143,24 @@ def as_folded_tensor(
device: Optional[Unit[str, torch.device]]
The device of the output tensor
"""
if data_dims is not None:
data_dims = tuple(
dim if isinstance(dim, int) else full_names.index(dim) for dim in data_dims
)
if (data_dims[-1] + 1) != len(full_names):
raise ValueError(
"The last dimension of `data_dims` must be the last variable dimension."
if full_names is not None:
if data_dims is not None:
data_dims = tuple(
dim if isinstance(dim, int) else full_names.index(dim)
for dim in data_dims
)
elif full_names is not None:
data_dims = tuple(range(len(full_names)))
if (data_dims[-1] + 1) != len(full_names):
raise ValueError(
"The last dimension of `data_dims` must be the last variable dimension."
)
elif full_names is not None:
data_dims = tuple(range(len(full_names)))
if isinstance(data, torch.Tensor) and lengths is not None:
data_dims = data_dims or tuple(range(len(lengths)))
np_indexer, shape = _C.make_refolding_indexer(lengths, data_dims)
assert shape == list(data.shape[: len(data_dims)])
assert shape == list(
data.shape[: len(data_dims)]
), f"Shape inferred from lengths is not compatible with data dims: {shape}, {data.shape}, {len(data_dims)}"
result = FoldedTensor(
data=data,
lengths=lengths,
Expand All @@ -165,6 +173,8 @@ def as_folded_tensor(
# raise ValueError("dtype must be provided when `data` is a sequence")
if data_dims is None or dtype is None:
deepness, inferred_dtype = get_metadata(data)
else:
deepness = len(full_names) if full_names is not None else len(data_dims)
if data_dims is None:
data_dims = tuple(range(deepness))
if dtype is None:
Expand All @@ -177,6 +187,8 @@ def as_folded_tensor(
)
indexer = torch.from_numpy(indexer)
padded = torch.from_numpy(padded)
# In case of empty sequences, lengths are not computed correctly
lengths = (list(lengths) + [[0]] * deepness)[:deepness]
result = FoldedTensor(
data=padded,
lengths=lengths,
Expand Down
10 changes: 8 additions & 2 deletions foldedtensor/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ nested_py_list_to_padded_np_array(
std::vector<std::tuple<std::vector<int64_t>, int64_t, PyObject *>> operations;

// Index from the data dimension to the dim in the theoretically fully padded list
std::vector<int64_t> data_dim_map(*std::max_element(data_dims.begin(), data_dims.end()) + 1, -1);
int64_t max_depth = 0;
if (data_dims.size() > 0) {
max_depth = *std::max_element(data_dims.begin(), data_dims.end()) + 1;
}
std::vector<int64_t> data_dim_map(max_depth, -1);
for (unsigned long i = 0; i < data_dims.size(); i++) {
data_dim_map[data_dims[i]] = i;
}
Expand Down Expand Up @@ -279,7 +283,9 @@ nested_py_list_to_padded_np_array(
// Set the element in the array and move to the next element
// Since array elements can be of any size, we use the element size (in
// bytes) to move from one element to the next
PyArray_SETITEM(padded_array.ptr(), array_ptr, items[i]);
if (PyArray_SETITEM(padded_array.ptr(), array_ptr, items[i]) < 0) {
throw py::error_already_set();
}
array_ptr += itemsize;

// Assign the current index to the indexer and move to the next element
Expand Down
37 changes: 37 additions & 0 deletions tests/test_folded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,40 @@ def test_share_memory(ft):
assert cloned.is_shared()
assert cloned.indexer.is_shared()
assert cloned.mask.is_shared()


def test_empty_sequence():
ft = as_folded_tensor(
[
[[], [], []],
[[], []],
],
dtype=torch.float,
)
assert ft.shape == (2, 3, 0)


def test_imbalanced_sequence_1():
with pytest.raises(ValueError) as e:
as_folded_tensor(
[
3,
[0, 1, 2],
],
dtype=torch.float,
)

assert "setting an array element with a sequence." in str(e.value)


def test_imbalanced_sequence_2():
with pytest.raises(TypeError) as e:
as_folded_tensor(
[
[0, 1, 2],
3,
],
dtype=torch.float,
)

assert "'int' object is not iterable" in str(e.value)

0 comments on commit d154ec6

Please sign in to comment.