Skip to content

Commit

Permalink
ADD: add two dim scaling for interpolate function
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jun 10, 2024
1 parent a7bc645 commit 62078c7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 24 deletions.
2 changes: 1 addition & 1 deletion nn/inc/nn_interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* @param out: the output tensor
* @param input: the input tensor
*/
void NN_interpolate_F32(Tensor *out, Tensor *in, float scale_factor/*const char* mode*/);
void NN_interpolate_F32(Tensor *out, Tensor *in, float *scale_factor/*const char* mode*/);


#endif // __NN_INTERPOLATE_H
45 changes: 23 additions & 22 deletions nn/src/interpolate/nn_interpolate.c
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@

#include "nn_interpolate.h"

void NN_interpolate_F32(Tensor *out, Tensor *in, float scale_factor/*const char* mode*/) {
void NN_interpolate_F32(Tensor *out, Tensor *in, float *scale_factor/*const char* mode*/) {
assert(in->ndim == 4);
assert(out->ndim == 4);
assert(in->dtype == DTYPE_F32);
assert(out->dtype == DTYPE_F32);
assert(scale_factor > 0);
assert(scale_factor[0] > 0);
assert(scale_factor[1] > 0);

size_t batch_size = in->shape[0];
size_t channels = in->shape[1];
Expand All @@ -16,31 +17,31 @@ void NN_interpolate_F32(Tensor *out, Tensor *in, float scale_factor/*const char*
size_t output_width = out->shape[3];

// Ensure output dimensions match the expected dimensions after scaling
assert(output_height == (size_t)(input_height * scale_factor));
assert(output_width == (size_t)(input_width * scale_factor));
assert(output_height == (size_t)(input_height * scale_factor[0]));
assert(output_width == (size_t)(input_width * scale_factor[1]));

// Initialize output tensor to zeros
memset(out->data, 0, batch_size * channels * output_height * output_width * sizeof(float));

for (size_t n = 0; n < batch_size; n++) {
for (size_t c = 0; c < channels; c++) {
for (size_t oh = 0; oh < output_height; oh++) {
for (size_t ow = 0; ow < output_width; ow++) {
size_t ih = (size_t)(oh / scale_factor);
size_t iw = (size_t)(ow / scale_factor);

size_t in_idx = n * channels * input_height * input_width
+ c * input_height * input_width
+ ih * input_width
+ iw;
size_t out_idx = n * channels * output_height * output_width
+ c * output_height * output_width
+ oh * output_width
+ ow;

((float *)out->data)[out_idx] = ((float *)in->data)[in_idx];
}
for (size_t n = 0; n < batch_size; n += 1) {
for (size_t c = 0; c < channels; c += 1) {
for (size_t oh = 0; oh < output_height; oh += 1) {
for (size_t ow = 0; ow < output_width; ow += 1) {
size_t ih = (size_t)(oh / scale_factor[0]);
size_t iw = (size_t)(ow / scale_factor[1]);

size_t in_idx = n * channels * input_height * input_width
+ c * input_height * input_width
+ ih * input_width
+ iw;
size_t out_idx = n * channels * output_height * output_width
+ c * output_height * output_width
+ oh * output_width
+ ow;

((float *)out->data)[out_idx] = ((float *)in->data)[in_idx];
}
}
}
}
}
2 changes: 1 addition & 1 deletion torchconverter/src/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def call_function(self, target, args, kwargs):
elif target == torch.nn.functional.interpolate:
layer_name = "interpolate_{count}".format(count=count) if count > 0 else "interpolate"
self.model_forward += INDENT + "// F.{layer_name}\n".format(layer_name=layer_name)
self.model_forward += INDENT + "NN_interpolate_F32(&model->{layer_name}, &model->{input_names[0]}, {scale_factor});\n".format(
self.model_forward += INDENT + "NN_interpolate_F32(&model->{layer_name}, &model->{input_names[0]}, (float []){{{scale_factor}, {scale_factor}}});\n".format(
layer_name=layer_name,
input_names=self.node_info[layer_name][0],
scale_factor=kwargs.get("scale_factor")
Expand Down

0 comments on commit 62078c7

Please sign in to comment.