Skip to content

Commit 75f99c5

Browse files
committed
wasi_nn_tensorflowlite.cpp: fix get_output return size
it should be byte size, not the number of (fp32) values. i'm ambivalent about how to deal with the compatibility for the legacy wamr-specific "wasi_nn". for now, i avoided changing it. (so that existing tests using the legacy abi, namely test_tensorflow.c and test_tensorflow_quantized.c, passes as they are.) if we have any users who still want to use the legacy abi, i suppose they consider the compatibility is more important than the consistency with other backends. cf. #4376
1 parent a29f394 commit 75f99c5

File tree

1 file changed

+57
-16
lines changed

1 file changed

+57
-16
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -389,23 +389,34 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
389389
return too_large;
390390
}
391391

392-
uint32_t model_tensor_size = 1;
393-
for (int i = 0; i < (int)tensor->dims->size; ++i)
394-
model_tensor_size *= (uint32_t)tensor->dims->data[i];
395-
396-
if (*output_tensor_size < model_tensor_size) {
397-
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
398-
return too_large;
399-
}
400-
401392
if (tensor->quantization.type == kTfLiteNoQuantization) {
402393
NN_DBG_PRINTF("No quantization information");
403-
float *ot =
404-
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
405-
index);
406-
407-
int size = model_tensor_size * sizeof(float);
408-
bh_memcpy_s(output_tensor, size, ot, size);
394+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
395+
if (*output_tensor_size < tensor->bytes) {
396+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
397+
return too_large;
398+
}
399+
#else
400+
/*
401+
* for now, maintain the bug-to-bug compatibility with the old abi,
402+
* where the size here is the number of fp32, not bytes.
403+
*/
404+
if (*output_tensor_size < tensor->bytes / sizeof(float)) {
405+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
406+
return too_large;
407+
}
408+
#endif
409+
bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
410+
tensor->bytes);
411+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
412+
*output_tensor_size = tensor->bytes;
413+
#else
414+
/*
415+
* for now, maintain the bug-to-bug compatibility with the old abi,
416+
* where the size here is the number of fp32, not bytes.
417+
*/
418+
*output_tensor_size = tensor->bytes / sizeof(float);
419+
#endif
409420
}
410421
else { // TODO: Assuming uint8 quantized networks.
411422
TfLiteAffineQuantization *quant_info =
@@ -414,6 +425,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
414425
NN_ERR_PRINTF("Quantization per channel is not supported");
415426
return runtime_error;
416427
}
428+
429+
uint32_t model_tensor_size = 1;
430+
for (int i = 0; i < (int)tensor->dims->size; ++i)
431+
model_tensor_size *= (uint32_t)tensor->dims->data[i];
432+
433+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
434+
if (*output_tensor_size / sizeof(float) < model_tensor_size) {
435+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
436+
return too_large;
437+
}
438+
#else
439+
/*
440+
* for now, maintain the bug-to-bug compatibility with the old abi,
441+
* where the size here is the number of fp32, not bytes.
442+
*/
443+
if (*output_tensor_size < model_tensor_size) {
444+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
445+
return too_large;
446+
}
447+
#endif
448+
417449
uint8_t *ot = tfl_ctx->interpreters[ctx]
418450
.interpreter->typed_output_tensor<uint8_t>(index);
419451

@@ -426,9 +458,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
426458
for (uint32_t i = 0; i < model_tensor_size; ++i) {
427459
output_tensor_f[i] = (ot[i] - zero_point) * scale;
428460
}
461+
462+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
463+
*output_tensor_size = model_tensor_size * sizeof(float);
464+
#else
465+
/*
466+
* for now, maintain the bug-to-bug compatibility with the old abi,
467+
* where the size here is the number of fp32, not bytes.
468+
*/
469+
*output_tensor_size = model_tensor_size;
470+
#endif
429471
}
430472

431-
*output_tensor_size = model_tensor_size;
432473
return success;
433474
}
434475

0 commit comments

Comments
 (0)