Skip to content

Commit

Permalink
fix!: wavg() and wsum() Numeric functions return null when pr…
Browse files Browse the repository at this point in the history
…ovided a vector containing only `null` values (deephaven#5524)

* Initial commit of wavg and wsum Numeric output changes.

* Correct update_by rolling_wavg calculation.

* Correct update_by rolling_wavg calculation.
  • Loading branch information
lbooker42 authored Jun 4, 2024
1 parent 3209be3 commit fd08f4a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
21 changes: 21 additions & 0 deletions engine/function/src/templates/Numeric.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,7 @@ public class Numeric {
}

long vsum = 0;
long nullCount = 0;

try (
final ${pt.vectorIterator} vi = values.iterator();
Expand All @@ -2610,10 +2611,16 @@ public class Numeric {

if (!isNull(c) && !isNull(w)) {
vsum += c * (long) w;
} else {
nullCount++;
}
}
}

if (nullCount == values.size()) {
return NULL_LONG;
}

return vsum;
}
<#else>
Expand All @@ -2629,6 +2636,7 @@ public class Numeric {
}

double vsum = 0;
long nullCount = 0;

try (
final ${pt.vectorIterator} vi = values.iterator();
Expand Down Expand Up @@ -2660,10 +2668,16 @@ public class Numeric {
<#else>
vsum += c * (double) w;
</#if>
} else {
nullCount++;
}
}
}

if (nullCount == values.size()) {
return NULL_DOUBLE;
}

return vsum;
}
</#if>
Expand Down Expand Up @@ -2733,6 +2747,7 @@ public class Numeric {

double vsum = 0;
double wsum = 0;
long nullCount = 0;

try (
final ${pt.vectorIterator} vi = values.iterator();
Expand All @@ -2750,10 +2765,16 @@ public class Numeric {
if (!isNull(c) && !isNull(w)) {
vsum += c * w;
wsum += w;
} else {
nullCount++;
}
}
}

if (nullCount == values.size()) {
return NULL_DOUBLE;
}

return vsum / wsum;
}

Expand Down
12 changes: 12 additions & 0 deletions engine/function/src/templates/TestNumeric.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,10 @@ public class TestNumeric extends BaseArrayTestCase {
assertEquals(NULL_LONG, wsum((${pt.primitive}[])null, new ${pt2.primitive}[]{4,5,6}));
assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{1,2,3}, (${pt2.primitive}[])null));

assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}}));
assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{1,2,3}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}}));
assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{1,2,3}));

assertEquals(1*4+2*5+3*6, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3,${pt.null},5}), new ${pt2.primitive}[]{4,5,6,7,${pt2.null}}));
assertEquals(NULL_LONG, wsum((${pt.vector}) null, new ${pt2.primitive}[]{4,5,6}));
assertEquals(NULL_LONG, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3}), (${pt2.primitive}[])null));
Expand All @@ -1169,6 +1173,10 @@ public class TestNumeric extends BaseArrayTestCase {
assertEquals(NULL_DOUBLE, wsum((${pt.primitive}[])null, new ${pt2.primitive}[]{4,5,6}));
assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{1,2,3}, (${pt2.primitive}[])null));

assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}}));
assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{1,2,3}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}}));
assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{1,2,3}));

assertEquals(1.0*4.0+2.0*5.0+3.0*6.0, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3,${pt.null},5}), new ${pt2.primitive}[]{4,5,6,7,${pt2.null}}));
assertEquals(NULL_DOUBLE, wsum((${pt.vector}) null, new ${pt2.primitive}[]{4,5,6}));
assertEquals(NULL_DOUBLE, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3}), (${pt2.primitive}[])null));
Expand Down Expand Up @@ -1215,6 +1223,10 @@ public class TestNumeric extends BaseArrayTestCase {
assertEquals(NULL_DOUBLE, wavg((${pt.primitive}[])null, new ${pt2.primitive}[]{4,5,6}));
assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{1,2,3}, (${pt2.primitive}[])null));

assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}}));
assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{1,2,3}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}}));
assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{1,2,3}));

assertEquals((1.0*4.0+2.0*5.0+3.0*6.0)/(4.0+5.0+6.0), wavg(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3,${pt.null},5}), new ${pt2.primitive}[]{4,5,6,7,${pt2.null}}));
assertEquals(NULL_DOUBLE, wavg((${pt.vector}) null, new ${pt2.primitive}[]{4,5,6}));
assertEquals(NULL_DOUBLE, wavg(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3}), (${pt2.primitive}[])null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ public void pop(int count) {
@Override
public void writeToOutputChunk(int outIdx) {
if (windowValues.size() == nullCount) {
// Looks weird but returning NaN is consistent with Numeric#wavg and AggWAvg
outputValues.set(outIdx, Double.NaN);
outputValues.set(outIdx, NULL_DOUBLE);
} else {
final double weightedValSum = windowValues.evaluate();
final double weightSum = windowWeightValues.evaluate();
Expand Down

0 comments on commit fd08f4a

Please sign in to comment.