Skip to content

Commit

Permalink
python: re-add comparison of output metadata against user-provided
Browse files Browse the repository at this point in the history
  • Loading branch information
SteVwonder authored and salasoom committed Apr 17, 2019
1 parent 8d612d8 commit 8d64d04
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
6 changes: 5 additions & 1 deletion docs/source/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,8 @@ can be a numpy array or a pointer to memory large enough to hold the
decompressed data. Regardless if ``out`` is provided or its type,
``_decompress`` always returns a numpy array. If ``out`` is not provided, the
array is allocated for the user, and if ``out`` is provided, then the returned
numpy is just a pointer to or wrapper around the user-supplied ``out``.
numpy is just a pointer to or wrapper around the user-supplied ``out``. If
``out`` is a numpy array, then the shape and type of the numpy array must match
the required arguments ``shape`` and ``ztype``. If you want to avoid this
constraint check, use ``out=ndarray.data`` rather than ``out=ndarray`` when
calling ``_decompress``.
35 changes: 29 additions & 6 deletions python/zfp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,38 @@ cpdef np.ndarray _decompress(
if out is None:
output = np.asarray(_decompress_with_view(field, stream))
else:
dtype = zfp.ztype_to_dtype(ztype)
if isinstance(out, np.ndarray):
output = out
else:
header_dtype = ztype_to_dtype(field[0]._type)
header_shape = (field[0].nw, field[0].nz, field[0].ny, field[0].nx)
header_shape = [x for x in header_shape if x > 0]

output = np.frombuffer(out, dtype=header_dtype)
output = output.reshape(header_shape)
# check that numpy and user-provided types match
if out.dtype != dtype:
raise ValueError(
"Out ndarray has dtype {} but decompression is using "
"{}. Use out=ndarray.data to avoid this check.".format(
out.dtype,
dtype
)
)

# check that numpy and user-provided shape match
numpy_shape = out.shape
user_shape = [x for x in shape if x > 0]
if not all(
[x == y for x, y in
zip_longest(numpy_shape, user_shape)
]
):
raise ValueError(
"Out ndarray has shape {} but decompression is using "
"{}. Use out=ndarray.data to avoid this check.".format(
numpy_shape,
user_shape
)
)
else:
output = np.frombuffer(out, dtype=dtype)
output = output.reshape(shape)

_decompress_with_user_array(field, stream, <void *>output.data)

Expand Down

0 comments on commit 8d64d04

Please sign in to comment.