Skip to content

Commit

Permalink
fix: TS test_scaled_dot_product_attention (#3117)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Aug 23, 2024
1 parent ad1ae8a commit 5a04839
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
23 changes: 19 additions & 4 deletions core/lowering/passes/unpack_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ namespace passes {
// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph) {
std::string sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale)
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa)
return (%out))IR";

std::string unpacked_sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%none : NoneType = prim::Constant()
%1 : int = prim::Constant[value=-1]()
%2 : int = prim::Constant[value=-2]()
Expand All @@ -33,7 +33,7 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
return(%out))IR";

std::string unpacked_sdpa_attn_biased_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%none : NoneType = prim::Constant()
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=-1]()
Expand Down Expand Up @@ -69,6 +69,16 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) {
return false;
}
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
if (enable_gqa_node->kind() != at::prim::Constant) {
LOG_WARNING(
"Could not unpack scaled_dot_product_attention with non constant enable_gqa: " << *enable_gqa_node);
return false;
}
if (enable_gqa_node->i(at::attr::value) == 1) {
LOG_WARNING("Could not unpack scaled_dot_product_attention with enable_gqa = True: " << *enable_gqa_node);
return false;
}
return true;
});

Expand All @@ -83,6 +93,11 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
// messages already written in first pass, do not write again
return false;
}
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
if (enable_gqa_node->kind() != at::prim::Constant || enable_gqa_node->i(at::attr::value) == 1) {
// messages already written in first pass, do not write again
return false;
}
return true;
});
LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ TEST(Converters, ATenScaledDotProductAttentionConvertsCorrectly) {
%none : NoneType = prim::Constant()
%0 : float = prim::Constant[value=0.]()
%scale : NoneType = prim::Constant()
%enable_gqa : bool = prim::Constant[value=0]()
%false : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale)
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %none, %0, %false, %scale, %enable_gqa)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -38,7 +39,8 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%scale : NoneType = prim::Constant()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
%enable_gqa : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -59,13 +61,14 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
TEST(Converters, ATenScaledDotProductAttnMaskIntConvertsCorrectly) {
const auto graph = R"IR(
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%scale : NoneType = prim::Constant()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale)
%enable_gqa : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
Expand All @@ -74,7 +77,7 @@ TEST(Converters, ATenScaledDotProductAttnMaskBoolConvertsCorrectly) {
auto query = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto key = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto value = at::rand({32, 8, 128, 64}, {at::kCUDA});
auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, at::kCUDA).to(at::kBool);
auto attn_mask = at::randint(0, 2, {32, 8, 128, 128}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask});

Expand Down

0 comments on commit 5a04839

Please sign in to comment.