Skip to content

Commit 6fba107

Browse files
authored
refactor: Use safer backend APIs (#149)
1 parent e3244bc commit 6fba107

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/libtorch.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -2105,7 +2105,11 @@ ModelInstanceState::SetInputTensors(
21052105
input, nullptr, nullptr, &input_shape,
21062106
&input_dims_count, nullptr, nullptr));
21072107

2108-
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
2108+
int64_t element_cnt = 0;
2109+
RESPOND_AND_SET_NULL_IF_ERROR(
2110+
&((*responses)[idx]),
2111+
GetElementCount(input_shape, input_dims_count, &element_cnt));
2112+
batchn_shape[0] += element_cnt;
21092113
}
21102114
} else {
21112115
batchn_shape =
@@ -2160,7 +2164,10 @@ ModelInstanceState::SetInputTensors(
21602164
input, HostPolicyName().c_str(), nullptr, nullptr, &shape,
21612165
&dims_count, nullptr, &buffer_count));
21622166

2163-
const int64_t batch_element_cnt = GetElementCount(shape, dims_count);
2167+
int64_t batch_element_cnt = 0;
2168+
RESPOND_AND_SET_NULL_IF_ERROR(
2169+
&((*responses)[idx]),
2170+
GetElementCount(shape, dims_count, &batch_element_cnt));
21642171

21652172
*cuda_copy |= SetStringInputTensor(
21662173
&input_list, input, input_name, buffer_count, batch_element_cnt,
@@ -2347,7 +2354,8 @@ ModelInstanceState::ReadOutputTensors(
23472354
batchn_shape[0] = shape[0];
23482355
}
23492356

2350-
const size_t tensor_element_cnt = GetElementCount(batchn_shape);
2357+
int64_t tensor_element_cnt = 0;
2358+
RETURN_IF_ERROR(GetElementCount(batchn_shape, &tensor_element_cnt));
23512359

23522360
// Only need an response tensor for requested outputs.
23532361
if (response != nullptr) {

0 commit comments

Comments
 (0)