Skip to content

Commit

Permalink
Add other troubleshooting points.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Sep 19, 2024
1 parent cf35b0e commit 6a635fe
Showing 1 changed file with 169 additions and 1 deletion.
170 changes: 169 additions & 1 deletion docs/torchbench.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,171 @@ Please, open an issue if you see any of the codes in the table above, except for
`FAIL`. If you see `FAIL`, but did not introduce any change to any code, please, open an
issue. Otherwise, it means that your change likely introduced some incorrect behavior.

### Profile the Benchmark

## Experiment Result
In order to find out performance improvement opportunities, the benchmarking scripts
provide a few command-line arguments. For more information, check out [the optional
metrics][#optional-metrics] and [the available dump-able
information][#dumping-benchmark-specific-data].

A good overall command-line argument for that is the flag `--dump-pytorch-profiles` (see
[this section][#dumping-benchmark-specific-data] for more details). It makes use of
[PyTorch baked in profiler][11].

Suppose, for example we profile [the `speech_transformer` benchmark][12] by passing that
flag to _experiment_runner.py_ script. We would end up with the following directory
hierarchy:

```
- output
|- results.jsonl
|- speech_transformer
|- 0
|- pytorch-profile-...-speech_transformer.txt
|- trace-...-speech_transformer.json
|- 1
|- pytorch-profile-...-speech_transformer.txt
|- trace-...-speech_transformer.json
|- 2
|- pytorch-profile-...-speech_transformer.txt
|- trace-...-speech_transformer.json
...
|- pytorch-profile-...-speech_transformer.txt
```

- **Numbered Directories:** represents each repetition

- **Repetition-Specific _pytorch-profile-...txt_:** a table with the time spent in each
operation inside PyTorch

```
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
loop_convert_fusion 0.00% 0.000us 0.00% 0.000us 0.000us 1.037ms 27.24% 1.037ms 26.581us 39
ampere_bf16_s1688gemm_bf16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 603.560us 15.86% 603.560us 33.531us 18
ampere_bf16_s1688gemm_bf16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 407.618us 10.71% 407.618us 16.305us 25
ampere_bf16_s16816gemm_bf16_128x128_ldg8_f2f_stages_... 0.00% 0.000us 0.00% 0.000us 0.000us 324.994us 8.54% 324.994us 54.166us 6
Torch-Compiled Region 29.63% 74.028ms 498.57% 1.246s 83.052ms 0.000us 0.00% 206.090us 13.739us 15
void cutlass::Kernel<cutlass_80_wmma_tensorop_bf16_s... 0.00% 0.000us 0.00% 0.000us 0.000us 143.426us 3.77% 143.426us 11.952us 12
void cutlass::Kernel<cutlass_80_wmma_tensorop_bf16_s... 0.00% 0.000us 0.00% 0.000us 0.000us 135.970us 3.57% 135.970us 5.665us 24
...
aten::_to_copy 0.56% 1.400ms 0.62% 1.537ms 384.260us 0.000us 0.00% 0.000us 0.000us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 249.871ms
Self CUDA time total: 3.806ms
```

- **Repetition-Specific _trace-...json_:** a JSON file, containing the time-line of the
profiled events, to be loaded in a chrome browser at address `chrome://tracing` (see
[this page][13] for more details)

![Chrome tracing](_static/img/torchbench-chrome-tracing.png)

- **Outer _pytorch-profile-...txt_:** the concatenation of the tables corresponding to
each repetition

### Tweaking a Benchmark

In general, the benchmarking scripts were written so as to reproduce PyTorch HUD inductor
results while comparing with PyTorch/XLA performance. Therefore, much of the execution
setup was mirrored from [PyTorch HUD benchmarking scripts][2]. That said, we do allow a
couple of customization room by exposing a few command-line arguments:

- `--batch-size`: it is benchmark-specific (from PyTorch Torchbench configuration file),
by default

- [`--matmul-precision`][14]: it is set to `high`, by default

### Intermediate Representations (IRs)

PyTorch/XLA manipulates mainly 2 IRs:

1. **PyTorch lazy IR:** directed acyclic graph (DAG) created by the tracing phase; and
2. **HLO Graph:** DAG output of lowering the lazy IR, [used by XLA][15] in the backend

In order to inspect them (i.e. the compiled HLO graphs), our benchmarking scripts exposed
[the `--save-ir-format` command-line argument][#dumping-benchmark-specific-data]. By
specifying such a flag, every compiled graph will get dumped in the file
_dump-...<format>_, where _"<format>"_ can be any of:

- `text` (for lazy IR)

```
IR {
%0 = bf16[] xla::device_data(), xla_shape=bf16[]
%1 = bf16[1,27,84,84]{3,2,1,0} xla::device_data(), xla_shape=bf16[1,27,84,84]{3,2,1,0}
%2 = bf16[1,27,84,84]{3,2,1,0} aten::div(%1, %0), xla_shape=bf16[1,27,84,84]{3,2,1,0}, ROOT=0
%3 = bf16[32]{0} xla::device_data(), xla_shape=bf16[32]{0}
%4 = bf16[32,27,3,3]{3,2,1,0} xla::device_data(), xla_shape=bf16[32,27,3,3]{3,2,1,0}
%5 = bf16[1,32,41,41]{3,2,1,0} aten::convolution_overrideable(%2, %4, %3), xla_shape=bf16[1,32,41,41]{3,2,1,0}
%6 = bf16[1,32,41,41]{3,2,1,0} aten::relu(%5), xla_shape=bf16[1,32,41,41]{3,2,1,0}, ROOT=1
%7 = bf16[32]{0} xla::device_data(), xla_shape=bf16[32]{0}
%8 = bf16[32,32,3,3]{3,2,1,0} xla::device_data(), xla_shape=bf16[32,32,3,3]{3,2,1,0}
%9 = bf16[1,32,39,39]{3,2,1,0} aten::convolution_overrideable(%6, %8, %7), xla_shape=bf16[1,32,39,39]{3,2,1,0}
...
%66 = bf16[1,1]{1,0} aten::exp(%65), xla_shape=bf16[1,1]{1,0}, ROOT=7
}
```

- `hlo` (for HLO IR)

```
HloModule IrToHlo.129, entry_computation_layout={(bf16[], bf16[1,27,84,84]{3,2,1,0}, bf16[32]{0}, bf16[32,27,3,3]{3,2,1,0}, bf16[32]{0}, /*index=5*/bf16[32,32,3,3]{3,2,1,0}, bf16[32]{0}, bf16[32,32,3,3]{3,2,1,0}, bf16[32]{0}, bf16[32,32,3,3]{3,2,1,0}, /*index=10*/bf16[50]{0}, bf16[50]{0}, bf16[50,39200]{1,0}, bf16[50]{0}, bf16[2]{0}, /*index=15*/bf16[2,1024]{1,0}, bf16[1024]{0}, bf16[1024,1024]{1,0}, bf16[1024]{0}, bf16[1024,50]{1,0}, /*index=20*/bf16[], f32[])->(bf16[1,27,84,84]{3,2,1,0}, bf16[1,32,41,41]{3,2,1,0}, bf16[1,32,39,39]{3,2,1,0}, bf16[1,32,37,37]{3,2,1,0}, bf16[1,32,35,35]{3,2,1,0}, /*index=5*/bf16[1,50]{1,0}, bf16[1,1]{1,0}, bf16[1,1]{1,0})}
ENTRY %IrToHlo.129 (p0.1: bf16[], p1.2: bf16[1,27,84,84], p2.5: bf16[32], p3.6: bf16[32,27,3,3], p4.14: bf16[32], p5.15: bf16[32,32,3,3], p6.23: bf16[32], p7.24: bf16[32,32,3,3], p8.32: bf16[32], p9.33: bf16[32,32,3,3], p10.42: bf16[50], p11.55: bf16[50], p12.56: bf16[50,39200], p13.73: bf16[50], p14.82: bf16[2], p15.83: bf16[2,1024], p16.85: bf16[1024], p17.86: bf16[1024,1024], p18.88: bf16[1024], p19.89: bf16[1024,50], p20.112: bf16[], p21.113: f32[]) -> (bf16[1,27,84,84], bf16[1,32,41,41], bf16[1,32,39,39], bf16[1,32,37,37], bf16[1,32,35,35], /*index=5*/bf16[1,50], bf16[1,1], bf16[1,1]) {
%constant.43 = bf16[] constant(0)
%reshape.44 = bf16[1]{0} reshape(bf16[] %constant.43)
%broadcast.45 = bf16[1]{0} broadcast(bf16[1]{0} %reshape.44), dimensions={0}
%constant.46 = bf16[] constant(0)
%reshape.47 = bf16[1]{0} reshape(bf16[] %constant.46)
%broadcast.48 = bf16[1]{0} broadcast(bf16[1]{0} %reshape.47), dimensions={0}
%p1.2 = bf16[1,27,84,84]{3,2,1,0} parameter(1)
%p0.1 = bf16[] parameter(0)
%broadcast.3 = bf16[1,27,84,84]{3,2,1,0} broadcast(bf16[] %p0.1), dimensions={}
%divide.4 = bf16[1,27,84,84]{3,2,1,0} divide(bf16[1,27,84,84]{3,2,1,0} %p1.2, bf16[1,27,84,84]{3,2,1,0} %broadcast.3)
%p3.6 = bf16[32,27,3,3]{3,2,1,0} parameter(3)
%convolution.7 = bf16[1,32,41,41]{3,2,1,0} convolution(bf16[1,27,84,84]{3,2,1,0} %divide.4, bf16[32,27,3,3]{3,2,1,0} %p3.6), window={size=3x3 stride=2x2}, dim_labels=bf01_oi01->bf01, operand_precision={high,high}
%p2.5 = bf16[32]{0} parameter(2)
%broadcast.8 = bf16[1,41,41,32]{3,2,1,0} broadcast(bf16[32]{0} %p2.5), dimensions={3}
%transpose.9 = bf16[1,32,41,41]{1,3,2,0} transpose(bf16[1,41,41,32]{3,2,1,0} %broadcast.8), dimensions={0,3,1,2}
%add.10 = bf16[1,32,41,41]{3,2,1,0} add(bf16[1,32,41,41]{3,2,1,0} %convolution.7, bf16[1,32,41,41]{1,3,2,0} %transpose.9)
...
ROOT %tuple.128 = (bf16[1,27,84,84]{3,2,1,0}, bf16[1,32,41,41]{3,2,1,0}, bf16[1,32,39,39]{3,2,1,0}, bf16[1,32,37,37]{3,2,1,0}, bf16[1,32,35,35]{3,2,1,0}, /*index=5*/bf16[1,50]{1,0}, bf16[1,1]{1,0}, bf16[1,1]{1,0}) tuple(bf16[1,27,84,84]{3,2,1,0} %divide.4, bf16[1,32,41,41]{3,2,1,0} %maximum.13, bf16[1,32,39,39]{3,2,1,0} %maximum.22, bf16[1,32,37,37]{3,2,1,0} %maximum.31, bf16[1,32,35,35]{3,2,1,0} %maximum.40, /*index=5*/bf16[1,50]{1,0} %tanh.81, bf16[1,1]{1,0} %slice.109, bf16[1,1]{1,0} %exponential.127)
}
```

- `stablehlo` (for StableHLO IR)

```
module @IrToHlo.129 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<bf16>, %arg1: tensor<1x27x84x84xbf16>, %arg2: tensor<32xbf16>, %arg3: tensor<32x27x3x3xbf16>, %arg4: tensor<32xbf16>, %arg5: tensor<32x32x3x3xbf16>, %arg6: tensor<32xbf16>, %arg7: tensor<32x32x3x3xbf16>, %arg8: tensor<32xbf16>, %arg9: tensor<32x32x3x3xbf16>, %arg10: tensor<50xbf16>, %arg11: tensor<50xbf16>, %arg12: tensor<50x39200xbf16>, %arg13: tensor<50xbf16>, %arg14: tensor<2xbf16>, %arg15: tensor<2x1024xbf16>, %arg16: tensor<1024xbf16>, %arg17: tensor<1024x1024xbf16>, %arg18: tensor<1024xbf16>, %arg19: tensor<1024x50xbf16>, %arg20: tensor<bf16>, %arg21: tensor<f32>) -> (tensor<1x27x84x84xbf16>, tensor<1x32x41x41xbf16>, tensor<1x32x39x39xbf16>, tensor<1x32x37x37xbf16>, tensor<1x32x35x35xbf16>, tensor<1x50xbf16>, tensor<1x1xbf16>, tensor<1x1xbf16>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<1x1xbf16>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<1x1024xbf16>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<1xbf16>
%cst_2 = stablehlo.constant dense<1.000000e+00> : tensor<1xbf16>
%cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<1x32x35x35xbf16>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<1x32x37x37xbf16>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<1x32x39x39xbf16>
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<1x32x41x41xbf16>
%0 = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<bf16>) -> tensor<1x27x84x84xbf16>
%1 = stablehlo.divide %arg1, %0 : tensor<1x27x84x84xbf16>
%2 = stablehlo.convolution(%1, %arg3) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision HIGH>, #stablehlo<precision HIGH>]} : (tensor<1x27x84x84xbf16>, tensor<32x27x3x3xbf16>) -> tensor<1x32x41x41xbf16>
%3 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<32xbf16>) -> tensor<1x32x41x41xbf16>
%4 = stablehlo.add %2, %3 : tensor<1x32x41x41xbf16>
%5 = stablehlo.maximum %4, %cst_6 : tensor<1x32x41x41xbf16>
%6 = stablehlo.convolution(%5, %arg5) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision HIGH>, #stablehlo<precision HIGH>]} : (tensor<1x32x41x41xbf16>, tensor<32x32x3x3xbf16>) -> tensor<1x32x39x39xbf16>
%7 = stablehlo.broadcast_in_dim %arg4, dims = [1] : (tensor<32xbf16>) -> tensor<1x32x39x39xbf16>
%8 = stablehlo.add %6, %7 : tensor<1x32x39x39xbf16>
%9 = stablehlo.maximum %8, %cst_5 : tensor<1x32x39x39xbf16>
%10 = stablehlo.convolution(%9, %arg7) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision HIGH>, #stablehlo<precision HIGH>]} : (tensor<1x32x39x39xbf16>, tensor<32x32x3x3xbf16>) -> tensor<1x32x37x37xbf16>
...
return %1, %5, %9, %13, %17, %29, %44, %54 : tensor<1x27x84x84xbf16>, tensor<1x32x41x41xbf16>, tensor<1x32x39x39xbf16>, tensor<1x32x37x37xbf16>, tensor<1x32x35x35xbf16>, tensor<1x50xbf16>, tensor<1x1xbf16>, tensor<1x1xbf16>
}
}
```

## Experiment Results

The benchmarking scripts will store the resulting artifacts in a directory called `output`
(default) or another one specified by the parameter `--output-dirname`.
Expand Down Expand Up @@ -304,3 +467,8 @@ of the following:
[8]: https://jsonlines.org/
[9]: https://pytorch.org/docs/stable/torch.compiler_profiling_torch_compile.html
[10]: https://github.com/pytorch/pytorch/blob/a4e9a1c90b74572b48f2eedf1e931c18713c1781/torch/_dynamo/utils.py#L1616
[11]: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
[12]: https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models/speech_transformer
[13]: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-tracing-functionality
[14]: https://github.com/pytorch/pytorch/issues/76440
[15]: https://openxla.org/xla/architecture

0 comments on commit 6a635fe

Please sign in to comment.