Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CPU] Add benchmarks for 1D strided convolutions #19261

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions xla/service/cpu/benchmarks/convolution_benchmark_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,99 @@ static void BM_Conv2D(benchmark::State& state) {
padding_w, "_", padding_w)}}));
}

// Regular strided convolution. Shapes come from an actual use case.
static void BM_Conv1DStrided(benchmark::State& state) {
std::string hlo_module = R"(
HloModule jit_jconvf

ENTRY main.6 {
Arg_0.1 = f32[16,1,25600]{2,1,0} parameter(0)
Arg_1.2 = f32[129,1,256]{2,1,0} parameter(1)
ROOT convolution.3 = f32[16,129,400]{2,1,0} convolution(Arg_0.1, Arg_1.2),
window={size=256 stride=64 pad=96_96}, dim_labels=bf0_oi0->bf0
}
)";

std::minstd_rand0 engine;

// NCW layout
auto input_shape = ShapeUtil::MakeShape(F32, {16, 1, 25600});
// OIW layout
auto kernel_shape = ShapeUtil::MakeShape(F32, {129, 1, 256});

auto input =
*LiteralUtil::CreateRandomLiteral<F32>(input_shape, &engine, 1.0f, 0.1f);
auto kernel =
*LiteralUtil::CreateRandomLiteral<F32>(kernel_shape, &engine, 1.0f, 0.1f);
std::vector<const Literal*> args = {&input, &kernel};

CHECK_OK(RunHloBenchmark(state, hlo_module, args));
}

// Transposed version (i.e. gradient) of BM_Conv1DStrided. In terms of shapes,
// this operation can be thought of as reverse of regular strided convolution,
// that's why input and output shapes are swapped (so we can directly compare
// performance of this function with BM_Conv1DStrided).
// Currently, the performance is orders of magnitude worse than regular conv
// when they should be similar.
static void BM_Conv1DTransposedStrided(benchmark::State& state) {
std::string hlo_module = R"(
HloModule jit_jconvt

ENTRY main.6 {
Arg_0.1 = f32[16,129,400]{2,1,0} parameter(0)
Arg_1.2 = f32[129,1,256]{2,1,0} parameter(1)
ROOT convolution.3 = f32[16,1,25600]{2,1,0} convolution(Arg_0.1, Arg_1.2),
window={size=256 pad=159_159 lhs_dilate=64}, dim_labels=bf0_io0->bf0
}
)";

std::minstd_rand0 engine;

// NCW layout
auto input_shape = ShapeUtil::MakeShape(F32, {16, 129, 400});
// IOW layout
auto kernel_shape = ShapeUtil::MakeShape(F32, {129, 1, 256});

auto input =
*LiteralUtil::CreateRandomLiteral<F32>(input_shape, &engine, 1.0f, 0.1f);
auto kernel =
*LiteralUtil::CreateRandomLiteral<F32>(kernel_shape, &engine, 1.0f, 0.1f);
std::vector<const Literal*> args = {&input, &kernel};

CHECK_OK(RunHloBenchmark(state, hlo_module, args));
}

// The same shapes as BM_Conv1DTransposedStrided, but with a different layout.
static void BM_Conv1DTransposedStridedNonDefaultLayout(
benchmark::State& state) {
std::string hlo_module = R"(
HloModule jit_jconvt

ENTRY main.6 {
Arg_0.1 = f32[16,400,129]{2,1,0} parameter(0)
Arg_1.2 = f32[256,1,129]{2,1,0} parameter(1)
ROOT convolution.3 = f32[16,25600,1]{2,1,0} convolution(Arg_0.1, Arg_1.2),
window={size=256 pad=159_159 lhs_dilate=64}, dim_labels=b0f_0oi->b0f
}
)";

std::minstd_rand0 engine;

// NWC layout
auto input_shape = ShapeUtil::MakeShape(F32, {16, 400, 129});
// WOI layout
auto kernel_shape = ShapeUtil::MakeShape(F32, {256, 1, 129});

auto input =
*LiteralUtil::CreateRandomLiteral<F32>(input_shape, &engine, 1.0f, 0.1f);
auto kernel =
*LiteralUtil::CreateRandomLiteral<F32>(kernel_shape, &engine, 1.0f, 0.1f);
std::vector<const Literal*> args = {&input, &kernel};

CHECK_OK(RunHloBenchmark(state, hlo_module, args));
}

static void BM_GroupedConv2D(benchmark::State& state) {
int batch = state.range(0);
int height = state.range(1);
Expand Down Expand Up @@ -188,6 +281,14 @@ BENCHMARK(BM_Conv2D<F32>)
->Args({32, 64, 64, 4, 3, 3, 16})
->Args({32, 32, 32, 96, 3, 3, 96});

// -------------------------------------------------------------------------- //
// 1D strided convolutions
// -------------------------------------------------------------------------- //

BENCHMARK(BM_Conv1DStrided)->MeasureProcessCPUTime();
BENCHMARK(BM_Conv1DTransposedStrided)->MeasureProcessCPUTime();
BENCHMARK(BM_Conv1DTransposedStridedNonDefaultLayout)->MeasureProcessCPUTime();

// -------------------------------------------------------------------------- //
// Grouped convolution
// -------------------------------------------------------------------------- //
Expand Down
Loading