diff --git a/docs/source/python.rst b/docs/source/python.rst index 28b5519e8..0806ff10d 100644 --- a/docs/source/python.rst +++ b/docs/source/python.rst @@ -112,4 +112,8 @@ can be a numpy array or a pointer to memory large enough to hold the decompressed data. Regardless if ``out`` is provided or its type, ``_decompress`` always returns a numpy array. If ``out`` is not provided, the array is allocated for the user, and if ``out`` is provided, then the returned -numpy is just a pointer to or wrapper around the user-supplied ``out``. +numpy is just a pointer to or wrapper around the user-supplied ``out``. If +``out`` is a numpy array, then the shape and type of the numpy array must match +the required arguments ``shape`` and ``ztype``. If you want to avoid this +constraint check, use ``out=ndarray.data`` rather than ``out=ndarray`` when +calling ``_decompress``. diff --git a/python/zfp.pyx b/python/zfp.pyx index e673f6ec5..d0e4f3305 100644 --- a/python/zfp.pyx +++ b/python/zfp.pyx @@ -286,15 +286,38 @@ cpdef np.ndarray _decompress( if out is None: output = np.asarray(_decompress_with_view(field, stream)) else: + dtype = zfp.ztype_to_dtype(ztype) if isinstance(out, np.ndarray): output = out - else: - header_dtype = ztype_to_dtype(field[0]._type) - header_shape = (field[0].nw, field[0].nz, field[0].ny, field[0].nx) - header_shape = [x for x in header_shape if x > 0] - output = np.frombuffer(out, dtype=header_dtype) - output = output.reshape(header_shape) + # check that numpy and user-provided types match + if out.dtype != dtype: + raise ValueError( + "Out ndarray has dtype {} but decompression is using " + "{}. Use out=ndarray.data to avoid this check.".format( + out.dtype, + dtype + ) + ) + + # check that numpy and user-provided shape match + numpy_shape = out.shape + user_shape = [x for x in shape if x > 0] + if not all( + [x == y for x, y in + zip_longest(numpy_shape, user_shape) + ] + ): + raise ValueError( + "Out ndarray has shape {} but decompression is using " + "{}. Use out=ndarray.data to avoid this check.".format( + numpy_shape, + user_shape + ) + ) + else: + output = np.frombuffer(out, dtype=dtype) + output = output.reshape(shape) _decompress_with_user_array(field, stream, output.data)