Skip to content

Commit e47738c

Browse files
authored
Merge pull request #173 from SwayamInSync/logaddexp
2 parents 1b4c800 + 5356cc7 commit e47738c

File tree

4 files changed

+185
-2
lines changed

4 files changed

+185
-2
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,48 @@ quad_copysign(const Sleef_quad *in1, const Sleef_quad *in2)
584584
return Sleef_copysignq1(*in1, *in2);
585585
}
586586

587+
static inline Sleef_quad
588+
quad_logaddexp(const Sleef_quad *x, const Sleef_quad *y)
589+
{
590+
// logaddexp(x, y) = log(exp(x) + exp(y))
591+
// Numerically stable implementation: max(x, y) + log1p(exp(-abs(x - y)))
592+
593+
// Handle NaN
594+
if (Sleef_iunordq1(*x, *y)) {
595+
return Sleef_iunordq1(*x, *x) ? *x : *y;
596+
}
597+
598+
// Handle infinities
599+
// If both are -inf, result is -inf
600+
Sleef_quad neg_inf = Sleef_negq1(QUAD_POS_INF);
601+
if (Sleef_icmpeqq1(*x, neg_inf) && Sleef_icmpeqq1(*y, neg_inf)) {
602+
return neg_inf;
603+
}
604+
605+
// If either is +inf, result is +inf
606+
if (Sleef_icmpeqq1(*x, QUAD_POS_INF) || Sleef_icmpeqq1(*y, QUAD_POS_INF)) {
607+
return QUAD_POS_INF;
608+
}
609+
610+
// If one is -inf, result is the other value
611+
if (Sleef_icmpeqq1(*x, neg_inf)) {
612+
return *y;
613+
}
614+
if (Sleef_icmpeqq1(*y, neg_inf)) {
615+
return *x;
616+
}
617+
618+
// Numerically stable computation
619+
Sleef_quad diff = Sleef_subq1_u05(*x, *y);
620+
Sleef_quad abs_diff = Sleef_fabsq1(diff);
621+
Sleef_quad neg_abs_diff = Sleef_negq1(abs_diff);
622+
Sleef_quad exp_term = Sleef_expq1_u10(neg_abs_diff);
623+
Sleef_quad log1p_term = Sleef_log1pq1_u10(exp_term);
624+
625+
Sleef_quad max_val = Sleef_icmpgtq1(*x, *y) ? *x : *y;
626+
return Sleef_addq1_u05(max_val, log1p_term);
627+
}
628+
587629
// Binary long double operations
588630
typedef long double (*binary_op_longdouble_def)(const long double *, const long double *);
589631

@@ -680,6 +722,43 @@ ld_copysign(const long double *in1, const long double *in2)
680722
return copysignl(*in1, *in2);
681723
}
682724

725+
static inline long double
726+
ld_logaddexp(const long double *x, const long double *y)
727+
{
728+
// logaddexp(x, y) = log(exp(x) + exp(y))
729+
// Numerically stable implementation: max(x, y) + log1p(exp(-abs(x - y)))
730+
731+
// Handle NaN
732+
if (isnan(*x) || isnan(*y)) {
733+
return isnan(*x) ? *x : *y;
734+
}
735+
736+
// Handle infinities
737+
// If both are -inf, result is -inf
738+
if (isinf(*x) && *x < 0 && isinf(*y) && *y < 0) {
739+
return -INFINITY;
740+
}
741+
742+
// If either is +inf, result is +inf
743+
if ((isinf(*x) && *x > 0) || (isinf(*y) && *y > 0)) {
744+
return INFINITY;
745+
}
746+
747+
// If one is -inf, result is the other value
748+
if (isinf(*x) && *x < 0) {
749+
return *y;
750+
}
751+
if (isinf(*y) && *y < 0) {
752+
return *x;
753+
}
754+
755+
// Numerically stable computation
756+
long double diff = *x - *y;
757+
long double abs_diff = fabsl(diff);
758+
long double max_val = (*x > *y) ? *x : *y;
759+
return max_val + log1pl(expl(-abs_diff));
760+
}
761+
683762
// comparison quad functions
684763
typedef npy_bool (*cmp_quad_def)(const Sleef_quad *, const Sleef_quad *);
685764

quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,5 +240,8 @@ init_quad_binary_ops(PyObject *numpy)
240240
if (create_quad_binary_ufunc<quad_copysign, ld_copysign>(numpy, "copysign") < 0) {
241241
return -1;
242242
}
243+
if (create_quad_binary_ufunc<quad_logaddexp, ld_logaddexp>(numpy, "logaddexp") < 0) {
244+
return -1;
245+
}
243246
return 0;
244247
}

quaddtype/release_tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
| multiply |||
1111
| matmul |||
1212
| divide |||
13-
| logaddexp | | |
13+
| logaddexp | | |
1414
| logaddexp2 | | |
1515
| true_divide | | |
1616
| floor_divide | | |

quaddtype/tests/test_quaddtype.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def test_logarithmic_functions(op, val):
367367
# Check sign for zero results
368368
if float_result == 0.0:
369369
assert np.signbit(float_result) == np.signbit(
370-
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
370+
quad_result), f"Zero sign mismatch"
371371

372372

373373
@pytest.mark.parametrize("val", [
@@ -390,6 +390,7 @@ def test_logarithmic_functions(op, val):
390390
])
391391
def test_log1p(val):
392392
"""Comprehensive test for log1p function"""
393+
op = "log1p"
393394
quad_val = QuadPrecision(val)
394395
float_val = float(val)
395396

@@ -427,6 +428,106 @@ def test_log1p(val):
427428
assert np.signbit(float_result) == np.signbit(
428429
quad_result), f"Zero sign mismatch for {op}({val})"
429430

431+
432+
@pytest.mark.parametrize("x", [
433+
# Regular values
434+
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",
435+
# Large values (test numerical stability)
436+
"100.0", "1000.0", "-100.0", "-1000.0",
437+
# Small values
438+
"1e-10", "-1e-10", "1e-20", "-1e-20",
439+
# Special values
440+
"inf", "-inf", "nan", "-nan", "-0.0"
441+
])
442+
@pytest.mark.parametrize("y", [
443+
# Regular values
444+
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",
445+
# Large values
446+
"100.0", "1000.0", "-100.0", "-1000.0",
447+
# Small values
448+
"1e-10", "-1e-10", "1e-20", "-1e-20",
449+
# Special values
450+
"inf", "-inf", "nan", "-nan", "-0.0"
451+
])
452+
def test_logaddexp(x, y):
453+
"""Comprehensive test for logaddexp function: log(exp(x) + exp(y))"""
454+
quad_x = QuadPrecision(x)
455+
quad_y = QuadPrecision(y)
456+
float_x = float(x)
457+
float_y = float(y)
458+
459+
quad_result = np.logaddexp(quad_x, quad_y)
460+
float_result = np.logaddexp(float_x, float_y)
461+
462+
# Handle NaN cases
463+
if np.isnan(float_result):
464+
assert np.isnan(float(quad_result)), \
465+
f"Expected NaN for logaddexp({x}, {y}), got {float(quad_result)}"
466+
return
467+
468+
# Handle infinity cases
469+
if np.isinf(float_result):
470+
assert np.isinf(float(quad_result)), \
471+
f"Expected inf for logaddexp({x}, {y}), got {float(quad_result)}"
472+
if not np.isnan(float_result):
473+
assert np.sign(float_result) == np.sign(float(quad_result)), \
474+
f"Infinity sign mismatch for logaddexp({x}, {y})"
475+
return
476+
477+
# For finite results, check with appropriate tolerance
478+
# logaddexp is numerically sensitive, especially for large differences
479+
if abs(float_x - float_y) > 50:
480+
# When values differ greatly, result should be close to max(x, y)
481+
rtol = 1e-10
482+
atol = 1e-10
483+
else:
484+
rtol = 1e-13
485+
atol = 1e-15
486+
487+
np.testing.assert_allclose(
488+
float(quad_result), float_result,
489+
rtol=rtol, atol=atol,
490+
err_msg=f"Value mismatch for logaddexp({x}, {y})"
491+
)
492+
493+
494+
def test_logaddexp_special_properties():
495+
"""Test special mathematical properties of logaddexp"""
496+
# logaddexp(x, x) = x + log(2)
497+
x = QuadPrecision("2.0")
498+
result = np.logaddexp(x, x)
499+
expected = float(x) + np.log(2.0)
500+
np.testing.assert_allclose(float(result), expected, rtol=1e-14)
501+
502+
# logaddexp(x, -inf) = x
503+
x = QuadPrecision("5.0")
504+
result = np.logaddexp(x, QuadPrecision("-inf"))
505+
np.testing.assert_allclose(float(result), float(x), rtol=1e-14)
506+
507+
# logaddexp(-inf, x) = x
508+
result = np.logaddexp(QuadPrecision("-inf"), x)
509+
np.testing.assert_allclose(float(result), float(x), rtol=1e-14)
510+
511+
# logaddexp(-inf, -inf) = -inf
512+
result = np.logaddexp(QuadPrecision("-inf"), QuadPrecision("-inf"))
513+
assert np.isinf(float(result)) and float(result) < 0
514+
515+
# logaddexp(inf, anything) = inf
516+
result = np.logaddexp(QuadPrecision("inf"), QuadPrecision("100.0"))
517+
assert np.isinf(float(result)) and float(result) > 0
518+
519+
# logaddexp(anything, inf) = inf
520+
result = np.logaddexp(QuadPrecision("100.0"), QuadPrecision("inf"))
521+
assert np.isinf(float(result)) and float(result) > 0
522+
523+
# Commutativity: logaddexp(x, y) = logaddexp(y, x)
524+
x = QuadPrecision("3.0")
525+
y = QuadPrecision("5.0")
526+
result1 = np.logaddexp(x, y)
527+
result2 = np.logaddexp(y, x)
528+
np.testing.assert_allclose(float(result1), float(result2), rtol=1e-14)
529+
530+
430531
def test_inf():
431532
assert QuadPrecision("inf") > QuadPrecision("1e1000")
432533
assert np.signbit(QuadPrecision("inf")) == 0

0 commit comments

Comments
 (0)