Skip to content

Commit

Permalink
[aoti] Add cpp loader changes
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Sep 13, 2024
1 parent dd34475 commit adb33fb
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 59 deletions.
21 changes: 7 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy an
[skip default]: end

### Server
This mode exposes a REST API for interacting with a model.
This mode exposes a REST API for interacting with a model.
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.

To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
Expand Down Expand Up @@ -255,14 +255,14 @@ Use the "Max Response Tokens" slider to limit the maximum number of tokens gener
## Desktop/Server Execution

### AOTI (AOT Inductor)
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`)
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a zipped PT2 file containing all the artifacts generated by AOTInductor, and a [.so](https://en.wikipedia.org/wiki/Shared_library) file with the runnable contents
that is then loaded for inference. This can be done with both Python and C++ enviroments.

The following example exports and executes the Llama3.1 8B Instruct
model. The first command compiles and performs the actual export.

```bash
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts.pt2
```

> [!NOTE]
Expand All @@ -274,12 +274,11 @@ case visit our [customization guide](docs/model_customization.md).

### Run in a Python Enviroment

To run in a python enviroment, use the generate subcommand like before, but include the dso file.
To run in a python enviroment, use the generate subcommand like before, but include the pt2 file.

```bash
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
python3 torchchat.py generate llama3.1 --aoti-package-path exportedModels/llama3_1_artifacts.pt2 --prompt "Hello my name is"
```
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.


### Run using our C++ Runner
Expand All @@ -289,17 +288,11 @@ To run in a C++ enviroment, we need to build the runner binary.
torchchat/utils/scripts/build_native.sh aoti
```

To compile the AOTI generated artifacts into a `.so`:
Then run the compiled executable, with the pt2.
```bash
make -C exportedModels/llama3_1_artifacts
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
```

Then run the compiled executable, with the compiled DSO.
```bash
cmake-out/aoti_run exportedModels/llama3_1_artifacts/llama3_1_artifacts.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
```
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.

## Mobile Execution

[ExecuTorch](https://github.com/pytorch/executorch) enables you to optimize your model for execution on a
Expand Down
6 changes: 3 additions & 3 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ fi
# NOTE: If a newly-fetched version of the executorch repo changes the value of
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
# package versions.
PYTORCH_NIGHTLY_VERSION=dev20240814
PYTORCH_NIGHTLY_VERSION=dev20240913

# Nightly version for torchvision
VISION_NIGHTLY_VERSION=dev20240814
VISION_NIGHTLY_VERSION=dev20240913

# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20240910
Expand All @@ -74,7 +74,7 @@ fi

# pip packages needed by exir.
REQUIREMENTS_TO_INSTALL=(
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
)
Expand Down
19 changes: 4 additions & 15 deletions runner/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree.
#endif

#ifdef __AOTI_MODEL__
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
torch::Device aoti_device(torch::kCPU);

#else // __ET_MODEL__
Expand Down Expand Up @@ -93,7 +90,7 @@ typedef struct {
RunState state; // buffers for the "wave" of activations in the forward pass

#ifdef __AOTI_MODEL__
torch::inductor::AOTIModelContainerRunner* runner;
torch::inductor::AOTIModelPackageLoader* runner;
#else // __ET_MODEL__
Module* runner;
#endif
Expand Down Expand Up @@ -143,16 +140,8 @@ void build_transformer(
malloc_run_state(&t->state, &t->config);

#ifdef __AOTI_MODEL__
#ifdef USE_CUDA
if (aoti_device.type() == torch::kCUDA) {
t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path);
aoti_device = torch::Device(torch::kCUDA);
} else {
#else
{
#endif
t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path);
}
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
#else //__ET_MODEL__
t->runner = new Module(
/* path to PTE model */ model_path,
Expand Down
17 changes: 10 additions & 7 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,22 @@ def __post_init__(self):
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
)

if self.pte_path and self.aoti_package_path:
if self.aoti_package_path and self.pte_path:
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")

if self.checkpoint_path and (self.pte_path or self.aoti_package_path):
if self.checkpoint_path and (self.aoti_package_path or self.pte_path):
print(
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
)
if self.checkpoint_dir and (self.pte_path or self.aoti_package_path):
if self.checkpoint_dir and (self.aoti_package_path or self.pte_path):
print(
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
)
if self.gguf_path and (self.pte_path or self.aoti_package_path):
if self.gguf_path and (self.aoti_package_path or self.pte_path):
print(
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
)
if not (self.dso_path) and not (self.aoti_package_path):
if not (self.aoti_package_path) and not (self.pte_path):
self.prefill_possible = True

@classmethod
Expand Down Expand Up @@ -533,9 +533,12 @@ def _initialize_model(
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
from torch._inductor.package import load_package
model.forward = load_package(
str(builder_args.aoti_package_path.absolute()), builder_args.device
aoti_compiled_model = load_package(
str(builder_args.aoti_package_path.absolute())
)
model.forward = aoti_compiled_model
metadata = aoti_compiled_model.get_metadata()
builder_args.device = metadata["AOTI_DEVICE_KEY"]
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")

Expand Down
28 changes: 19 additions & 9 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn

from torch.export import Dim
import torch._inductor

from torchchat.cli.builder import (
_initialize_model,
Expand Down Expand Up @@ -38,7 +39,6 @@ def export_for_server(
output_path: str = "model.pt2",
dynamic_shapes: bool = False,
package: bool = True,
model_key: str = "",
) -> str:
"""
Export the model using AOT Compile to get a .dso for server use cases.
Expand Down Expand Up @@ -67,16 +67,27 @@ def export_for_server(
dynamic_shapes = None

with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
path = torch._export.aot_compile(
metadata = {} # TODO: put more metadata here
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
if not package:
options = {"aot_inductor.output_path": output_path}

ep = torch.export.export(
model,
args=input,
options={
"aot_inductor.output_path": output_path,
"aot_inductor.package": package,
},
input,
dynamic_shapes=dynamic_shapes,
)
print(f"The generated DSO model can be found at: {path}")
path = torch._inductor.aot_compile(
ep.module(),
input,
options=options,
)

if package:
from torch._inductor.package import package_aoti
path = package_aoti(output_path, path)

print(f"The generated packaged model can be found at: {path}")
return path


Expand Down Expand Up @@ -439,5 +450,4 @@ def main(args):
output_aoti_package_path,
builder_args.dynamic_shapes,
package=True,
model_key=builder_args.params_table,
)
14 changes: 3 additions & 11 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def validate_build(
reason = "model compilation"
if builder_args.aoti_package_path:
model_type = "PT2"
if builder_args.dso_path:
model_type = "DSO"
if builder_args.pte_path:
model_type = "PTE"
if model_type and reason:
Expand All @@ -148,7 +150,7 @@ def from_args(cls, args):
pte_path = getattr(args, "pte_path", None)
aoti_package_path = getattr(args, "aoti_package_path", None)
sequential_prefill = (
args.sequential_prefill or bool(aoti_package_path) or bool(pte_path)
args.sequential_prefill or bool(aoti_package_path) or bool(pte_path) or bool(dso_path)
)

return cls(
Expand Down Expand Up @@ -951,13 +953,3 @@ def main(args):
torch.cuda.reset_peak_memory_stats()
for _ in gen.chat(generator_args):
pass


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="torchchat generate CLI")
verb = "generate"
add_arguments_for_verb(parser, verb)
args = parser.parse_args()
check_args(args, verb)
args = arg_init(args)
main(args)

0 comments on commit adb33fb

Please sign in to comment.