From 5443c55575bb264ee2713e3228d688056bba7be7 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Wed, 20 Sep 2023 21:38:11 -0700 Subject: [PATCH] Add float8 & int4 numpy integration PiperOrigin-RevId: 567178215 --- ml_dtypes/_src/int4_numpy.h | 66 +++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/ml_dtypes/_src/int4_numpy.h b/ml_dtypes/_src/int4_numpy.h index 7f23fbc1..aed0f00a 100644 --- a/ml_dtypes/_src/int4_numpy.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -28,6 +28,8 @@ limitations under the License. namespace ml_dtypes { +constexpr char kOutOfRange[] = "out of range value cannot be converted to int4"; + template struct Int4TypeDescriptor { static int Dtype() { return npy_type; } @@ -114,8 +116,7 @@ bool CastToInt4(PyObject* arg, T* output) { } if (d < static_cast(T::lowest()) || d > static_cast(T::highest())) { - PyErr_SetString(PyExc_OverflowError, - "out of range value cannot be converted to int4"); + PyErr_SetString(PyExc_OverflowError, kOutOfRange); } *output = T(d); return true; @@ -131,9 +132,37 @@ bool CastToInt4(PyObject* arg, T* output) { if (PyArray_IsScalar(arg, Integer)) { int64_t v; PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64)); + + if (!(std::numeric_limits::min() <= v && + v <= std::numeric_limits::max())) { + PyErr_SetString(PyExc_OverflowError, kOutOfRange); + return false; + } *output = T(v); return true; } + if (PyArray_IsScalar(arg, Float)) { + float f; + PyArray_ScalarAsCtype(arg, &f); + if (!(std::numeric_limits::min() <= f && + f <= std::numeric_limits::max())) { + PyErr_SetString(PyExc_OverflowError, kOutOfRange); + return false; + } + *output = T(static_cast<::int8_t>(f)); + return true; + } + if (PyArray_IsScalar(arg, Double)) { + double d; + PyArray_ScalarAsCtype(arg, &d); + if (!(std::numeric_limits::min() <= d && + d <= std::numeric_limits::max())) { + PyErr_SetString(PyExc_OverflowError, kOutOfRange); + return false; + } + *output = T(static_cast<::int8_t>(d)); + return true; + } return false; } @@ -652,7 +681,38 @@ bool RegisterInt4Casts() { } // Safe casts from T to other types - // TODO(phawkins): add integer types + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT8, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT8, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT16, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT16, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT32, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT32, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_INT64, + NPY_NOSCALAR) < 0) { + return false; + } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_UINT64, + NPY_NOSCALAR) < 0) { + return false; + } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_FLOAT, NPY_NOSCALAR) < 0) { return false;