Skip to content

Commit bef3b63

Browse files
authored
example for using TorchInductor caching with torch.compile (#2925)
* example for using torhcinductor caching * example for using torhcinductor caching * example for using torhcinductor caching * update README * update README * review comments * updated readme * Verified with 4 workers * verified with 4 workers * added additional links for debugging
1 parent 33d87e3 commit bef3b63

File tree

6 files changed

+271
-3
lines changed

6 files changed

+271
-3
lines changed

examples/pt2/README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,17 @@ torchserve takes care of 4 and 5 for you while the remaining steps are your resp
5151

5252
### Note
5353

54-
`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. If that's an issue for you make sure to populate these two environment variables to improve your warm starts.
54+
`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. To reduce the warm up time, `TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine
55+
56+
To persist this cache and /or to make use of additional experimental caching feature, set the following
5557

5658
```
5759
import os
5860

59-
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "1"
60-
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "/path/to/directory" # replace with your desired path
61+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/path/to/directory" # replace with your desired path
62+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
6163
```
64+
An example of how to use these with TorchServe is shown [here](./torch_inductor_caching/)
6265
6366
## torch.export.export
6467
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
2+
# TorchInductor Caching with TorchServe inference of densenet161 model
3+
4+
`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. To handle this, `TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine
5+
6+
## TorchInductor FX Graph Cache
7+
There is an experimental feature to cache FX Graph as well. This is not enabled by default and can be set with the following config
8+
9+
```
10+
import os
11+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
12+
```
13+
14+
This needs to be set before you `import torch`
15+
16+
or
17+
18+
```
19+
import torch
20+
21+
torch._inductor.config.fx_graph_cache = True
22+
```
23+
24+
To see the effect of caching on `torch.compile` execution times, we need to have a multi worker setup. In this example, we use 4 workers. Workers 2,3,4 will see the benefit of caching when they execute `torch.compile`
25+
26+
We show below how this can be used with TorchServe
27+
28+
29+
### Pre-requisites
30+
31+
- `PyTorch >= 2.2`
32+
33+
Change directory to the examples directory
34+
Ex: `cd examples/pt2/torch_inductor_caching`
35+
36+
37+
### torch.compile config
38+
39+
`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html)
40+
41+
In this example , we use the following config
42+
43+
```yaml
44+
pt2 : {backend: inductor, mode: max-autotune}
45+
```
46+
47+
### Create model archive
48+
49+
```
50+
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
51+
mkdir model_store
52+
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler ./caching_handler.py --config-file model-config-fx-cache.yaml -f
53+
```
54+
55+
#### Start TorchServe
56+
```
57+
torchserve --start --ncs --model-store model_store --models densenet161.mar
58+
```
59+
60+
#### Run Inference
61+
62+
```
63+
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
64+
```
65+
66+
produces the output
67+
68+
```
69+
{
70+
"tabby": 0.4664836823940277,
71+
"tiger_cat": 0.4645617604255676,
72+
"Egyptian_cat": 0.06619937717914581,
73+
"lynx": 0.0012969186063855886,
74+
"plastic_bag": 0.00022856894065625966
75+
}{
76+
"tabby": 0.4664836823940277,
77+
"tiger_cat": 0.4645617604255676,
78+
"Egyptian_cat": 0.06619937717914581,
79+
"lynx": 0.0012969186063855886,
80+
"plastic_bag": 0.00022856894065625966
81+
}{
82+
"tabby": 0.4664836823940277,
83+
"tiger_cat": 0.4645617604255676,
84+
"Egyptian_cat": 0.06619937717914581,
85+
"lynx": 0.0012969186063855886,
86+
"plastic_bag": 0.00022856894065625966
87+
}{
88+
"tabby": 0.4664836823940277,
89+
"tiger_cat": 0.4645617604255676,
90+
"Egyptian_cat": 0.06619937717914581,
91+
"lynx": 0.0012969186063855886,
92+
"plastic_bag": 0.00022856894065625966
93+
}
94+
```
95+
96+
## TorchInductor Cache Directory
97+
`TorchInductor` already makes use of caching in `/tmp/torchinductor_USERID` of your machine.
98+
99+
Since the default directory is in `/tmp`, the cache is deleted on restart
100+
101+
`torch.compile` provides a config to change the cache directory for `TorchInductor `
102+
103+
```
104+
import os
105+
106+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/path/to/directory" # replace with your desired path
107+
108+
```
109+
110+
111+
We show below how this can be used with TorchServe
112+
113+
114+
### Pre-requisites
115+
116+
- `PyTorch >= 2.2`
117+
118+
Change directory to the examples directory
119+
Ex: `cd examples/pt2/torch_inductor_caching`
120+
121+
122+
### torch.compile config
123+
124+
`torch.compile` supports a variety of config and the performance you get can vary based on the config. You can find the various options [here](https://pytorch.org/docs/stable/generated/torch.compile.html)
125+
126+
In this example , we use the following config
127+
128+
```yaml
129+
pt2 : {backend: inductor, mode: max-autotune}
130+
```
131+
132+
### Create model archive
133+
134+
```
135+
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
136+
mkdir model_store
137+
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ../../image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ../../image_classifier/index_to_name.json --handler ./caching_handler.py --config-file model-config-cache-dir.yaml -f
138+
```
139+
140+
#### Start TorchServe
141+
```
142+
torchserve --start --ncs --model-store model_store --models densenet161.mar
143+
```
144+
145+
#### Run Inference
146+
147+
```
148+
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg && curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
149+
```
150+
151+
produces the output
152+
153+
```
154+
{
155+
"tabby": 0.4664836823940277,
156+
"tiger_cat": 0.4645617604255676,
157+
"Egyptian_cat": 0.06619937717914581,
158+
"lynx": 0.0012969186063855886,
159+
"plastic_bag": 0.00022856894065625966
160+
}{
161+
"tabby": 0.4664836823940277,
162+
"tiger_cat": 0.4645617604255676,
163+
"Egyptian_cat": 0.06619937717914581,
164+
"lynx": 0.0012969186063855886,
165+
"plastic_bag": 0.00022856894065625966
166+
}{
167+
"tabby": 0.4664836823940277,
168+
"tiger_cat": 0.4645617604255676,
169+
"Egyptian_cat": 0.06619937717914581,
170+
"lynx": 0.0012969186063855886,
171+
"plastic_bag": 0.00022856894065625966
172+
}{
173+
"tabby": 0.4664836823940277,
174+
"tiger_cat": 0.4645617604255676,
175+
"Egyptian_cat": 0.06619937717914581,
176+
"lynx": 0.0012969186063855886,
177+
"plastic_bag": 0.00022856894065625966
178+
}
179+
```
180+
181+
## Additional links for improving `torch.compile` performance and debugging
182+
183+
- [Compile Threads](https://pytorch.org/blog/training-production-ai-models/#34-controlling-just-in-time-compilation-time)
184+
- [Profiling torch.compile](https://pytorch.org/docs/stable/torch.compiler_profiling_torch_compile.html)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import logging
2+
import os
3+
4+
import torch
5+
from torch._dynamo.utils import counters
6+
7+
from ts.torch_handler.image_classifier import ImageClassifier
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class TorchInductorCacheHandler(ImageClassifier):
13+
"""
14+
Diffusion-Fast handler class for text to image generation.
15+
"""
16+
17+
def __init__(self):
18+
super().__init__()
19+
self.initialized = False
20+
21+
def initialize(self, ctx):
22+
"""In this initialize function, the model is loaded and
23+
initialized here.
24+
Args:
25+
ctx (context): It is a JSON Object containing information
26+
pertaining to the model artifacts parameters.
27+
"""
28+
self.context = ctx
29+
self.manifest = ctx.manifest
30+
properties = ctx.system_properties
31+
32+
if (
33+
"handler" in ctx.model_yaml_config
34+
and "torch_inductor_caching" in ctx.model_yaml_config["handler"]
35+
):
36+
if ctx.model_yaml_config["handler"]["torch_inductor_caching"].get(
37+
"torch_inductor_fx_graph_cache", False
38+
):
39+
torch._inductor.config.fx_graph_cache = True
40+
if (
41+
"torch_inductor_cache_dir"
42+
in ctx.model_yaml_config["handler"]["torch_inductor_caching"]
43+
):
44+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = ctx.model_yaml_config[
45+
"handler"
46+
]["torch_inductor_caching"]["torch_inductor_cache_dir"]
47+
48+
super().initialize(ctx)
49+
self.initialized = True
50+
51+
def inference(self, data, *args, **kwargs):
52+
with torch.inference_mode():
53+
marshalled_data = data.to(self.device)
54+
results = self.model(marshalled_data, *args, **kwargs)
55+
56+
# Debugs for FX Graph Cache hit
57+
if torch._inductor.config.fx_graph_cache:
58+
fx_graph_cache_hit, fx_graph_cache_miss = (
59+
counters["inductor"]["fxgraph_cache_hit"],
60+
counters["inductor"]["fxgraph_cache_miss"],
61+
)
62+
logger.info(
63+
f'TorchInductor FX Graph cache hit {counters["inductor"]["fxgraph_cache_hit"]}, FX Graph cache miss {counters["inductor"]["fxgraph_cache_miss"]}'
64+
)
65+
return results
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
minWorkers: 4
2+
maxWorkers: 4
3+
responseTimeout: 600
4+
pt2 : {backend: inductor, mode: max-autotune}
5+
handler:
6+
torch_inductor_caching:
7+
torch_inductor_cache_dir: "/home/ubuntu/serve/examples/pt2/torch_inductor_caching/cache"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
minWorkers: 4
2+
maxWorkers: 4
3+
responseTimeout: 600
4+
pt2 : {backend: inductor, mode: max-autotune}
5+
handler:
6+
torch_inductor_caching:
7+
torch_inductor_fx_graph_cache: true

ts_scripts/spellcheck_conf/wordlist.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,8 @@ BabyLlamaHandler
11751175
CMakeLists
11761176
TorchScriptHandler
11771177
libllamacpp
1178+
USERID
1179+
torchinductor
11781180
libtorch
11791181
Andrej
11801182
Karpathy's

0 commit comments

Comments
 (0)