Skip to content

Commit

Permalink
improve eficciecy of imglib2 creation from tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 26, 2023
1 parent 2f6f234 commit 0dc2d50
Showing 1 changed file with 48 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
*/
package io.bioimage.modelrunner.tensorflow.v1.tensor;

import io.bioimage.modelrunner.utils.IndexingUtils;
import io.bioimage.modelrunner.tensor.Utils;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;

import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
Expand All @@ -40,7 +40,7 @@
import org.tensorflow.types.UInt8;

/**
* A {@link Img} builder for TensorFlow {@link Tensor} objects.
* A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects.
* Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor})
* from Tensorflow 1 {@link Tensor}
*
Expand All @@ -54,153 +54,110 @@ public final class ImgLib2Builder {
private ImgLib2Builder() {}

/**
* Creates a {@link Img} from a given {@link Tensor}
* Creates a {@link RandomAccessibleInterval} from a given {@link Tensor}
*
* @param <T> the type of the image
* @param tensor The tensor data is read from.
* @return The Img built from the tensor.
* @return The RandomAccessibleInterval built from the tensor.
* @throws IllegalArgumentException If the tensor type is not supported.
*/
@SuppressWarnings("unchecked")
public static < T extends Type< T > > Img<T> build(Tensor<?> tensor)
public static < T extends Type< T > > RandomAccessibleInterval<T> build(Tensor<?> tensor)
throws IllegalArgumentException
{
// Create an Img of the same type of the tensor
switch (tensor.dataType()) {
case UINT8:
return (Img<T>) buildFromTensorUByte((Tensor<UInt8>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorUByte((Tensor<UInt8>) tensor);
case INT32:
return (Img<T>) buildFromTensorInt((Tensor<Integer>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorInt((Tensor<Integer>) tensor);
case FLOAT:
return (Img<T>) buildFromTensorFloat((Tensor<Float>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorFloat((Tensor<Float>) tensor);
case DOUBLE:
return (Img<T>) buildFromTensorDouble((Tensor<Double>) tensor);
return (RandomAccessibleInterval<T>) buildFromTensorDouble((Tensor<Double>) tensor);
default:
throw new IllegalArgumentException("Unsupported tensor type: " + tensor
.dataType());
}
}

/**
* Builds a {@link Img} from a unsigned byte-typed {@link Tensor}.
* Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link Tensor}.
*
* @param tensor The tensor data is read from.
* @return The Img built from the tensor, of type {@link UnsignedByteType}.
* @return The RandomAccessibleInterval built from the tensor, of type {@link UnsignedByteType}.
*/
private static Img<UnsignedByteType> buildFromTensorUByte(Tensor<UInt8> tensor) {
long[] tensorShape = tensor.shape();
final ArrayImgFactory<UnsignedByteType> factory = new ArrayImgFactory<>(new UnsignedByteType());
final Img<UnsignedByteType> outputImg = factory.create(tensorShape);
Cursor<UnsignedByteType> tensorCursor = outputImg.cursor();
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(Tensor<UInt8> tensor) {
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {
totalSize *= i;
}
for (long i : tensorShape) totalSize *= i;
byte[] flatArr = new byte[totalSize];
ByteBuffer outBuff = ByteBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
byte val = flatArr[flatPos];
if (val < 0)
tensorCursor.get().set(256 + (int) val);
else
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<UnsignedByteType> rai = ArrayImgs.unsignedBytes(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned integer-typed {@link Tensor}.
* Builds a {@link RandomAccessibleInterval} from a unsigned integer-typed {@link Tensor}.
*
* @param tensor The tensor data is read from.
* @return The sequence built from the tensor, of type {@link IntType}.
* @return The RandomAccessibleInterval built from the tensor, of type {@link IntType}.
*/
private static Img<IntType> buildFromTensorInt(Tensor<Integer> tensor) {
long[] tensorShape = tensor.shape();
final ArrayImgFactory<IntType> factory = new ArrayImgFactory<>(new IntType());
final Img<IntType> outputImg = factory.create(tensorShape);
Cursor<IntType> tensorCursor = outputImg.cursor();
private static RandomAccessibleInterval<IntType> buildFromTensorInt(Tensor<Integer> tensor) {
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {
totalSize *= i;
}
for (long i : tensorShape) totalSize *= i;
int[] flatArr = new int[totalSize];
IntBuffer outBuff = IntBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
int val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<IntType> rai = ArrayImgs.ints(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned float-typed {@link Tensor}.
* Builds a {@link RandomAccessibleInterval} from a unsigned float-typed {@link Tensor}.
*
* @param tensor The tensor data is read from.
* @return The Img built from the tensor, of type {@link FloatType}.
* @return The RandomAccessibleInterval built from the tensor, of type {@link FloatType}.
*/
private static Img<FloatType> buildFromTensorFloat(Tensor<Float> tensor) {
long[] tensorShape = tensor.shape();
final ArrayImgFactory<FloatType> factory = new ArrayImgFactory<>(new FloatType());
final Img<FloatType> outputImg = factory.create(tensorShape);
Cursor<FloatType> tensorCursor = outputImg.cursor();
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(Tensor<Float> tensor) {
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {
totalSize *= i;
}
for (long i : tensorShape) totalSize *= i;
float[] flatArr = new float[totalSize];
FloatBuffer outBuff = FloatBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
float val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
outBuff = null;;
RandomAccessibleInterval<FloatType> rai = ArrayImgs.floats(flatArr, tensorShape);
return Utils.transpose(rai);
}

/**
* Builds a {@link Img} from a unsigned double-typed {@link Tensor}.
* Builds a {@link RandomAccessibleInterval} from a unsigned double-typed {@link Tensor}.
*
* @param tensor The tensor data is read from.
* @return The Img built from the tensor, of type {@link DoubleType}.
* @return The RandomAccessibleInterval built from the tensor, of type {@link DoubleType}.
*/
private static Img<DoubleType> buildFromTensorDouble(Tensor<Double> tensor) {
long[] tensorShape = tensor.shape();
final ArrayImgFactory<DoubleType> factory = new ArrayImgFactory<>(new DoubleType());
final Img<DoubleType> outputImg = factory.create(tensorShape);
Cursor<DoubleType> tensorCursor = outputImg.cursor();
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(Tensor<Double> tensor) {
long[] arrayShape = tensor.shape();
long[] tensorShape = new long[arrayShape.length];
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
int totalSize = 1;
for (long i : tensorShape) {
totalSize *= i;
}
for (long i : tensorShape) totalSize *= i;
double[] flatArr = new double[totalSize];
DoubleBuffer outBuff = DoubleBuffer.wrap(flatArr);
tensor.writeTo(outBuff);
outBuff = null;
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
double val = flatArr[flatPos];
tensorCursor.get().set(val);
}
return outputImg;
RandomAccessibleInterval<DoubleType> rai = ArrayImgs.doubles(flatArr, tensorShape);
return Utils.transpose(rai);
}
}

0 comments on commit 0dc2d50

Please sign in to comment.