Skip to content

Commit

Permalink
kwargs for torch.compile in BaseHandler (#2796)
Browse files Browse the repository at this point in the history
* arbitrary kwarg for torch.compile in BaseHandler

* add documentation

* add testing

* remove type annotations and lint

---------

Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
eballesteros and msaroufim authored Nov 16, 2023
1 parent a00972b commit a8ca657
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 27 deletions.
12 changes: 9 additions & 3 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@ pip install torchserve torch-model-archiver

PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `model_config.yaml` during your model packaging

`pt2: "inductor"`
```yaml
pt2: "inductor"
```
You can also pass a dictionary with compile options if you need more control over torch.compile:
```yaml
pt2 : {backend: inductor, mode: reduce-overhead}
```
As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file

Expand Down Expand Up @@ -99,5 +107,3 @@ print(extra_files['foo.txt'])
# from inference()
print(ep(torch.randn(5)))
```


1 change: 1 addition & 0 deletions test/pytest/test_data/torch_compile/pt2_dict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : {backend: inductor, mode: reduce-overhead}
64 changes: 45 additions & 19 deletions test/pytest/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

MODEL_FILE = os.path.join(TEST_DATA_DIR, "model.py")
HANDLER_FILE = os.path.join(TEST_DATA_DIR, "compile_handler.py")
YAML_CONFIG = os.path.join(TEST_DATA_DIR, "pt2.yaml")
YAML_CONFIG_STR = os.path.join(TEST_DATA_DIR, "pt2.yaml") # backend as string
YAML_CONFIG_DICT = os.path.join(TEST_DATA_DIR, "pt2_dict.yaml") # arbitrary kwargs dict


SERIALIZED_FILE = os.path.join(TEST_DATA_DIR, "model.pt")
Expand All @@ -41,19 +42,32 @@ def teardown_class(self):

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(YAML_CONFIG)) == 1
assert len(glob.glob(YAML_CONFIG_STR)) == 1
assert len(glob.glob(YAML_CONFIG_DICT)) == 1
subprocess.run(f"cd {TEST_DATA_DIR} && python model.py", shell=True, check=True)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)

# register 2 models, one with the backend as str config, the other with the kwargs as dict config
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME}_str --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG_STR} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
shell=True,
check=True,
)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
f"torch-model-archiver --model-name {MODEL_NAME}_dict --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG_DICT} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
shell=True,
check=True,
)
assert len(glob.glob(SERIALIZED_FILE)) == 1
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1
assert (
len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}_str.mar"))) == 1
)
assert (
len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}_dict.mar"))) == 1
)

def test_start_torchserve(self):
cmd = f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR}"
cmd = f"torchserve --start --ncs --models {MODEL_NAME}_str.mar,{MODEL_NAME}_dict.mar --model-store {MODEL_STORE_DIR}"
subprocess.run(
cmd,
shell=True,
Expand Down Expand Up @@ -90,9 +104,16 @@ def test_registered_model(self):
capture_output=True,
check=True,
)
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}'
expected_registered_model = json.loads(expected_registered_model_str)
assert json.loads(result.stdout) == expected_registered_model

def _response_to_tuples(response_str):
models = json.loads(response_str)["models"]
return {(k, v) for d in models for k, v in d.items()}

# transform to set of tuples so order won't cause inequality
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two_str", "modelUrl": "half_plus_two_str.mar"}, {"modelName": "half_plus_two_dict", "modelUrl": "half_plus_two_dict.mar"}]}'
assert _response_to_tuples(result.stdout) == _response_to_tuples(
expected_registered_model_str
)

@pytest.mark.skipif(
os.environ.get("TS_RUN_IN_DOCKER", False),
Expand All @@ -103,20 +124,25 @@ def test_serve_inference(self):
request_data = {"instances": [[1.0], [2.0], [3.0]]}
request_json = json.dumps(request_data)

result = subprocess.run(
f"curl -s -X POST -H \"Content-Type: application/json;\" http://localhost:8080/predictions/half_plus_two -d '{request_json}'",
shell=True,
capture_output=True,
check=True,
)
for model_name in [f"{MODEL_NAME}_str", f"{MODEL_NAME}_dict"]:
result = subprocess.run(
f"curl -s -X POST -H \"Content-Type: application/json;\" http://localhost:8080/predictions/{model_name} -d '{request_json}'",
shell=True,
capture_output=True,
check=True,
)

string_result = result.stdout.decode("utf-8")
float_result = float(string_result)
expected_result = 3.5
string_result = result.stdout.decode("utf-8")
float_result = float(string_result)
expected_result = 3.5

assert float_result == expected_result
assert float_result == expected_result

model_log_path = glob.glob("logs/model_log.log")[0]
with open(model_log_path, "rt") as model_log_file:
model_log = model_log_file.read()
assert "Compiled model with backend inductor" in model_log
assert "Compiled model with backend inductor\n" in model_log
assert (
"Compiled model with backend inductor, mode reduce-overhead"
in model_log
)
27 changes: 22 additions & 5 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,23 +184,40 @@ def initialize(self, context):
raise RuntimeError("No model weights could be loaded")

if hasattr(self, "model_yaml_config") and "pt2" in self.model_yaml_config:
pt2_backend = self.model_yaml_config["pt2"]
valid_backend = check_valid_pt2_backend(pt2_backend)
pt2_value = self.model_yaml_config["pt2"]

# pt2_value can be the backend, passed as a str, or arbitrary kwargs, passed as a dict
if isinstance(pt2_value, str):
compile_options = dict(backend=pt2_value)
elif isinstance(pt2_value, dict):
compile_options = pt2_value
else:
raise ValueError("pt2 should be str or dict")

# if backend is not provided, compile will use its default, which is valid
valid_backend = (
check_valid_pt2_backend(compile_options["backend"])
if "backend" in compile_options
else True
)
else:
valid_backend = False

# PT 2.0 support is opt in
if PT2_AVAILABLE and valid_backend:
compile_options_str = ", ".join(
[f"{k} {v}" for k, v in compile_options.items()]
)
# Compilation will delay your model initialization
try:
self.model = torch.compile(
self.model,
backend=pt2_backend,
**compile_options,
)
logger.info(f"Compiled model with backend {pt2_backend}")
logger.info(f"Compiled model with {compile_options_str}")
except Exception as e:
logger.warning(
f"Compiling model model with backend {pt2_backend} has failed \n Proceeding without compilation"
f"Compiling model model with {compile_options_str} has failed \n Proceeding without compilation"
)
logger.warning(e)

Expand Down

0 comments on commit a8ca657

Please sign in to comment.