Skip to content

Commit

Permalink
axe numpy_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
bchess committed Apr 26, 2024
1 parent f200987 commit 03af4b8
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions tensorizer/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import collections.abc
import concurrent.futures
import contextlib
import ctypes
import dataclasses
import enum
import functools
Expand Down Expand Up @@ -3778,7 +3779,7 @@ def __init__(
self.module_index = module_index
self.tensor_type: TensorType = tensor_type
self.name = _TensorPath.wrap_(name)
self.dtype: Optional[str] = None # _prepare_for_write_numpy_tensor
self.dtype: Optional[str] = None # _prepare_for_write_dtype
self.shape = tensor.size()
self.data_length = tensor.nbytes
# self.file_offset # intentionally omitted, handled by _write_headers()
Expand All @@ -3789,9 +3790,6 @@ def __init__(
)

# Additional payloads that get set and used during the prepare_for_write procedures
self.numpy_tensor: Optional[_NumpyTensor] = (
None # $et in _prepare_for_write_numpy_tensor
)
self.header: Optional[_TensorHeaderSerializer] = (
None # $et in _prepare_for_write_headers
)
Expand All @@ -3805,6 +3803,13 @@ def __init__(
# They are often chained from one step of the process to the next
self.tensor_data_task: Optional[_Future] = None

@property
def tensor_memoryview(self) -> memoryview:
nbytes = self.tensor.element_size() * self.tensor.nelement()
return memoryview(
(ctypes.c_char * nbytes).from_address(self.tensor.data_ptr())
)

def set_min_file_version_number(self, version_number):
self.min_file_version = max(self.min_file_version, version_number)

Expand Down Expand Up @@ -3841,8 +3846,7 @@ def _bulk_write(self, write_specs: Iterable[_WriteSpec], incremental=False):
try:
self._prepare_for_write_contiguous(write_specs)
self._prepare_for_write_meta(write_specs)
self._prepare_for_write_numpy_tensor(write_specs)
self._prepare_for_write_opaque(write_specs)
self._prepare_for_write_dtype(write_specs)
if self._encrypted:
self._prepare_for_write_encryption(write_specs)
self._prepare_for_write_headers(write_specs)
Expand Down Expand Up @@ -4172,33 +4176,31 @@ def make_contiguous(write_spec, dependency):
make_contiguous, w, w.tensor_data_task
)

def _prepare_for_write_numpy_tensor(
self, write_specs: Sequence[_WriteSpec]
):
for w in write_specs:
# all futures are resolved here. This step is not multi-threaded.
if w.tensor_data_task is not None:
w.tensor_data_task.result(_TIMEOUT)
w.tensor_data_task = None
w.numpy_tensor = _NumpyTensor.from_tensor(w.tensor)
w.dtype = w.numpy_tensor.numpy_dtype
if w.numpy_tensor.data.data.nbytes != w.tensor.nbytes:
raise ValueError(
f"Cannot serialize tensor {w.name!r}:"
f" buffer size of underlying memory ({w.numpy_tensor.data.data.nbytes})"
f" doesn't match reported size ({w.tensor.nbytes})"
)
def _prepare_for_write_dtype(self, write_specs: Sequence[_WriteSpec]):
torch_dtype_to_numpy_dtype_cache: Dict[str, str] = {}

def _prepare_for_write_opaque(
self, write_specs: Sequence[_WriteSpec]
) -> None:
for w in write_specs:
if not w.numpy_tensor.is_opaque: # type: ignore
continue
# The datatype name needs to contain both the numpy dtype that the
# data is serialized as and the original torch dtype.
w.dtype += OPAQUE_DTYPE_SEP + w.numpy_tensor.torch_dtype # type: ignore
w.set_min_file_version_number(OPAQUE_TENSORIZER_VERSION)
tensor_dtype_str = str(w.tensor.dtype)
if _NumpyTensor._is_asymmetric(w.tensor.dtype):
# is opaque
w.dtype = (
f"<V{w.tensor.element_size():d}"
+ OPAQUE_DTYPE_SEP
+ tensor_dtype_str
)
w.set_min_file_version_number(OPAQUE_TENSORIZER_VERSION)
else:
w.dtype = torch_dtype_to_numpy_dtype_cache.get(
tensor_dtype_str, None
)
if w.dtype is None:
torch_dtype_to_numpy_dtype_cache[tensor_dtype_str] = (
w.dtype
) = (
torch.tensor((), dtype=w.tensor.dtype, device="cpu")
.numpy()
.dtype.str
)

@staticmethod
def _do_clone(write_spec, dependency: Optional[_Future]):
Expand Down Expand Up @@ -4250,7 +4252,6 @@ def _prepare_for_write_encryption(
w.tensor_data_task = _FutureGroup(clone_tasks)

for w in write_specs:
assert w.numpy_tensor is not None
w.include_crc32 = False

if w.data_length == 0:
Expand All @@ -4263,7 +4264,7 @@ def _prepare_for_write_encryption(
w.tensor_data_task.result(_TIMEOUT)
w.tensor_data_task = None

tensor_memory: memoryview = w.numpy_tensor.tensor_memory
tensor_memory: memoryview = w.tensor_memoryview
chunked = _Chunked(
total_size=tensor_memory.nbytes,
chunk_size=self._crypt_chunk_size,
Expand Down Expand Up @@ -4383,9 +4384,7 @@ def compute_crc32(
if dependency is not None:
dependency.result(_TIMEOUT)
header_crc32 = write_spec.header.compute_crc32()
crc32 = zlib.crc32(
write_spec.numpy_tensor.tensor_memory, header_crc32
)
crc32 = zlib.crc32(write_spec.tensor_memoryview, header_crc32)
write_spec.header.add_crc32(crc32)

def compute_sha256(
Expand All @@ -4395,7 +4394,7 @@ def compute_sha256(
if dependency is not None:
dependency.result(_TIMEOUT)
sha256 = write_spec.header.compute_sha256()
sha256.update(write_spec.numpy_tensor.tensor_memory)
sha256.update(write_spec.tensor_memoryview)
write_spec.header.add_sha256(sha256.digest())

for w in write_specs:
Expand Down Expand Up @@ -4480,7 +4479,7 @@ def commit_tensor_data(
bytes_written = 0
else:
bytes_written = self._pwrite(
write_spec.numpy_tensor.tensor_memory,
write_spec.tensor_memoryview,
write_spec.header.data_offset, # type: ignore
verify=write_spec.header.data_length, # type: ignore
)
Expand Down

0 comments on commit 03af4b8

Please sign in to comment.