Skip to content

Commit

Permalink
[Filter] util to get nth info
Browse files Browse the repository at this point in the history
Use util function to get nth tensor-info ptr.
Also, remove unnecessary header include in filter subplugin.

Signed-off-by: Jaeyun Jung <[email protected]>
  • Loading branch information
jaeyun-jung authored and jijoongmoon committed Oct 16, 2024
1 parent 2a074ff commit 07ccc2c
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions nnstreamer/tensor_filter/tensor_filter_nntrainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

#include <nntrainer_error.h>

#include "ml-api-common.h"
#include "nnstreamer.h"
#include "nnstreamer_plugin_api.h"
#include "tensor_filter_nntrainer.hh"

Expand Down Expand Up @@ -77,9 +75,6 @@ to_nnst_tensor_dim(const ml::train::TensorDim &dim) {
for (unsigned int i = 0; i < ml::train::TensorDim::MAXDIM; ++i) {
info->dimension[i] = dim.getTensorDim(ml::train::TensorDim::MAXDIM - i - 1);
}
for (unsigned int i = ml::train::TensorDim::MAXDIM; i < NNS_TENSOR_RANK_LIMIT;
++i)
info->dimension[i] = 1;

return info;
}
Expand All @@ -90,8 +85,7 @@ static ml::train::TensorDim to_nntr_tensor_dim(const GstTensorInfo *info) {
}

NNTrainerInference::NNTrainerInference(const std::string &model_config_) :
batch_size(1),
model_config(model_config_) {
batch_size(1), model_config(model_config_) {
loadModel();
model->compile();
model->initialize();
Expand Down Expand Up @@ -219,6 +213,7 @@ static int nntrainer_setInputDim(const GstTensorFilterProperties *prop,
GstTensorsInfo *out_info) {
NNTrainerInference *nntrainer =
static_cast<NNTrainerInference *>(*private_data);
GstTensorInfo *_info;
g_return_val_if_fail(prop && nntrainer && in_info && out_info, -EINVAL);

auto model_inputs = nntrainer->getInputDimension();
Expand All @@ -232,13 +227,14 @@ static int nntrainer_setInputDim(const GstTensorFilterProperties *prop,
for (unsigned int i = 0u; i < model_inputs.size(); ++i) {
auto default_dim = model_inputs[i];
default_dim.batch(1);
mutable_in_info->info[i].type = _NNS_FLOAT32;
gst_tensor_info_copy(mutable_in_info->info + i,
to_nnst_tensor_dim(default_dim).get());
_info = gst_tensors_info_get_nth_info(mutable_in_info, i);
_info->type = _NNS_FLOAT32;
gst_tensor_info_copy(_info, to_nnst_tensor_dim(default_dim).get());
}
}

auto batch_size = in_info->info[0].dimension[3];
_info = gst_tensors_info_get_nth_info((GstTensorsInfo *)in_info, 0);
auto batch_size = _info->dimension[3];

/// this does not allocate the memory for the inference, so setting batch here
/// does not have a large effect on the first inference call as of now.
Expand All @@ -253,9 +249,9 @@ static int nntrainer_setInputDim(const GstTensorFilterProperties *prop,
/// check each in dimension matches
for (unsigned int i = 0; i < in_info->num_tensors; ++i) {
model_inputs[i].batch(batch_size);
g_return_val_if_fail(in_info->info[i].type == _NNS_FLOAT32, -EINVAL);
g_return_val_if_fail(
model_inputs[i] == to_nntr_tensor_dim(in_info->info + i), -EINVAL);
_info = gst_tensors_info_get_nth_info((GstTensorsInfo *)in_info, i);
g_return_val_if_fail(_info->type == _NNS_FLOAT32, -EINVAL);
g_return_val_if_fail(model_inputs[i] == to_nntr_tensor_dim(_info), -EINVAL);
}

auto model_outputs = nntrainer->getOutputDimension();
Expand All @@ -264,9 +260,9 @@ static int nntrainer_setInputDim(const GstTensorFilterProperties *prop,
out_info->num_tensors = model_outputs.size();
for (unsigned int i = 0; i < out_info->num_tensors; ++i) {
model_outputs[i].batch(batch_size);
out_info->info[i].type = _NNS_FLOAT32;
gst_tensor_info_copy(out_info->info + i,
to_nnst_tensor_dim(model_outputs[i]).get());
_info = gst_tensors_info_get_nth_info(out_info, i);
_info->type = _NNS_FLOAT32;
gst_tensor_info_copy(_info, to_nnst_tensor_dim(model_outputs[i]).get());
}

return 0;
Expand Down

0 comments on commit 07ccc2c

Please sign in to comment.