Skip to content

Commit

Permalink
clean and generalize clip transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 26, 2024
1 parent b9f76b0 commit 0a23d1e
Showing 1 changed file with 14 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,6 @@

public class ClipTransformation extends AbstractTensorPixelTransformation
{

private static final class ClipFunction implements FloatUnaryOperator
{

private final float min;

private final float max;

private ClipFunction( final double min, final double max )
{
this.min = (float) min;
this.max = (float) max;
}

@Override
public final float applyAsFloat( final float in )
{
return ( in > max )
? max
: ( in < min )
? min
: in;
}
}

private static String name = "clip";
private Double min;
Expand All @@ -59,6 +35,15 @@ public final float applyAsFloat( final float in )
public ClipTransformation()
{
super(name);
super.setFloatUnitaryOperator(v -> v >= max ? max.floatValue() : (v < min ? min.floatValue() : v));
super.setDoubleUnitaryOperator(v -> v >= max ? max.doubleValue() : (v < min ? min.doubleValue() : v));
super.setByteUnitaryOperator(v -> v >= max ? max.byteValue() : (v < min ? min.byteValue() : v));
super.setUByteUnitaryOperator(v -> v >= max ? max.intValue() : (v < min ? min.intValue() : v));
super.setShortUnitaryOperator(v -> v >= max ? max.shortValue() : (v < min ? min.shortValue() : v));
super.setUShortUnitaryOperator(v -> v >= max ? max.intValue() : (v < min ? min.intValue() : v));
super.setIntUnitaryOperator(v -> v >= max ? max.intValue() : (v < min ? min.intValue() : v));
super.setUIntUnitaryOperator(v -> v >= max ? max.longValue() : (v < min ? min.longValue() : v));
super.setLongUnitaryOperator(v -> v >= max ? max.longValue() : (v < min ? min.longValue() : v));
}

public void setMin(Object min) {
Expand Down Expand Up @@ -91,23 +76,22 @@ public void setMax(Object max) {

public void checkRequiredArgs() {
if (min == null) {
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, "min"));
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "min"));
} else if (max == null) {
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, "max"));
throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "max"));
}
}

public < R extends RealType< R > & NativeType< R > > Tensor< FloatType > apply( final Tensor< R > input )
{
checkRequiredArgs();
super.setFloatUnitaryOperator(new ClipFunction( min, max ) );
return super.apply(input);
}

public void applyInPlace( final Tensor< FloatType > input )
public < R extends RealType< R > & NativeType< R > >
void applyInPlace( final Tensor< R > input )
{
checkRequiredArgs();
super.setFloatUnitaryOperator(new ClipFunction( min, max ) );
super.apply(input);
super.applyInPlace(input);
}
}

0 comments on commit 0a23d1e

Please sign in to comment.