Skip to content

Commit

Permalink
Merge pull request numpy#26372 from seberg/can-cast-numpy-scalar
Browse files Browse the repository at this point in the history
BUG: Make sure that NumPy scalars are supported by can_cast
  • Loading branch information
charris authored May 6, 2024
2 parents 2e354ee + 2e02cb7 commit dd33199
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 10 deletions.
10 changes: 10 additions & 0 deletions doc/source/release/2.0.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,16 @@ to achieve the previous behavior.

(`gh-25712 <https://github.com/numpy/numpy/pull/25712>`__)

``np.can_cast`` cannot be called on Python int, float, or complex
-----------------------------------------------------------------
``np.can_cast`` cannot be called with Python int, float, or complex instances
anymore. This is because NEP 50 means that the result of ``can_cast`` must
not depend on the value passed in.
Unfortunately, for Python scalars whether a cast should be considered
``"same_kind"`` or ``"safe"`` may depend on the context and value so that
this is currently not implemented.
In some cases, this means you may have to add a specific path for:
``if type(obj) in (int, float, complex): ...``.


**Content from release note snippets in doc/release/upcoming_changes:**
Expand Down
3 changes: 2 additions & 1 deletion numpy/_core/src/multiarray/descriptor.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "conversion_utils.h" /* for PyArray_TypestrConvert */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "descriptor.h"
#include "multiarraymodule.h"
#include "alloc.h"
#include "assert.h"
#include "npy_buffer.h"
Expand Down Expand Up @@ -2701,7 +2702,7 @@ arraydescr_reduce(PyArray_Descr *self, PyObject *NPY_UNUSED(args))
Py_DECREF(ret);
return NULL;
}
obj = PyObject_GetAttrString(mod, "dtype");
obj = PyObject_GetAttr(mod, npy_ma_str_dtype);
Py_DECREF(mod);
if (obj == NULL) {
Py_DECREF(ret);
Expand Down
44 changes: 35 additions & 9 deletions numpy/_core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3488,6 +3488,36 @@ array_can_cast_safely(PyObject *NPY_UNUSED(self),
if (PyArray_Check(from_obj)) {
ret = PyArray_CanCastArrayTo((PyArrayObject *)from_obj, d2, casting);
}
else if (PyArray_IsScalar(from_obj, Generic)) {
/*
* TODO: `PyArray_IsScalar` should not be required for new dtypes.
* weak-promotion branch is in practice identical to dtype one.
*/
if (npy_promotion_state == NPY_USE_WEAK_PROMOTION) {
PyObject *descr = PyObject_GetAttr(from_obj, npy_ma_str_dtype);
if (descr == NULL) {
goto finish;
}
if (!PyArray_DescrCheck(descr)) {
Py_DECREF(descr);
PyErr_SetString(PyExc_TypeError,
"numpy_scalar.dtype did not return a dtype instance.");
goto finish;
}
ret = PyArray_CanCastTypeTo((PyArray_Descr *)descr, d2, casting);
Py_DECREF(descr);
}
else {
/* need to convert to object to consider old value-based logic */
PyArrayObject *arr;
arr = (PyArrayObject *)PyArray_FROM_O(from_obj);
if (arr == NULL) {
goto finish;
}
ret = PyArray_CanCastArrayTo(arr, d2, casting);
Py_DECREF(arr);
}
}
else if (PyArray_IsPythonNumber(from_obj)) {
PyErr_SetString(PyExc_TypeError,
"can_cast() does not support Python ints, floats, and "
Expand All @@ -3496,15 +3526,6 @@ array_can_cast_safely(PyObject *NPY_UNUSED(self),
"explicitly allow them again in the future.");
goto finish;
}
else if (PyArray_IsScalar(from_obj, Generic)) {
PyArrayObject *arr;
arr = (PyArrayObject *)PyArray_FROM_O(from_obj);
if (arr == NULL) {
goto finish;
}
ret = PyArray_CanCastArrayTo(arr, d2, casting);
Py_DECREF(arr);
}
/* Otherwise use CanCastTypeTo */
else {
if (!PyArray_DescrConverter2(from_obj, &d1) || d1 == NULL) {
Expand Down Expand Up @@ -4772,6 +4793,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_convert = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_preserve = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_convert_if_no_array = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_cpu = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_dtype = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_err_msg_substr = NULL;

static int
Expand Down Expand Up @@ -4850,6 +4872,10 @@ intern_strings(void)
if (npy_ma_str_cpu == NULL) {
return -1;
}
npy_ma_str_dtype = PyUnicode_InternFromString("dtype");
if (npy_ma_str_dtype == NULL) {
return -1;
}
npy_ma_str_array_err_msg_substr = PyUnicode_InternFromString(
"__array__() got an unexpected keyword argument 'copy'");
if (npy_ma_str_array_err_msg_substr == NULL) {
Expand Down
1 change: 1 addition & 0 deletions numpy/_core/src/multiarray/multiarraymodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_convert;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_preserve;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_convert_if_no_array;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_cpu;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_err_msg_substr;

#endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */
11 changes: 11 additions & 0 deletions numpy/_core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,17 @@ def test_can_cast_values(self):
assert_(np.can_cast(fi.min, dt))
assert_(np.can_cast(fi.max, dt))

@pytest.mark.parametrize("dtype",
list("?bhilqBHILQefdgFDG") + [rational])
def test_can_cast_scalars(self, dtype):
# Basic test to ensure that scalars are supported in can-cast
# (does not check behavior exhaustively).
dtype = np.dtype(dtype)
scalar = dtype.type(0)

assert np.can_cast(scalar, "int64") == np.can_cast(dtype, "int64")
assert np.can_cast(scalar, "float32", casting="unsafe")


# Custom exception class to test exception propagation in fromiter
class NIterError(Exception):
Expand Down

0 comments on commit dd33199

Please sign in to comment.