From bbeedd470ecac727c42e97648c0f27bfc312af30 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Fri, 22 Sep 2023 11:04:56 -0700 Subject: [PATCH] Added support for int4 casting to wider integers such as int8 Addes support to cast np.float32 and np.float64 into int4 PiperOrigin-RevId: 567667633 --- CHANGELOG.md | 8 ++++- ml_dtypes/__init__.py | 2 +- ml_dtypes/_src/int4_numpy.h | 71 +++++++++++++++++++++++++++++++++++-- pyproject.toml | 2 +- 4 files changed, 77 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a304fd1..49789d01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.3.1] - 2023-09-22 + +* Added support for int4 casting to wider integers such as int8 +* Addes support to cast np.float32 and np.float64 into int4 + ## [0.3.0] - 2023-09-19 * Dropped support for Python 3.8, following [NEP 29]. @@ -44,7 +49,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Initial release -[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...HEAD +[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.1...HEAD +[0.3.1]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...v0.3.1 [0.3.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.2.0...v0.3.0 [0.2.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.1.0...v0.2.0 [0.1.0]: https://github.com/jax-ml/ml_dtypes/releases/tag/v0.1.0 diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index 60c297c1..7546ba96 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '0.3.0' # Keep in sync with pyproject.toml:version +__version__ = '0.3.1' # Keep in sync with pyproject.toml:version __all__ = [ '__version__', 'bfloat16', diff --git a/ml_dtypes/_src/int4_numpy.h b/ml_dtypes/_src/int4_numpy.h index 7f23fbc1..dba8b9f9 100644 --- a/ml_dtypes/_src/int4_numpy.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef ML_DTYPES_INT4_NUMPY_H_ #define ML_DTYPES_INT4_NUMPY_H_ +#include + // Must be included first // clang-format off #include "_src/numpy.h" @@ -28,6 +30,8 @@ limitations under the License. namespace ml_dtypes { +constexpr char kOutOfRange[] = "out of range value cannot be converted to int4"; + template struct Int4TypeDescriptor { static int Dtype() { return npy_type; } @@ -114,8 +118,7 @@ bool CastToInt4(PyObject* arg, T* output) { } if (d < static_cast(T::lowest()) || d > static_cast(T::highest())) { - PyErr_SetString(PyExc_OverflowError, - "out of range value cannot be converted to int4"); + PyErr_SetString(PyExc_OverflowError, kOutOfRange); } *output = T(d); return true; @@ -131,9 +134,37 @@ bool CastToInt4(PyObject* arg, T* output) { if (PyArray_IsScalar(arg, Integer)) { int64_t v; PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64)); + + if (!(std::numeric_limits::min() <= v && + v <= std::numeric_limits::max())) { + PyErr_SetString(PyExc_OverflowError, kOutOfRange); + return false; + } *output = T(v); return true; } + if (PyArray_IsScalar(arg, Float)) { + float f; + PyArray_ScalarAsCtype(arg, &f); + if (!(std::numeric_limits::min() <= f && + f <= std::numeric_limits::max())) { + PyErr_SetString(PyExc_OverflowError, kOutOfRange); + return false; + } + *output = T(static_cast<::int8_t>(f)); + return true; + } + if (PyArray_IsScalar(arg, Double)) { + double d; + PyArray_ScalarAsCtype(arg, &d); + if (!(std::numeric_limits::min() <= d && + d <= std::numeric_limits::max())) { + PyErr_SetString(PyExc_OverflowError, kOutOfRange); + return false; + } + *output = T(static_cast<::int8_t>(d)); + return true; + } return false; } @@ -652,7 +683,41 @@ bool RegisterInt4Casts() { } // Safe casts from T to other types - // TODO(phawkins): add integer types + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT8, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT16, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT32, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT64, + NPY_NOSCALAR) < 0) { + return false; + } + + if (std::is_same_v) { + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT8, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT16, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT32, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT64, + NPY_NOSCALAR) < 0) { + return false; + } + } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_FLOAT, NPY_NOSCALAR) < 0) { return false; diff --git a/pyproject.toml b/pyproject.toml index 65780ed1..a353d010 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ml_dtypes" -version = "0.3.0" # Keep in sync with ml_dtypes/__init__.py:__version__ +version = "0.3.1" # Keep in sync with ml_dtypes/__init__.py:__version__ description = "" readme = "README.md" requires-python = ">=3.9"