Skip to content

Commit 92b8195

Browse files
Merge pull request #526 from pydata/add_complex
Add complex arguments for new boolean functions
2 parents 11b1951 + b93a4bb commit 92b8195

File tree

10 files changed

+215
-28
lines changed

10 files changed

+215
-28
lines changed

ADDFUNCS.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,23 @@ Add clauses to generate the FUNC_CODES from the ``functions.hpp`` header, making
171171
};
172172
#endif
173173
174+
Some functions (e.g. ``fmod``, ``isnan``) are not available in MKL, and so must be hard-coded here as well:
175+
176+
.. code-block:: cpp
177+
178+
#ifdef USE_VML
179+
/* no isnan, isfinite or isinf in VML */
180+
static void vdIsfinite(MKL_INT n, const double* x1, bool* dest)
181+
{
182+
MKL_INT j;
183+
for (j=0; j<n; j++) {
184+
dest[j] = isfinited(x1[j]);
185+
};
186+
};
187+
#endif
188+
189+
The complex case is slightlñy different (see other examples in the same file).
190+
174191
Add case handling to the ``check_program`` function
175192

176193
.. code-block:: cpp
@@ -219,4 +236,6 @@ In many cases this process will not be very smooth since one relies on the inter
219236

220237
* Depending on the new function signature (above all if the out type is different to the input types), one may have to edit the ``__init__`` function in the ``FuncNode`` class in ``expressions.py``.
221238

239+
* Functions which accept and/or return complex arguments must be added to the ``complex_functions.hpp`` file (take care when adding them in ``interpreter.cpp`` and ``interp_body.cpp``, since their signatures are usually a bit different).
240+
222241
* Depending on MSVC support, namespace clashes, casting problems, it may be necessary to make various changes to ``numexpr/numexpr_config.hpp`` and ``numexpr/msvc_function_stubs.hpp``. For example, in PR #523, non-clashing wrappers were introduced for ``isnan`` and ``isfinite`` since the float versions ``isnanf, isfinitef`` were inconsistently defined (and output ints) - depending on how strict the platform interpreter is, the implicit cast from int to bool was acceptable or not for example. In addition, the base functions were in different namespaces or had different names across platforms.

doc/user_guide.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
NumExpr 2.8 User Guide
1+
NumExpr 2.12 User Guide
22
======================
33

44
The NumExpr package supplies routines for the fast evaluation of
@@ -201,6 +201,8 @@ The next are the current supported set:
201201

202202
* :code:`where(bool, number1, number2): number` -- number1 if the bool condition
203203
is true, number2 otherwise.
204+
* :code:`{isinf, isnan, isfinite}(float|complex): bool` -- returns element-wise True
205+
for ``inf`` or ``NaN``, ``NaN``, not ``inf`` respectively.
204206
* :code:`{sin,cos,tan}(float|complex): float|complex` -- trigonometric sine,
205207
cosine or tangent.
206208
* :code:`{arcsin,arccos,arctan}(float|complex): float|complex` -- trigonometric

numexpr/complex_functions.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,4 +424,33 @@ nc_abs(std::complex<double> *x, std::complex<double> *r)
424424
r->imag(0);
425425
}
426426

427+
static bool
428+
nc_isinf(std::complex<double> *x)
429+
{
430+
double xr=x->real(), xi=x->imag();
431+
bool bi,br;
432+
bi = isinfd(xi);
433+
br = isinfd(xr);
434+
return bi || br;
435+
}
436+
437+
static bool
438+
nc_isnan(std::complex<double> *x)
439+
{
440+
double xr=x->real(), xi=x->imag();
441+
bool bi,br;
442+
bi = isnand(xi);
443+
br = isnand(xr);
444+
return bi || br;
445+
}
446+
447+
static bool
448+
nc_isfinite(std::complex<double> *x)
449+
{
450+
double xr=x->real(), xi=x->imag();
451+
bool bi,br;
452+
bi = isfinited(xi);
453+
br = isfinited(xr);
454+
return bi && br;
455+
}
427456
#endif // NUMEXPR_COMPLEX_FUNCTIONS_HPP

numexpr/functions.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,17 @@ FUNC_CCC(FUNC_CCC_LAST, NULL, NULL)
166166
#undef ELIDE_FUNC_CCC
167167
#undef FUNC_CCC
168168
#endif
169+
170+
// complex -> boolean functions
171+
#ifndef FUNC_BC
172+
#define ELIDE_FUNC_BC
173+
#define FUNC_BC(...)
174+
#endif // use wrappers as there is name collision with isnanf in std
175+
FUNC_BC(FUNC_ISNAN_BC, "isnan_bc", nc_isnan, vzIsnan)
176+
FUNC_BC(FUNC_ISFINITE_BC, "isfinite_bc", nc_isfinite, vzIsfinite)
177+
FUNC_BC(FUNC_ISINF_BC, "isinf_bc", nc_isinf, vzIsinf)
178+
FUNC_BC(FUNC_BC_LAST, NULL, NULL, NULL)
179+
#ifdef ELIDE_FUNC_BC
180+
#undef ELIDE_FUNC_BC
181+
#undef FUNC_BC
182+
#endif

numexpr/interp_body.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,17 @@
469469
VEC_ARG1(b_dest = functions_bd[arg2](d1));
470470
#endif
471471

472+
case OP_FUNC_BCN:
473+
#ifdef USE_VML
474+
VEC_ARG1_VML(functions_bc_vml[arg2](BLOCK_SIZE,
475+
(const MKL_Complex16*)x1, (bool*)dest));
476+
#else
477+
VEC_ARG1(ca.real(c1r);
478+
ca.imag(c1i);
479+
b_dest = functions_bc[arg2](&ca));
480+
#endif
481+
482+
472483
/* Reductions */
473484
case OP_SUM_IIN: VEC_ARG1(i_reduce += i1);
474485
case OP_SUM_LLN: VEC_ARG1(l_reduce += l1);

numexpr/interpreter.cpp

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,31 @@ FuncBFPtr functions_bf[] = {
220220
};
221221
#endif
222222

223+
#ifdef USE_VML
224+
/* no isnan, isfinite or isinf in VML */
225+
static void vfIsfinite(MKL_INT n, const float* x1, bool* dest)
226+
{
227+
MKL_INT j;
228+
for (j=0; j<n; j++) {
229+
dest[j] = isfinitef_(x1[j]);
230+
};
231+
};
232+
static void vfIsinf(MKL_INT n, const float* x1, bool* dest)
233+
{
234+
MKL_INT j;
235+
for (j=0; j<n; j++) {
236+
dest[j] = isinff_(x1[j]);
237+
};
238+
};
239+
static void vfIsnan(MKL_INT n, const float* x1, bool* dest)
240+
{
241+
MKL_INT j;
242+
for (j=0; j<n; j++) {
243+
dest[j] = isnanf_(x1[j]);
244+
};
245+
};
246+
#endif
247+
223248
#ifdef USE_VML
224249
typedef void (*FuncBFPtr_vml)(MKL_INT, const float*, bool*);
225250
FuncBFPtr_vml functions_bf_vml[] = {
@@ -236,6 +261,31 @@ FuncBDPtr functions_bd[] = {
236261
#undef FUNC_BD
237262
};
238263

264+
#ifdef USE_VML
265+
/* no isnan, isfinite or isinf in VML */
266+
static void vdIsfinite(MKL_INT n, const double* x1, bool* dest)
267+
{
268+
MKL_INT j;
269+
for (j=0; j<n; j++) {
270+
dest[j] = isfinited(x1[j]);
271+
};
272+
};
273+
static void vdIsinf(MKL_INT n, const double* x1, bool* dest)
274+
{
275+
MKL_INT j;
276+
for (j=0; j<n; j++) {
277+
dest[j] = isinfd(x1[j]);
278+
};
279+
};
280+
static void vdIsnan(MKL_INT n, const double* x1, bool* dest)
281+
{
282+
MKL_INT j;
283+
for (j=0; j<n; j++) {
284+
dest[j] = isnand(x1[j]);
285+
};
286+
};
287+
#endif
288+
239289
#ifdef USE_VML
240290
typedef void (*FuncBDPtr_vml)(MKL_INT, const double*, bool*);
241291
FuncBDPtr_vml functions_bd_vml[] = {
@@ -245,6 +295,47 @@ FuncBDPtr_vml functions_bd_vml[] = {
245295
};
246296
#endif
247297

298+
typedef bool (*FuncBCPtr)(std::complex<double>*);
299+
FuncBCPtr functions_bc[] = {
300+
#define FUNC_BC(fop, s, f, ...) f,
301+
#include "functions.hpp"
302+
#undef FUNC_BC
303+
};
304+
305+
#ifdef USE_VML
306+
/* no isnan, isfinite or isinf in VML */
307+
static void vzIsfinite(MKL_INT n, const MKL_Complex16* x1, bool* dest)
308+
{
309+
MKL_INT j;
310+
for (j=0; j<n; j++) {
311+
dest[j] = isfinited(x1[j].real) && isfinited(x1[j].imag);
312+
};
313+
};
314+
static void vzIsinf(MKL_INT n, const MKL_Complex16* x1, bool* dest)
315+
{
316+
MKL_INT j;
317+
for (j=0; j<n; j++) {
318+
dest[j] = isinfd(x1[j].real) || isinfd(x1[j].imag);
319+
};
320+
};
321+
static void vzIsnan(MKL_INT n, const MKL_Complex16* x1, bool* dest)
322+
{
323+
MKL_INT j;
324+
for (j=0; j<n; j++) {
325+
dest[j] = isnand(x1[j].real) || isnand(x1[j].imag);
326+
};
327+
};
328+
#endif
329+
330+
#ifdef USE_VML
331+
typedef void (*FuncBCPtr_vml)(MKL_INT, const MKL_Complex16[], bool*);
332+
FuncBCPtr_vml functions_bc_vml[] = {
333+
#define FUNC_BC(fop, s, f, f_vml) f_vml,
334+
#include "functions.hpp"
335+
#undef FUNC_BC
336+
};
337+
#endif
338+
248339
#ifdef USE_VML
249340
/* Fake vdConj function just for casting purposes inside numexpr */
250341
static void vdConj(MKL_INT n, const double* x1, double* dest)
@@ -517,7 +608,14 @@ check_program(NumExprObject *self)
517608
PyErr_Format(PyExc_RuntimeError, "invalid program: funccode out of range (%i) at %i", arg, argloc);
518609
return -1;
519610
}
520-
} else if (op >= OP_REDUCTION) {
611+
}
612+
else if (op == OP_FUNC_BCN) {
613+
if (arg < 0 || arg >= FUNC_BC_LAST) {
614+
PyErr_Format(PyExc_RuntimeError, "invalid program: funccode out of range (%i) at %i", arg, argloc);
615+
return -1;
616+
}
617+
}
618+
else if (op >= OP_REDUCTION) {
521619
;
522620
} else {
523621
PyErr_Format(PyExc_RuntimeError, "invalid program: internal checker error processing %i", argloc);

numexpr/interpreter.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ enum FuncBDCodes {
4242
#undef FUNC_BD
4343
};
4444

45+
enum FuncBCCodes {
46+
#define FUNC_BC(fop, ...) fop,
47+
#include "functions.hpp"
48+
#undef FUNC_BC
49+
};
50+
4551
enum FuncDDDCodes {
4652
#define FUNC_DDD(fop, ...) fop,
4753
#include "functions.hpp"

numexpr/module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,15 @@ PyInit_interpreter(void) {
508508
#define FUNC_DD(name, sname, ...) add_func(name, sname);
509509
#define FUNC_BF(name, sname, ...) add_func(name, sname);
510510
#define FUNC_BD(name, sname, ...) add_func(name, sname);
511+
#define FUNC_BC(name, sname, ...) add_func(name, sname);
511512
#define FUNC_DDD(name, sname, ...) add_func(name, sname);
512513
#define FUNC_CC(name, sname, ...) add_func(name, sname);
513514
#define FUNC_CCC(name, sname, ...) add_func(name, sname);
514515
#include "functions.hpp"
515516
#undef FUNC_CCC
516517
#undef FUNC_CC
517518
#undef FUNC_DDD
519+
#undef FUNC_BC
518520
#undef FUNC_BD
519521
#undef FUNC_BF
520522
#undef FUNC_DD

numexpr/opcodes.hpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -150,43 +150,44 @@ OPCODE(105, OP_CONTAINS_BSS, "contains_bss", Tb, Ts, Ts, T0)
150150
//Boolean outputs
151151
OPCODE(106, OP_FUNC_BDN, "func_bdn", Tb, Td, Tn, T0)
152152
OPCODE(107, OP_FUNC_BFN, "func_bfn", Tb, Tf, Tn, T0)
153+
OPCODE(108, OP_FUNC_BCN, "func_bcn", Tb, Tc, Tn, T0)
153154

154155
// Reductions always have to be at the end - parts of the code
155156
// use > OP_REDUCTION to decide whether operation is a reduction
156-
OPCODE(108, OP_REDUCTION, NULL, T0, T0, T0, T0)
157+
OPCODE(109, OP_REDUCTION, NULL, T0, T0, T0, T0)
157158

158159
/* Last argument in a reduction is the axis of the array the
159160
reduction should be applied along. */
160161

161-
OPCODE(109, OP_SUM_IIN, "sum_iin", Ti, Ti, Tn, T0)
162-
OPCODE(110, OP_SUM_LLN, "sum_lln", Tl, Tl, Tn, T0)
163-
OPCODE(111, OP_SUM_FFN, "sum_ffn", Tf, Tf, Tn, T0)
164-
OPCODE(112, OP_SUM_DDN, "sum_ddn", Td, Td, Tn, T0)
165-
OPCODE(113, OP_SUM_CCN, "sum_ccn", Tc, Tc, Tn, T0)
166-
167-
OPCODE(114, OP_PROD, NULL, T0, T0, T0, T0)
168-
OPCODE(115, OP_PROD_IIN, "prod_iin", Ti, Ti, Tn, T0)
169-
OPCODE(116, OP_PROD_LLN, "prod_lln", Tl, Tl, Tn, T0)
170-
OPCODE(117, OP_PROD_FFN, "prod_ffn", Tf, Tf, Tn, T0)
171-
OPCODE(118, OP_PROD_DDN, "prod_ddn", Td, Td, Tn, T0)
172-
OPCODE(119, OP_PROD_CCN, "prod_ccn", Tc, Tc, Tn, T0)
173-
174-
OPCODE(120, OP_MIN, NULL, T0, T0, T0, T0)
175-
OPCODE(121, OP_MIN_IIN, "min_iin", Ti, Ti, Tn, T0)
176-
OPCODE(122, OP_MIN_LLN, "min_lln", Tl, Tl, Tn, T0)
177-
OPCODE(123, OP_MIN_FFN, "min_ffn", Tf, Tf, Tn, T0)
178-
OPCODE(124, OP_MIN_DDN, "min_ddn", Td, Td, Tn, T0)
179-
180-
OPCODE(125, OP_MAX, NULL, T0, T0, T0, T0)
181-
OPCODE(126, OP_MAX_IIN, "max_iin", Ti, Ti, Tn, T0)
182-
OPCODE(127, OP_MAX_LLN, "max_lln", Tl, Tl, Tn, T0)
183-
OPCODE(128, OP_MAX_FFN, "max_ffn", Tf, Tf, Tn, T0)
184-
OPCODE(129, OP_MAX_DDN, "max_ddn", Td, Td, Tn, T0)
162+
OPCODE(110, OP_SUM_IIN, "sum_iin", Ti, Ti, Tn, T0)
163+
OPCODE(111, OP_SUM_LLN, "sum_lln", Tl, Tl, Tn, T0)
164+
OPCODE(112, OP_SUM_FFN, "sum_ffn", Tf, Tf, Tn, T0)
165+
OPCODE(113, OP_SUM_DDN, "sum_ddn", Td, Td, Tn, T0)
166+
OPCODE(114, OP_SUM_CCN, "sum_ccn", Tc, Tc, Tn, T0)
167+
168+
OPCODE(115, OP_PROD, NULL, T0, T0, T0, T0)
169+
OPCODE(116, OP_PROD_IIN, "prod_iin", Ti, Ti, Tn, T0)
170+
OPCODE(117, OP_PROD_LLN, "prod_lln", Tl, Tl, Tn, T0)
171+
OPCODE(118, OP_PROD_FFN, "prod_ffn", Tf, Tf, Tn, T0)
172+
OPCODE(119, OP_PROD_DDN, "prod_ddn", Td, Td, Tn, T0)
173+
OPCODE(120, OP_PROD_CCN, "prod_ccn", Tc, Tc, Tn, T0)
174+
175+
OPCODE(121, OP_MIN, NULL, T0, T0, T0, T0)
176+
OPCODE(122, OP_MIN_IIN, "min_iin", Ti, Ti, Tn, T0)
177+
OPCODE(123, OP_MIN_LLN, "min_lln", Tl, Tl, Tn, T0)
178+
OPCODE(124, OP_MIN_FFN, "min_ffn", Tf, Tf, Tn, T0)
179+
OPCODE(125, OP_MIN_DDN, "min_ddn", Td, Td, Tn, T0)
180+
181+
OPCODE(126, OP_MAX, NULL, T0, T0, T0, T0)
182+
OPCODE(127, OP_MAX_IIN, "max_iin", Ti, Ti, Tn, T0)
183+
OPCODE(128, OP_MAX_LLN, "max_lln", Tl, Tl, Tn, T0)
184+
OPCODE(129, OP_MAX_FFN, "max_ffn", Tf, Tf, Tn, T0)
185+
OPCODE(130, OP_MAX_DDN, "max_ddn", Td, Td, Tn, T0)
185186

186187
/*
187188
When we get to 255, will maybe have to change code again
188189
(change latin_1 encoding in necompiler.py, use something
189190
other than unsigned char for OPCODE table)
190191
*/
191192
/* Should be the last opcode */
192-
OPCODE(130, OP_END, NULL, T0, T0, T0, T0)
193+
OPCODE(131, OP_END, NULL, T0, T0, T0, T0)

numexpr/tests/test_numexpr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,11 @@ def test_bool_funcs(self):
709709
assert np.all(evaluate("isnan(a)") == np.isnan(a))
710710
assert np.all(evaluate("isfinite(a)") == np.isfinite(a))
711711
assert np.all(evaluate("isinf(a)") == np.isinf(a))
712+
a = a.astype(np.complex128)
713+
assert a.dtype == np.complex128
714+
assert np.all(evaluate("isnan(a)") == np.isnan(a))
715+
assert np.all(evaluate("isfinite(a)") == np.isfinite(a))
716+
assert np.all(evaluate("isinf(a)") == np.isinf(a))
712717

713718
if 'sparc' not in platform.machine():
714719
# Execution order set here so as to not use too many threads

0 commit comments

Comments
 (0)