Skip to content

Commit

Permalink
Add user guide on how to use the compiler (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha authored Aug 17, 2024
1 parent 7705c34 commit 9e3b302
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 16 deletions.
138 changes: 138 additions & 0 deletions tripy/docs/pre0_user_guides/02-compiler.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Using the Compiler

```{contents} Table of Contents
:depth: 3
```

## Walk through model compilation and deployment

Let's walk through a simple example of a [GEGLU](https://arxiv.org/abs/2002.05202v1) module defined below:

```py
# doc: no-print-locals
class GEGLU(tp.Module):
def __init__(self, dim_in, dim_out):
self.proj = tp.Linear(dim_in, dim_out * 2)
self.dim_out = dim_out

def __call__(self, x):
proj = self.proj(x)
x, gate = tp.split(proj, 2, proj.rank - 1)
return x * tp.gelu(gate)
```

To run `GEGLU` in eager mode:

```py
# doc: no-print-locals
layer = GEGLU(2, 8)
inp = tp.ones((1, 2))
out = layer(inp)
```

Now, let's try to optimize this model for inference using Tripy's {class}`tripy.Compiler`.

First, let's initialize the compiler with the module we want to compile, `layer`, which lets the compiler know its properties, like the function signature.

```py
# doc: no-print-locals
compiler = tp.Compiler(layer)
```
Next, we need to provide information about each input using {class}`tripy.InputInfo`. The first argument for `InputInfo` is `shape`, where we specify either the static or dynamic shape information for each dimension. In the example below, we assume the shape of `inp` is static (`(1, 2)`). The second argument specifies the `dtype` for the input:

```py
# doc: no-print-locals
inp_info = tp.InputInfo(shape=(1, 2), dtype=tp.float32)
```
Now, we can call the `compile` function to obtain a compiled function and use it for inference:

```py
# doc: no-print-locals
fast_geglu = compiler.compile(inp_info)
fast_geglu(inp).eval()
```

### Optimization profiles

In the example above, we assumed `inp` has a static shape of `(1, 2)`. Now, let’s assume that the shape of `inp` can vary from `(1, 2)` to `(16, 2)`, with `(8, 2)` being the shape we'd like to optimize for. To express this constraint to the compiler, we can provide the range of shapes to `InputInfo` using `shape=((1, 8, 16), 2)`. This indicates to the compiler that the first dimension can vary from 1 to 16, and it should optimize for a size of 8.

```py
# doc: print-locals out out_change_shape
inp_info = tp.InputInfo(shape=((1, 8, 16), 2), dtype=tp.float32)
fast_geglu = compiler.compile(inp_info)
out = fast_geglu(inp)

# Let's change the shape of input to (2, 2)
inp = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
out_change_shape = fast_geglu(inp)
```

### Errors for out of bounds inference
If we provide an input that does not comply with the dynamic shape constraint given to the compiler, `Tripy` will produce an error with relevant information:

<!-- Tripy: TEST: IGNORE Start -->
```py
# doc: allow-exception
inp = tp.ones((32, 2), dtype=tp.float32)
print(fast_geglu(inp))
```
<!-- Tripy: TEST: IGNORE End -->

### Serializing the executable to disk

A compiled executable can be serialized to disk and then used for deployment.

Saving an executable to disk:

```py
# doc: no-print-locals
import tempfile, os
temp_dir = tempfile.mkdtemp()
executable_file_path = os.path.join(temp_dir, "executable.json")
fast_geglu.save(executable_file_path)
```

Reading an executable and running inference:

```py
# doc: no-print-locals
inp = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
loaded_fast_geglu = tp.Executable.load(executable_file_path)
out = loaded_fast_geglu(inp)
os.remove(executable_file_path)
```

### Querying properties of the executable

You can also query properties about the executable:

```py
# doc: print-locals input_info output_info
input_info = loaded_fast_geglu.get_input_info()
output_info = loaded_fast_geglu.get_output_info()
```

## Common Pitfalls
### Expectations for pure functions

Tripy expects that the functions that will be compiled are pure functions with no side effects. Consider this example:

```py
# doc: print-locals out
def add_times_two(a, b):
c = a + b
print(f"c : {c}")
return c + a + b

compiler = tp.Compiler(add_times_two)
inp_info = tp.InputInfo(shape=(1, 2), dtype=tp.float32)
fast_myadd = compiler.compile(inp_info, inp_info)
a = tp.Tensor([[1.0, 2.0]], dtype=tp.float32)
b = tp.Tensor([[2.0, 3.0]], dtype=tp.float32)

out = fast_myadd(a, b)
```

You would expect that the output `out` is `[6.0, 10.0]` and `c` gets evaluated to `[3.0, 5.0]`.

This unexpected behavior occurs because Tripy uses dummy tensors filled with 1's during the compilation process. When the print(c) statement is encountered during compilation, it outputs the result of adding two tensors filled with 1's, which is `[2.0, 2.0]`. The compiled function then uses this dummy result in place of the actual c value, leading to incorrect calculations in the final output.
23 changes: 15 additions & 8 deletions tripy/tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,15 @@ def process_code_block_for_outputs_and_locals(
NO_EVAL = "# doc: no-eval"
NO_PRINT_LOCALS = "# doc: no-print-locals"
PRINT_LOCALS = "# doc: print-locals"
REMOVE_TAGS = [NO_PRINT_LOCALS, PRINT_LOCALS, NO_EVAL]
ALLOW_EXCEPTION = "# doc: allow-exception"
REMOVE_TAGS = [NO_PRINT_LOCALS, PRINT_LOCALS, NO_EVAL, ALLOW_EXCEPTION]
if strip_assertions:
REMOVE_TAGS.append("assert ")
OMIT_COMMENT = "# doc: omit"

should_append_locals = True
should_eval = True
allow_exception = False

# By default, we print all local variables. If `print_vars` it not empty,
# then we'll only print those that appear in it.
Expand All @@ -486,6 +488,9 @@ def process_code_block_for_outputs_and_locals(
if block_line.strip() == NO_EVAL:
should_eval = False

if block_line.strip() == ALLOW_EXCEPTION:
allow_exception = True

if block_line.strip().startswith(PRINT_LOCALS):
_, _, names = block_line.strip().partition(PRINT_LOCALS)
print_vars.update(names.strip().split(" "))
Expand All @@ -499,14 +504,16 @@ def process_code_block_for_outputs_and_locals(
return code_block_lines, local_var_lines, output_lines, local_vars

code = dedent(code)
try:
with capture_output() as outfile:

with capture_output() as outfile:
try:
code_locals = exec_code(code, local_vars)
except:
print(err_msg)
print(f"Note: Code example was:\n{code}")
print(outfile.read())
raise
except Exception as e:
if allow_exception:
print(f"Exception occurred: {str(e)}")
code_locals = local_vars
else:
raise

new_locals = {
key: value for key, value in code_locals.items() if key not in local_vars or value is not local_vars[key]
Expand Down
18 changes: 10 additions & 8 deletions tripy/tests/test_internal_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,19 @@ def test_python_code_snippets(code_blocks):
not any(block.has_marker("test: use_pytest") for block in code_blocks) or all_pytest
), f"This test does not currently support mixing blocks meant to be run with PyTest with blocks meant to be run by themselves!"

# We concatenate all the code together because most documentation includes code
# that is continued from previous code blocks.
# TODO: We can instead run the blocks individually and propagate the evaluated local variables like `generate_rsts.py` does.
code = "\n\n".join(map(str, code_blocks))
print(f"Checking code:\n{code}")

if all_pytest:
code = "\n\n".join(map(str, code_blocks))
f = tempfile.NamedTemporaryFile(mode="w+", suffix=".py")
f.write(code)
f.flush()

assert pytest.main([f.name, "-vv", "-s"]) == 0
else:
helper.exec_code(code)
code_locals = {}
for block in code_blocks:
print(f"Checking code block:\n{str(block)}")
try:
new_locals = helper.exec_code(str(block), code_locals)
# Update code_locals with new variables
code_locals.update(new_locals)
except Exception as e:
raise AssertionError(f"Error while executing code block: {str(e)}") from e

0 comments on commit 9e3b302

Please sign in to comment.