Skip to content

Commit

Permalink
correct important error wrt to converting images to tensros
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 27, 2023
1 parent bc048fd commit e1b8bdc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,11 @@ public static void fillOutputTensors(
if (outputNDArrays.size() != outputTensors.size())
throw new RunModelException(outputNDArrays.size(), outputTensors.size());
for (int i = 0; i < outputNDArrays.size(); i++) {
outputTensors.get(i).setData(ImgLib2Builder.build(outputNDArrays.get(i)));
try {
outputTensors.get(i).setData(ImgLib2Builder.build(outputNDArrays.get(i)));
} catch (IllegalArgumentException ex) {
throw new RunModelException(ex.toString());
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
package io.bioimage.modelrunner.tensorflow.v1.tensor;

import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.IndexingUtils;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;

import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.blocks.PrimitiveBlocks;
import net.imglib2.img.Img;
Expand All @@ -39,7 +37,6 @@
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;

import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
Expand Down Expand Up @@ -111,6 +108,7 @@ else if (Util.getTypeFromInterval(rai) instanceof DoubleType) {
private static Tensor<UInt8> buildByte(
RandomAccessibleInterval<ByteType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -122,8 +120,7 @@ private static Tensor<UInt8> buildByte(
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
ByteBuffer buff = ByteBuffer.wrap(flatArr);
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, tensor
.dimensionsAsLongArray(), buff);
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, ogShape, buff);
return ndarray;
}

Expand All @@ -138,6 +135,7 @@ private static Tensor<UInt8> buildByte(
private static Tensor<Integer> buildInt(
RandomAccessibleInterval<IntType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -149,8 +147,7 @@ private static Tensor<Integer> buildInt(
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
IntBuffer buff = IntBuffer.wrap(flatArr);
Tensor<Integer> ndarray = Tensor.create(tensor.dimensionsAsLongArray(),
buff);
Tensor<Integer> ndarray = Tensor.create(ogShape, buff);
return ndarray;
}

Expand All @@ -165,6 +162,7 @@ private static Tensor<Integer> buildInt(
private static Tensor<Float> buildFloat(
RandomAccessibleInterval<FloatType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -176,8 +174,7 @@ private static Tensor<Float> buildFloat(
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
FloatBuffer buff = FloatBuffer.wrap(flatArr);
Tensor<Float> ndarray = Tensor.create(tensor.dimensionsAsLongArray(),
buff);
Tensor<Float> ndarray = Tensor.create(ogShape, buff);
return ndarray;
}

Expand All @@ -192,6 +189,7 @@ private static Tensor<Float> buildFloat(
private static Tensor<Double> buildDouble(
RandomAccessibleInterval<DoubleType> tensor)
{
long[] ogShape = tensor.dimensionsAsLongArray();
tensor = Utils.transpose(tensor);
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
long[] tensorShape = tensor.dimensionsAsLongArray();
Expand All @@ -203,8 +201,7 @@ private static Tensor<Double> buildDouble(
sArr[i] = (int) tensorShape[i];
blocks.copy( new long[tensorShape.length], flatArr, sArr );
DoubleBuffer buff = DoubleBuffer.wrap(flatArr);
Tensor<Double> ndarray = Tensor.create(tensor.dimensionsAsLongArray(),
buff);
Tensor<Double> ndarray = Tensor.create(ogShape, buff);
return ndarray;
}
}

0 comments on commit e1b8bdc

Please sign in to comment.