Skip to content

Commit

Permalink
improve robustness when handling big tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 27, 2023
1 parent a7ea0c8 commit c7964a7
Showing 1 changed file with 52 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,34 @@
package io.bioimage.modelrunner.tensorflow.v1.tensor;

import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.CommonUtils;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.Arrays;

import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.blocks.PrimitiveBlocks;
import net.imglib2.img.Img;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;

/**
* A TensorFlow {@link Tensor} builder for {@link Img} and
* {@link io.bioimage.modelrunner.tensor.Tensor} objects.
Expand Down Expand Up @@ -79,8 +87,8 @@ Tensor<?> build(io.bioimage.modelrunner.tensor.Tensor<T> tensor) {
*/
public static Tensor<?> build(RandomAccessibleInterval<?> rai)
{
if (Util.getTypeFromInterval(rai) instanceof ByteType) {
return buildByte((RandomAccessibleInterval<ByteType>) rai);
if (Util.getTypeFromInterval(rai) instanceof UnsignedByteType) {
return buildUByte((RandomAccessibleInterval<UnsignedByteType>) rai);
}
else if (Util.getTypeFromInterval(rai) instanceof IntType) {
return buildInt((RandomAccessibleInterval<IntType>) rai);
Expand All @@ -105,20 +113,28 @@ else if (Util.getTypeFromInterval(rai) instanceof DoubleType) {
* @param tensor The {@link RandomAccessibleInterval} to be converted.
* @return The {@link Tensor} created from the sequence.
*/
private static Tensor<UInt8> buildByte(
RandomAccessibleInterval<ByteType> tensor)
private static Tensor<UInt8> buildUByte(
RandomAccessibleInterval<UnsignedByteType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final byte[] flatArr = new byte[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<UnsignedByteType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().getByte();
}
ByteBuffer buff = ByteBuffer.wrap(flatArr);
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, ogShape, buff);
return ndarray;
Expand All @@ -136,16 +152,24 @@ private static Tensor<Integer> buildInt(
RandomAccessibleInterval<IntType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final int[] flatArr = new int[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<IntType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
IntBuffer buff = IntBuffer.wrap(flatArr);
Tensor<Integer> ndarray = Tensor.create(ogShape, buff);
return ndarray;
Expand All @@ -163,16 +187,24 @@ private static Tensor<Float> buildFloat(
RandomAccessibleInterval<FloatType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final float[] flatArr = new float[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<FloatType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
FloatBuffer buff = FloatBuffer.wrap(flatArr);
Tensor<Float> ndarray = Tensor.create(ogShape, buff);
return ndarray;
Expand All @@ -190,16 +222,24 @@ private static Tensor<Double> buildDouble(
RandomAccessibleInterval<DoubleType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
if (CommonUtils.int32Overflows(ogShape))
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final double[] flatArr = new double[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );

Cursor<DoubleType> cursor = Views.flatIterable(tensor).cursor();
int i = 0;
while (cursor.hasNext()) {
cursor.fwd();
flatArr[i ++] = cursor.get().get();
}
DoubleBuffer buff = DoubleBuffer.wrap(flatArr);
Tensor<Double> ndarray = Tensor.create(ogShape, buff);
return ndarray;
Expand Down

0 comments on commit c7964a7

Please sign in to comment.