Skip to content

Commit

Permalink
Added support for int4 casting to wider integers such as int8
Browse files Browse the repository at this point in the history
Addes support to cast np.float32 and np.float64 into int4

PiperOrigin-RevId: 567667633
  • Loading branch information
ChromeHearts authored and The ml_dtypes Authors committed Sep 22, 2023
1 parent fc69958 commit bbeedd4
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 6 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.3.1] - 2023-09-22

* Added support for int4 casting to wider integers such as int8
* Addes support to cast np.float32 and np.float64 into int4

## [0.3.0] - 2023-09-19

* Dropped support for Python 3.8, following [NEP 29].
Expand All @@ -44,7 +49,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Initial release

[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...HEAD
[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.1...HEAD
[0.3.1]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...v0.3.1
[0.3.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.2.0...v0.3.0
[0.2.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.1.0...v0.2.0
[0.1.0]: https://github.com/jax-ml/ml_dtypes/releases/tag/v0.1.0
Expand Down
2 changes: 1 addition & 1 deletion ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '0.3.0' # Keep in sync with pyproject.toml:version
__version__ = '0.3.1' # Keep in sync with pyproject.toml:version
__all__ = [
'__version__',
'bfloat16',
Expand Down
71 changes: 68 additions & 3 deletions ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#ifndef ML_DTYPES_INT4_NUMPY_H_
#define ML_DTYPES_INT4_NUMPY_H_

#include <type_traits>

// Must be included first
// clang-format off
#include "_src/numpy.h"
Expand All @@ -28,6 +30,8 @@ limitations under the License.

namespace ml_dtypes {

constexpr char kOutOfRange[] = "out of range value cannot be converted to int4";

template <typename T>
struct Int4TypeDescriptor {
static int Dtype() { return npy_type; }
Expand Down Expand Up @@ -114,8 +118,7 @@ bool CastToInt4(PyObject* arg, T* output) {
}
if (d < static_cast<double>(T::lowest()) ||
d > static_cast<double>(T::highest())) {
PyErr_SetString(PyExc_OverflowError,
"out of range value cannot be converted to int4");
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
}
*output = T(d);
return true;
Expand All @@ -131,9 +134,37 @@ bool CastToInt4(PyObject* arg, T* output) {
if (PyArray_IsScalar(arg, Integer)) {
int64_t v;
PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64));

if (!(std::numeric_limits<T>::min() <= v &&
v <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(v);
return true;
}
if (PyArray_IsScalar(arg, Float)) {
float f;
PyArray_ScalarAsCtype(arg, &f);
if (!(std::numeric_limits<T>::min() <= f &&
f <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(static_cast<::int8_t>(f));
return true;
}
if (PyArray_IsScalar(arg, Double)) {
double d;
PyArray_ScalarAsCtype(arg, &d);
if (!(std::numeric_limits<T>::min() <= d &&
d <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(static_cast<::int8_t>(d));
return true;
}
return false;
}

Expand Down Expand Up @@ -652,7 +683,41 @@ bool RegisterInt4Casts() {
}

// Safe casts from T to other types
// TODO(phawkins): add integer types
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT8,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT32,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT64,
NPY_NOSCALAR) < 0) {
return false;
}

if (std::is_same_v<uint4, T>) {
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT8,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT32,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT64,
NPY_NOSCALAR) < 0) {
return false;
}
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
NPY_NOSCALAR) < 0) {
return false;
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ml_dtypes"
version = "0.3.0" # Keep in sync with ml_dtypes/__init__.py:__version__
version = "0.3.1" # Keep in sync with ml_dtypes/__init__.py:__version__
description = ""
readme = "README.md"
requires-python = ">=3.9"
Expand Down

0 comments on commit bbeedd4

Please sign in to comment.