Skip to content

Commit

Permalink
corect small errors in the example
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 5, 2023
1 parent 1a457f3 commit 55c788f
Showing 1 changed file with 28 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
*
* It also requires the installation of a TF2 and a TF1 engine.
*
* The example code downloads all the needed artifacts, thus excuting the whole
* The example code downloads all the needed artifacts, thus executing the whole
* example might take some time.
*
* @author Carlos Garcia Lopez de Haro
Expand All @@ -67,6 +67,20 @@ public class ExampleLoadTensorflow1Tensorflow2 {
private static final String ENGINES_DIR = new File(CWD, "engines").getAbsolutePath();
private static final String MODELS_DIR = new File(CWD, "models").getAbsolutePath();

/**
* Run the test
* @param <T>
* type of the tensors
* @param args
* arguments of the main method
* @throws LoadEngineException if there is any exception loading the engine
* @throws Exception if there is any exception in the tests
*/
public static < T extends RealType< T > & NativeType< T > > void main(String[] args) throws LoadEngineException, Exception {
loadAndRunTf1();
loadAndRunTf2();
}

/**
* Loads a TF2 model and runs it
* @throws LoadEngineException if there is any error loading an engine
Expand Down Expand Up @@ -105,10 +119,10 @@ public static void loadAndRunTf2() throws LoadEngineException, Exception {
Model model = loadModel(modelFolder, null, engineInfo);
// Create an image that will be the backend of the Input Tensor
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 512, 512, 1 );
final Img< FloatType > img1 = imgFactory.create( 1, 1, 3, 64, 64 );
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input", "bcyx", img1);
Tensor<FloatType> inpTensor = Tensor.build("input0", "bczyx", img1);
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
inputs.add(inpTensor);

Expand All @@ -117,10 +131,13 @@ public static void loadAndRunTf2() throws LoadEngineException, Exception {
/// Regard that output tensors can be built empty without allocating memory
// or allocating memory by creating the tensor with a sample empty image, or by
// defining the dimensions and data type
final Img< FloatType > img2 = imgFactory.create( 1, 512, 512, 33 );
Tensor<FloatType> outTensor = Tensor.build("output", "bcyx", img2);
Tensor<FloatType> outTensor0 = Tensor.buildBlankTensor(
"output0", "bczyx", new long[] {1, 1, 3, 64, 64}, new FloatType());
final Img< FloatType > img2 = imgFactory.create( 1, 2, 3, 64, 64 );
Tensor<FloatType> outTensor1 = Tensor.build("output1", "bczyx", img2);
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
outputs.add(outTensor);
outputs.add(outTensor0);
outputs.add(outTensor1);

// Run the model on the input tensors. THe output tensors
// will be rewritten with the result of the execution
Expand Down Expand Up @@ -171,11 +188,11 @@ public static void loadAndRunTf1() throws LoadEngineException, Exception {
// Load the corresponding model, for Tensorflow the arg model_source is not needed
Model model = loadModel(modelFolder, null, engineInfo);
// Create an image that will be the backend of the Input Tensor
final ImgFactory< FloatType > imgFactory = new CellImgFactory<>( new FloatType(), 5 );
final Img< FloatType > img1 = imgFactory.create( 1, 512, 512, 1 );
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 512, 512, 3 );
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input0", "bcyx", img1);
Tensor<FloatType> inpTensor = Tensor.build("input0", "byxc", img1);
List<Tensor<?>> inputs = new ArrayList<Tensor<?>>();
inputs.add(inpTensor);

Expand All @@ -184,8 +201,8 @@ public static void loadAndRunTf1() throws LoadEngineException, Exception {
/// Regard that output tensors can be built empty without allocating memory
// or allocating memory by creating the tensor with a sample empty image, or by
// defining the dimensions and data type
final Img< FloatType > img2 = imgFactory.create( 1, 512, 512, 1 );
Tensor<FloatType> outTensor = Tensor.build("output0", "bcyx", img2);
final Img< FloatType > img2 = imgFactory.create( 1, 512, 512, 33 );
Tensor<FloatType> outTensor = Tensor.build("output0", "byxc", img2);
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
outputs.add(outTensor);

Expand All @@ -201,20 +218,6 @@ public static void loadAndRunTf1() throws LoadEngineException, Exception {
System.out.print("Success running Tensorflow 1!!");
}

/**
* Run the test
* @param <T>
* type of the tensors
* @param args
* arguments of the main method
* @throws LoadEngineException if there is any exception loading the engine
* @throws Exception if there is any exception in the tests
*/
public static < T extends RealType< T > & NativeType< T > > void main(String[] args) throws LoadEngineException, Exception {
loadAndRunTf1();
loadAndRunTf2();
}

/**
* Method that creates the {@link EngineInfo} object.
* @param engine
Expand Down

0 comments on commit 55c788f

Please sign in to comment.