Skip to content

Commit

Permalink
fix: Mask over doubles
Browse files Browse the repository at this point in the history
NODATA in Double context was returning Integer.MinValue instead
  • Loading branch information
echeipesh committed Feb 25, 2023
1 parent 82de60b commit a3f9bc8
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ case class InverseMaskByDefined(targetTile: Expression, maskTile: Expression)
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
val result = maskEval(targetTile, mask,
{ (v, m) => if (isNoData(m)) v else NODATA },
{ (v, m) => if (isNoData(m)) v else NODATA }
{ (v, m) => if (isNoData(m)) v else Double.NaN }
)
toInternalRow(result, targetCtx)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, mask

val result = maskEval(targetTile, mask,
{ (v, m) => if (m != maskValue) NODATA else v },
{ (v, m) => if (m != maskValue) NODATA else v }
{ (v, m) => if (m != maskValue) Double.NaN else v }
)
toInternalRow(result, targetCtx)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ case class MaskByDefined(targetTile: Expression, maskTile: Expression)
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
val result = maskEval(targetTile, mask,
{ (v, m) => if (isNoData(m)) NODATA else v },
{ (v, m) => if (isNoData(m)) NODATA else v }
{ (v, m) => if (isNoData(m)) Double.NaN else v }
)
toInternalRow(result, targetCtx)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ case class MaskByValue(targetTile: Expression, maskTile: Expression, maskValue:

val result = maskEval(targetTile, mask,
{ (v, m) => if (m == maskValue) NODATA else v },
{ (v, m) => if (m == maskValue) NODATA else v }
{ (v, m) => if (m == maskValue) Double.NaN else v }
)
toInternalRow(result, targetCtx)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ case class MaskByValues(targetTile: Expression, maskTile: Expression, maskValues

val result = maskEval(targetTile, mask,
{ (v, m) => if (maskValues.contains(m)) NODATA else v },
{ (v, m) => if (maskValues.contains(m)) NODATA else v }
{ (v, m) => if (maskValues.contains(m)) Double.NaN else v }
)

toInternalRow(result, targetCtx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,31 @@ class MaskingFunctionsSpec extends TestEnvironment {
checkDocs("rf_mask_by_value")
}

it("should mask_by_value") {
val values = (0 to 16)
val tile: Tile = DoubleArrayTile(values.map(_.toDouble).toArray, 4, 4)
// array([[ 0, 1, 2, 3],
// [ 4, 5, 6, 7],
// [ 8, 9, 10, 11],
// [12, 13, 14, 15]])
val mask: Tile = IntArrayTile(values.map(x => x % 2 * 4).toArray, 4, 4)
// array([[0, 4, 0, 4],
// [0, 4, 0, 4],
// [0, 4, 0, 4],
// [0, 4, 0, 4]])

import spark.implicits._
val df = List((tile, mask)).toDF("tile", "mask")

val (maskedTile, inverseMaskedTile) = df.select(
rf_mask_by_value(col("tile"), col("mask"), lit(4), inverse=false).alias("m1"),
rf_mask_by_value(col("tile"), col("mask"), lit(4), inverse=true).alias("m2")
).as[(Tile, Tile)].first()

maskedTile.findMinMax shouldBe (0, 14)
inverseMaskedTile.findMinMax shouldBe (1, 15)
}

it("should mask by value for value 0.") {
import spark.implicits._
// maskingTile has -4, ND, and -15 values. Expect mask by value with 0 to not change the
Expand Down

0 comments on commit a3f9bc8

Please sign in to comment.