Skip to content

Commit bde6b11

Browse files
authored
Add torch.tensor to named_data_store
Differential Revision: D85992938 Pull Request resolved: #15504
1 parent 40b304f commit bde6b11

File tree

5 files changed

+151
-20
lines changed

5 files changed

+151
-20
lines changed

exir/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ runtime.python_library(
8686
],
8787
deps = [
8888
":scalar_type",
89+
":tensor",
8990
]
9091
)
9192

exir/_serialize/_named_data_store.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
# pyre-strict
88

99
import hashlib
10+
1011
from dataclasses import dataclass
12+
from typing import Dict, List, Optional, Union
1113

12-
# from dataclasses import dataclass
13-
from typing import Dict, List, Optional
14+
import torch
1415

1516
from executorch.exir._serialize.data_serializer import DataEntry
1617
from executorch.exir.tensor_layout import TensorLayout
@@ -137,7 +138,7 @@ def _add_named_data_to_map(
137138
def add_named_data(
138139
self,
139140
key: str,
140-
data: bytes,
141+
data: Union[bytes, torch.Tensor],
141142
alignment: Optional[int] = 1,
142143
external_tag: Optional[str] = None,
143144
tensor_layout: Optional[TensorLayout] = None,
@@ -146,7 +147,7 @@ def add_named_data(
146147
Adds a named blob to the NamedDataStore.
147148
Args:
148149
key (str): key associated with the data.
149-
data (bytes): Bytes being requested to be serialized.
150+
data (Union[bytes, torch.Tensor]): Union of bytes, or torch.Tensor to serialize. Note: if a tensor is passed, it must have contiguous memory layout. The tensor_layout will be inferred from the tensor and should not be passed in.
150151
alignment (int): alignment for bytes to be serialized with.
151152
external (Optional[str]): the external filename that this data is saved to.
152153
tensor_layout (Optional[TensorLayout]): layout of the tensor, if applicable.
@@ -161,14 +162,25 @@ def add_named_data(
161162
if alignment <= 0:
162163
raise ValueError(f"Alignment must be greater than 0, received {alignment}.")
163164

165+
if isinstance(data, torch.Tensor):
166+
real_tensor_layout = TensorLayout.from_tensor(data)
167+
if tensor_layout is not None and not (real_tensor_layout == tensor_layout):
168+
raise ValueError(
169+
f"Tensor {key} is a torch.Tensor, with tensor_layout {real_tensor_layout}. The provided tensor layout {tensor_layout} does not match."
170+
)
171+
tensor_layout = real_tensor_layout
172+
byte_data = bytes(data.untyped_storage())
173+
else:
174+
byte_data = data
175+
164176
if external_tag is None:
165177
self._add_named_data_to_map(
166-
key, data, alignment, self.pte_data, tensor_layout
178+
key, byte_data, alignment, self.pte_data, tensor_layout
167179
)
168180
else:
169181
self._add_named_data_to_map(
170182
key,
171-
data,
183+
byte_data,
172184
alignment,
173185
self.external_data.setdefault(external_tag, {}),
174186
tensor_layout,

exir/_serialize/test/test_named_data_store.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import unittest
1010

11+
import torch
12+
1113
from executorch.exir._serialize._named_data_store import NamedDataStore
1214
from executorch.exir._serialize.data_serializer import DataEntry
1315
from executorch.exir.scalar_type import ScalarType
@@ -36,6 +38,53 @@ def test_add(self) -> None:
3638
self.assertEqual(output.external_data["file1"]["key2"], DataEntry(1, 16, None))
3739
self.assertEqual(output.external_data["file1"]["key3"], DataEntry(2, 16, None))
3840

41+
def test_add_torch_tensor(self) -> None:
42+
store = NamedDataStore()
43+
t0 = torch.tensor([[1, 2], [3, 4]], dtype=torch.int)
44+
t1 = torch.randn(2, 3, 4, 5).contiguous(memory_format=torch.channels_last)
45+
46+
store.add_named_data("key0", t0, None, None)
47+
store.add_named_data("key1", t1, 16, None)
48+
49+
output = store.get_named_data_store_output()
50+
self.assertEqual(len(output.buffers), 2)
51+
self.assertEqual(output.buffers[0], bytes(t0.untyped_storage()))
52+
self.assertEqual(output.buffers[1], bytes(t1.untyped_storage()))
53+
54+
self.assertEqual(len(output.pte_data), 2)
55+
self.assertEqual(
56+
output.pte_data["key0"],
57+
DataEntry(0, 1, TensorLayout(ScalarType.INT, [2, 2], [0, 1])),
58+
)
59+
self.assertEqual(
60+
output.pte_data["key1"],
61+
DataEntry(
62+
1, 16, TensorLayout(ScalarType.FLOAT, [2, 3, 4, 5], [0, 2, 3, 1])
63+
),
64+
)
65+
self.assertEqual(len(output.external_data), 0)
66+
67+
def test_add_invalid_torch_tensor_layout(self) -> None:
68+
store = NamedDataStore()
69+
t0 = torch.tensor([[1, 2], [3, 4]], dtype=torch.int)
70+
# TensorLayout does not match the torch.tensor.
71+
self.assertRaises(
72+
ValueError,
73+
store.add_named_data,
74+
"key",
75+
t0,
76+
1,
77+
None,
78+
TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1]),
79+
)
80+
81+
def test_add_invalid_torch_tensor_dim_order(self) -> None:
82+
store = NamedDataStore()
83+
t0 = torch.randn(2, 3, 4, 5)
84+
t0 = t0.permute(0, 2, 3, 1)
85+
# Non-contiguous tensor is not supported.
86+
self.assertRaises(ValueError, store.add_named_data, "key", t0, 1, None)
87+
3988
def test_add_duplicate_name_and_data(self) -> None:
4089
store = NamedDataStore()
4190
store.add_named_data("key", b"data", None, None)

exir/tensor_layout.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from dataclasses import dataclass
1010
from typing import List
1111

12+
import torch
13+
1214
from executorch.exir.scalar_type import ScalarType
15+
from executorch.exir.tensor import dim_order_from_stride, scalar_type_enum
1316

1417

1518
# Note: keep this in sync with the TensorLayout definition in
@@ -19,3 +22,18 @@ class TensorLayout:
1922
scalar_type: ScalarType
2023
sizes: List[int]
2124
dim_order: List[int]
25+
26+
@classmethod
27+
def from_tensor(cls, tensor: torch.Tensor) -> "TensorLayout":
28+
if not (
29+
tensor.is_contiguous(memory_format=torch.contiguous_format)
30+
or tensor.is_contiguous(memory_format=torch.channels_last)
31+
):
32+
raise ValueError(
33+
"Tensor is not contiguous. Please call .contiguous() before creating the TensorLayout."
34+
)
35+
return TensorLayout(
36+
scalar_type=scalar_type_enum(tensor.dtype),
37+
sizes=list(tensor.shape),
38+
dim_order=list(dim_order_from_stride(tensor.stride())),
39+
)

extension/flat_tensor/test/test_serialize.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
import math
1111
import unittest
1212

13-
from typing import List, Optional
13+
from typing import Dict, List, Optional
14+
15+
import torch
1416

1517
from executorch.exir._serialize._cord import Cord
18+
from executorch.exir._serialize._named_data_store import NamedDataStore
1619

1720
from executorch.exir._serialize.data_serializer import (
1821
DataEntry,
@@ -90,6 +93,22 @@ def check_tensor_layout(
9093
self.assertEqual(expected.sizes, actual.sizes)
9194
self.assertEqual(expected.dim_order, actual.dim_order)
9295

96+
def _check_named_data_entries(
97+
self, reference: Dict[str, DataEntry], actual: Dict[str, DataEntry]
98+
) -> None:
99+
self.assertEqual(reference.keys(), actual.keys())
100+
SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison.
101+
for key in reference.keys():
102+
ref_entry = reference[key]
103+
actual_entry = actual[key]
104+
for field in dataclasses.fields(ref_entry):
105+
if field.name not in SKIP_FIELDS:
106+
self.assertEqual(
107+
getattr(ref_entry, field.name),
108+
getattr(actual_entry, field.name),
109+
f"Named data record {key}.{field.name} does not match.",
110+
)
111+
93112
def test_serialize(self) -> None:
94113
config = FlatTensorConfig()
95114
serializer: DataSerializer = FlatTensorSerializer(config)
@@ -245,19 +264,51 @@ def test_round_trip(self) -> None:
245264
f"Buffer at index {i} does not match.",
246265
)
247266

248-
self.assertEqual(
249-
TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys()
267+
self._check_named_data_entries(
268+
TEST_DATA_PAYLOAD.named_data, deserialized_payload.named_data
250269
)
251270

252-
SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison.
253-
for key in TEST_DATA_PAYLOAD.named_data.keys():
254-
reference = TEST_DATA_PAYLOAD.named_data[key]
255-
actual = deserialized_payload.named_data[key]
271+
def test_deserialize_to_named_data_store_output(self) -> None:
272+
store = NamedDataStore()
273+
external_tag = "model"
274+
275+
tensor_layout = TensorLayout(ScalarType.FLOAT, [1, 2], [0, 1])
276+
store.add_named_data(
277+
"key0",
278+
b"data0",
279+
alignment=1,
280+
external_tag=external_tag,
281+
tensor_layout=tensor_layout,
282+
)
283+
store.add_named_data(
284+
"key1",
285+
torch.tensor([[1, 2], [3, 4]], dtype=torch.float32),
286+
alignment=1,
287+
external_tag=external_tag,
288+
)
256289

257-
for field in dataclasses.fields(reference):
258-
if field.name not in SKIP_FIELDS:
259-
self.assertEqual(
260-
getattr(reference, field.name),
261-
getattr(actual, field.name),
262-
f"Named data record {key}.{field.name} does not match.",
263-
)
290+
output = store.get_named_data_store_output()
291+
self.assertEqual(len(output.buffers), 2)
292+
self.assertEqual(len(output.pte_data), 0)
293+
self.assertEqual(len(output.external_data), 1)
294+
self.assertEqual(len(output.external_data[external_tag]), 2)
295+
296+
# Serialize and deserialize.
297+
config = FlatTensorConfig()
298+
serializer: DataSerializer = FlatTensorSerializer(config)
299+
data_payload = DataPayload(
300+
buffers=output.buffers, named_data=output.external_data[external_tag]
301+
)
302+
serialized_data = serializer.serialize(data_payload)
303+
304+
output2 = serializer.deserialize_to_named_data_store_output(
305+
bytes(serialized_data), external_tag
306+
)
307+
308+
self.assertEqual(output.buffers, output2.buffers)
309+
self.assertEqual(len(output.pte_data), 0)
310+
self.assertEqual(len(output2.pte_data), 0)
311+
312+
self._check_named_data_entries(
313+
output.external_data[external_tag], output2.external_data[external_tag]
314+
)

0 commit comments

Comments
 (0)