Skip to content

Commit

Permalink
fix: Repair Citrinet-1024 compilation issues [Duplicate of PR #1488 f…
Browse files Browse the repository at this point in the history
…or Release 1.3] (#1489)
  • Loading branch information
gs-olive authored Nov 30, 2022
1 parent 8d7cd50 commit 8dc1a06
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
3 changes: 2 additions & 1 deletion core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ auto element_wise_registrations TORCHTRT_UNUSED =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
} else if (rounding_mode == "trunc") {
// trunc = floor(abs(div)) * sign(div)
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
auto tmp_div = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n) + "_tmp_div");
auto abs = add_abs(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val");

// In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this
Expand Down
8 changes: 8 additions & 0 deletions core/conversion/converters/impl/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ auto reduce_registrations TORCHTRT_UNUSED =
LOG_DEBUG("Keep dims: " << keepdim);

LOG_WARNING("Sum converter disregards dtype");

if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
LOG_DEBUG(
"Found type " << in_tensor->getType() << " in aten::sum, casting to "
<< nvinfer1::DataType::kINT32 << " for compatibility.");
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32);
}

auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);

TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
Expand Down
14 changes: 14 additions & 0 deletions tests/core/conversion/converters/test_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/torch.h"

namespace {
std::string gen_basic_graph(const std::string& op) {
Expand Down Expand Up @@ -162,6 +163,19 @@ TEST(Converters, ATenSumDimNegOneIndexKeepDimsConvertsCorrectly) {
test_body(graph, in);
}

TEST(Converters, ATenSumDimNegOneIndexKeepDimsBoolTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%1 : int = prim::Constant[value=-1]()
%2 : int[] = prim::ListConstruct(%1)
%3 : bool = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::sum(%0, %2, %3, %4)
return (%5))IR";
auto in = at::randint(0, 2, {4, 4, 4}, at::kCUDA).to(torch::kBool);
test_body(graph, in);
}

TEST(Converters, ATenSumDimNegIndexConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
Expand Down

0 comments on commit 8dc1a06

Please sign in to comment.