Skip to content

Commit

Permalink
improve tensor conversion from imglib2
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 26, 2023
1 parent 0dc2d50 commit e646c85
Showing 1 changed file with 56 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
*/
package io.bioimage.modelrunner.tensorflow.v1.tensor;

import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.IndexingUtils;

import java.nio.ByteBuffer;
Expand All @@ -29,6 +30,7 @@

import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.blocks.PrimitiveBlocks;
import net.imglib2.img.Img;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
Expand Down Expand Up @@ -103,36 +105,24 @@ else if (Util.getTypeFromInterval(rai) instanceof DoubleType) {
* {@link RandomAccessibleInterval} and the desired dimension order for the
* resulting tensor.
*
* @param imgTensor The {@link RandomAccessibleInterval} to be converted.
* @param tensor The {@link RandomAccessibleInterval} to be converted.
* @return The {@link Tensor} created from the sequence.
*/
private static Tensor<UInt8> buildByte(
RandomAccessibleInterval<ByteType> imgTensor)
RandomAccessibleInterval<ByteType> tensor)
{
long[] tensorShape = imgTensor.dimensionsAsLongArray();
Cursor<ByteType> tensorCursor;
if (imgTensor instanceof IntervalView) tensorCursor =
((IntervalView<ByteType>) imgTensor).cursor();
else if (imgTensor instanceof Img) tensorCursor =
((Img<ByteType>) imgTensor).cursor();
else throw new IllegalArgumentException("The data of the " + Tensor.class +
" has " + "to be an instance of " + Img.class + " or " +
IntervalView.class);
long flatSize = 1;
for (long dd : imgTensor.dimensionsAsLongArray()) {
flatSize *= dd;
}
byte[] flatArr = new byte[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
byte val = tensorCursor.get().getByte();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final byte[] flatArr = new byte[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
ByteBuffer buff = ByteBuffer.wrap(flatArr);
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, imgTensor
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, tensor
.dimensionsAsLongArray(), buff);
return ndarray;
}
Expand All @@ -142,36 +132,24 @@ else throw new IllegalArgumentException("The data of the " + Tensor.class +
* {@link RandomAccessibleInterval} and the desired dimension order for the
* resulting tensor.
*
* @param imgTensor The {@link RandomAccessibleInterval} to be converted.
* @param tensor The {@link RandomAccessibleInterval} to be converted.
* @return The {@link Tensor} created from the sequence.
*/
private static Tensor<Integer> buildInt(
RandomAccessibleInterval<IntType> imgTensor)
RandomAccessibleInterval<IntType> tensor)
{
long[] tensorShape = imgTensor.dimensionsAsLongArray();
Cursor<IntType> tensorCursor;
if (imgTensor instanceof IntervalView) tensorCursor =
((IntervalView<IntType>) imgTensor).cursor();
else if (imgTensor instanceof Img) tensorCursor = ((Img<IntType>) imgTensor)
.cursor();
else throw new IllegalArgumentException("The data of the " + Tensor.class +
" has " + "to be an instance of " + Img.class + " or " +
IntervalView.class);
long flatSize = 1;
for (long dd : imgTensor.dimensionsAsLongArray()) {
flatSize *= dd;
}
int[] flatArr = new int[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
int val = tensorCursor.get().getInt();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final int[] flatArr = new int[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
IntBuffer buff = IntBuffer.wrap(flatArr);
Tensor<Integer> ndarray = Tensor.create(imgTensor.dimensionsAsLongArray(),
Tensor<Integer> ndarray = Tensor.create(tensor.dimensionsAsLongArray(),
buff);
return ndarray;
}
Expand All @@ -181,76 +159,52 @@ else throw new IllegalArgumentException("The data of the " + Tensor.class +
* {@link RandomAccessibleInterval} and the desired dimension order for the
* resulting tensor.
*
* @param imgTensor The {@link RandomAccessibleInterval} to be converted.
* @param tensor The {@link RandomAccessibleInterval} to be converted.
* @return The {@link Tensor} created from the sequence.
*/
private static Tensor<Float> buildFloat(
RandomAccessibleInterval<FloatType> imgTensor)
RandomAccessibleInterval<FloatType> tensor)
{
long[] tensorShape = imgTensor.dimensionsAsLongArray();
Cursor<FloatType> tensorCursor;
if (imgTensor instanceof IntervalView) tensorCursor =
((IntervalView<FloatType>) imgTensor).cursor();
else if (imgTensor instanceof Img) tensorCursor =
((Img<FloatType>) imgTensor).cursor();
else throw new IllegalArgumentException("The data of the " + Tensor.class +
" has " + "to be an instance of " + Img.class + " or " +
IntervalView.class);
long flatSize = 1;
for (long dd : imgTensor.dimensionsAsLongArray()) {
flatSize *= dd;
}
float[] flatArr = new float[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
float val = tensorCursor.get().getRealFloat();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final float[] flatArr = new float[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
FloatBuffer buff = FloatBuffer.wrap(flatArr);
Tensor<Float> tensor = Tensor.create(imgTensor.dimensionsAsLongArray(),
Tensor<Float> ndarray = Tensor.create(tensor.dimensionsAsLongArray(),
buff);
return tensor;
return ndarray;
}

/**
* Creates a double-typed {@link Tensor} based on the provided
* {@link RandomAccessibleInterval} and the desired dimension order for the
* resulting tensor.
*
* @param imgTensor The {@link RandomAccessibleInterval} to be converted.
* @param tensor The {@link RandomAccessibleInterval} to be converted.
* @return The {@link Tensor} created from the sequence.
*/
private static Tensor<Double> buildDouble(
RandomAccessibleInterval<DoubleType> imgTensor)
RandomAccessibleInterval<DoubleType> tensor)
{
long[] tensorShape = imgTensor.dimensionsAsLongArray();
Cursor<DoubleType> tensorCursor;
if (imgTensor instanceof IntervalView) tensorCursor =
((IntervalView<DoubleType>) imgTensor).cursor();
else if (imgTensor instanceof Img) tensorCursor =
((Img<DoubleType>) imgTensor).cursor();
else throw new IllegalArgumentException("The data of the " + Tensor.class +
" has " + "to be an instance of " + Img.class + " or " +
IntervalView.class);
long flatSize = 1;
for (long dd : imgTensor.dimensionsAsLongArray()) {
flatSize *= dd;
}
double[] flatArr = new double[(int) flatSize];
while (tensorCursor.hasNext()) {
tensorCursor.fwd();
long[] cursorPos = tensorCursor.positionAsLongArray();
int flatPos = IndexingUtils.multidimensionalIntoFlatIndex(cursorPos,
tensorShape);
double val = tensorCursor.get().getRealFloat();
flatArr[flatPos] = val;
}
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
int size = 1;
for (long ll : tensorShape) size *= ll;
final double[] flatArr = new double[size];
int[] sArr = new int[tensorShape.length];
for (int i = 0; i < sArr.length; i ++)
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
DoubleBuffer buff = DoubleBuffer.wrap(flatArr);
Tensor<Double> tensor = Tensor.create(imgTensor.dimensionsAsLongArray(),
Tensor<Double> ndarray = Tensor.create(tensor.dimensionsAsLongArray(),
buff);
return tensor;
return ndarray;
}
}

0 comments on commit e646c85

Please sign in to comment.