Skip to content

Commit

Permalink
adapt examples
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 30, 2024
1 parent 4a1d917 commit 994ee30
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;

Expand Down Expand Up @@ -68,7 +70,8 @@ public class ExampleLoadAndRunModel {
* @throws Exception if there is any error downloading the engines or the models
* or running the model
*/
public static void main(String[] args) throws LoadEngineException, Exception {
public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void main(String[] args) throws LoadEngineException, Exception {

// Tag for the DL framework (engine) that wants to be used
String framework = "torchscript";
Expand Down Expand Up @@ -109,8 +112,8 @@ public static void main(String[] args) throws LoadEngineException, Exception {
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input0", "bcyx", img1);
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
inputs.add(inpTensor);
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
inputs.add((Tensor<T>) inpTensor);

// Create the output tensors defined in the rdf.yaml file with their corresponding
// name and axes and add them to the output list of tensors.
Expand All @@ -123,8 +126,8 @@ public static void main(String[] args) throws LoadEngineException, Exception {
new FloatType());*/
final Img< FloatType > img2 = imgFactory.create( 1, 2, 512, 512 );
Tensor<FloatType> outTensor = Tensor.build("output0", "bcyx", img2);
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
outputs.add(outTensor);
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
outputs.add((Tensor<R>) outTensor);

// Run the model on the input tensors. THe output tensors
// will be rewritten with the result of the execution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;

Expand Down Expand Up @@ -91,7 +93,8 @@ public static void main(String[] args) throws LoadEngineException, Exception {
* @throws LoadEngineException if there is any error loading an engine
* @throws Exception if there is any exception running the model
*/
public static void loadAndRunPt2(String modelFolder, String modelSource) throws LoadEngineException, Exception {
public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void loadAndRunPt2(String modelFolder, String modelSource) throws LoadEngineException, Exception {
// Tag for the DL framework (engine) that wants to be used
String framework = "torchscript";
// Version of the engine
Expand Down Expand Up @@ -123,8 +126,8 @@ public static void loadAndRunPt2(String modelFolder, String modelSource) throws
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input0", "bcyx", img1);
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
inputs.add(inpTensor);
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
inputs.add((Tensor<T>) inpTensor);

// Create the output tensors defined in the rdf.yaml file with their corresponding
// name and axes and add them to the output list of tensors.
Expand All @@ -137,8 +140,8 @@ public static void loadAndRunPt2(String modelFolder, String modelSource) throws
new FloatType());*/
final Img< FloatType > img2 = imgFactory.create( 1, 2, 512, 512 );
Tensor<FloatType> outTensor = Tensor.build("output0", "bcyx", img2);
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
outputs.add(outTensor);
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
outputs.add((Tensor<R>) outTensor);

// Run the model on the input tensors. THe output tensors
// will be rewritten with the result of the execution
Expand All @@ -161,7 +164,8 @@ public static void loadAndRunPt2(String modelFolder, String modelSource) throws
* @throws LoadEngineException if there is any error loading an engine
* @throws Exception if there is any exception running the model
*/
public static void loadAndRunPt1(String modelFolder, String modelSource) throws LoadEngineException, Exception {
public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void loadAndRunPt1(String modelFolder, String modelSource) throws LoadEngineException, Exception {
// Tag for the DL framework (engine) that wants to be used
String framework = "torchscript";
// Version of the engine
Expand Down Expand Up @@ -193,8 +197,8 @@ public static void loadAndRunPt1(String modelFolder, String modelSource) throws
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input0", "bcyx", img1);
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
inputs.add(inpTensor);
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
inputs.add((Tensor<T>) inpTensor);

// Create the output tensors defined in the rdf.yaml file with their corresponding
// name and axes and add them to the output list of tensors.
Expand All @@ -207,8 +211,8 @@ public static void loadAndRunPt1(String modelFolder, String modelSource) throws
new FloatType());*/
final Img< FloatType > img2 = imgFactory.create( 1, 2, 512, 512 );
Tensor<FloatType> outTensor = Tensor.build("output0", "bcyx", img2);
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
outputs.add(outTensor);
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
outputs.add((Tensor<R>) outTensor);

// Run the model on the input tensors. THe output tensors
// will be rewritten with the result of the execution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;

Expand Down Expand Up @@ -93,7 +95,8 @@ public static void main(String[] args) throws LoadEngineException, Exception {
* @throws LoadEngineException if there is any error loading an engine
* @throws Exception if there is any exception running the model
*/
public static void loadAndRunTf2() throws LoadEngineException, Exception {
public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void loadAndRunTf2() throws LoadEngineException, Exception {
// Tag for the DL framework (engine) that wants to be used
String framework = "tensorflow_saved_model_bundle";
// Version of the engine
Expand Down Expand Up @@ -130,8 +133,8 @@ public static void loadAndRunTf2() throws LoadEngineException, Exception {
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input_1", "bxyc", img1);
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
inputs.add(inpTensor);
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
inputs.add((Tensor<T>) inpTensor);

// Create the output tensors defined in the rdf.yaml file with their corresponding
// name and axes and add them to the output list of tensors.
Expand All @@ -140,8 +143,8 @@ public static void loadAndRunTf2() throws LoadEngineException, Exception {
// defining the dimensions and data type
Tensor<FloatType> outTensor0 = Tensor.buildBlankTensor(
"conv2d_19", "bxyc", new long[] {1, 512, 512, 3}, new FloatType());
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
outputs.add(outTensor0);
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
outputs.add((Tensor<R>) outTensor0);

// Run the model on the input tensors. THe output tensors
// will be rewritten with the result of the execution
Expand All @@ -160,7 +163,8 @@ public static void loadAndRunTf2() throws LoadEngineException, Exception {
* @throws LoadEngineException if there is any error loading an engine
* @throws Exception if there is any exception running the model
*/
public static void loadAndRunTf1() throws LoadEngineException, Exception {
public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
void loadAndRunTf1() throws LoadEngineException, Exception {
// Tag for the DL framework (engine) that wants to be used
String framework = "tensorflow_saved_model_bundle";
// Version of the engine
Expand Down Expand Up @@ -197,8 +201,8 @@ public static void loadAndRunTf1() throws LoadEngineException, Exception {
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input", "byxc", img1);
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
inputs.add(inpTensor);
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
inputs.add((Tensor<T>) inpTensor);

// Create the output tensors defined in the rdf.yaml file with their corresponding
// name and axes and add them to the output list of tensors.
Expand All @@ -207,8 +211,8 @@ public static void loadAndRunTf1() throws LoadEngineException, Exception {
// defining the dimensions and data type
final Img< FloatType > img2 = imgFactory.create( 1, 512, 512, 33 );
Tensor<FloatType> outTensor = Tensor.build("output", "byxc", img2);
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
outputs.add(outTensor);
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
outputs.add((Tensor<R>) outTensor);

// Run the model on the input tensors. THe output tensors
// will be rewritten with the result of the execution
Expand Down

0 comments on commit 994ee30

Please sign in to comment.