Skip to content

Commit

Permalink
PR #15252: [XLA:CPU][oneDNN] Fix oneDNN matmul test timeout
Browse files Browse the repository at this point in the history
Imported from GitHub PR #15252

This PR addresses the test timeout observed in oneDNN matmul test file.

In particular, this PR:

1. Replaces the test titled ConsecutiveBinaryAdd with a smaller test such that the test still hits the targeted failure case.
2. Shards the test file.
3. In addition, this PR also replaces all instances of the old proto definitions with the new ones.
Copybara import of the project:

--
03e22ab by Akhil Goel <[email protected]>:

Fix test timeout

--
c3e08a6 by Akhil Goel <[email protected]>:

Address review comments

Merging this change closes #15252

COPYBARA_INTEGRATE_REVIEW=#15252 from Intel-tensorflow:akhil/fix_mm_timeout c3e08a6
PiperOrigin-RevId: 658441635
  • Loading branch information
akhilgoe authored and copybara-github committed Aug 1, 2024
1 parent 4b3c657 commit 336c9a9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 39 deletions.
1 change: 1 addition & 0 deletions xla/service/cpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ xla_cc_test(
name = "onednn_matmul_test",
srcs = ["onednn_matmul_test.cc"],
copts = tsl_copts(),
shard_count = 4,
tags = [
"no_oss",
"notap",
Expand Down
54 changes: 15 additions & 39 deletions xla/service/cpu/tests/onednn_matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,9 @@ TEST_F(MatmulTest, TestNonScalarConstantEltwiseLinearF32) {
; CHECK: backend_config={
; CHECK-DAG: "outer_dimension_partitions":[],
; CHECK-DAG: "onednn_matmul_config":{
; CHECK-NOT: "fused_ops":["LINEAR"]
; CHECK-NOT: "fusions":{
; CHECK-NOT: "ops":["LINEAR"]
; CHECK-NOT: }
; CHECK-DAG: }
; CHECK: }
)");
Expand Down Expand Up @@ -1502,44 +1504,18 @@ TEST_F(MatmulTest, WeightsPrepackAndScratch) {
TEST_F(MatmulTest, ConsecutiveBinaryAdd) {
const char* matmul_module_str = R"(
HloModule matmul.test.f32
region_0.22 {
Arg_0.23 = f32[] parameter(0)
Arg_1.24 = f32[] parameter(1)
ROOT add.25 = f32[] add(Arg_0.23, Arg_1.24)
}
region_1.29 {
Arg_0.30 = f32[] parameter(0)
Arg_1.31 = f32[] parameter(1)
ROOT add.32 = f32[] add(Arg_0.30, Arg_1.31)
}
ENTRY main {
constant.2 = f32[] constant(1e-06)
broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={}
constant.7 = f32[] constant(1)
broadcast.8 = f32[1000000,3] broadcast(constant.7), dimensions={}
Arg_0.1 = f32[3] parameter(0)
reshape.10 = f32[1,3] reshape(Arg_0.1)
broadcast.11 = f32[1,3] broadcast(reshape.10), dimensions={0,1}
reshape.12 = f32[3] reshape(broadcast.11)
broadcast.13 = f32[1000000,3] broadcast(reshape.12), dimensions={1}
subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13)
constant.4 = f32[] constant(0)
broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={}
dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
dot.16 = f32[1000000,3] dot(broadcast.3, dot.15), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
dot.17 = f32[1000000,3] dot(broadcast.3, subtract.14), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}
dot.18 = f32[1000000,3] dot(dot.17, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={1}
add.19 = f32[1000000,3] add(dot.16, dot.18)
constant.9 = f32[3] constant({1, 2, 3})
dot.20 = f32[1000000,3] dot(broadcast.3, constant.9), lhs_contracting_dims={}, rhs_contracting_dims={}
add.21 = f32[1000000,3] add(add.19, dot.20)
constant.6 = f32[] constant(0)
reduce.26 = f32[3] reduce(add.21, constant.6), dimensions={0}, to_apply=region_0.22
reshape.27 = f32[1,3] reshape(reduce.26)
negate.28 = f32[1,3] negate(reshape.27)
ROOT reduce.33 = f32[3] reduce(negate.28, constant.6), dimensions={0}, to_apply=region_1.29
ENTRY matmul.test.f32 {
arg0.1 = f32[128,32,4,4] parameter(0)
arg0.2 = f32[128,32,4,4] parameter(1)
dot.7 = f32[128,32,4,4] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
const.0 = f32[128,32] constant({...})
bcast.1 = f32[128,32,4,4] broadcast(const.0), dimensions={0,1}
add.0 = f32[128,32,4,4] add(dot.7,bcast.1)
const.1 = f32[4] constant({1,2,3,4})
bcast.2 = f32[128,32,4,4] broadcast(const.1), dimensions={3}
add.1 = f32[128,32,4,4] add(add.0, bcast.2)
tuple.12 = (f32[128,32,4,4]) tuple(add.1)
ROOT get-tuple-element.13 = f32[128,32,4,4] get-tuple-element(tuple.12), index=0
})";

EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4}));
Expand Down

0 comments on commit 336c9a9

Please sign in to comment.