Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses torch._inductor.compile instead of torch.compile in benchmark script to avoid segmentation by Dynamo #1540

Merged
merged 2 commits into from
Dec 11, 2024

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Dec 11, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #1521 .

The benchmark script:

# NOTE: This script requires `pytest-benchmark==4.0.0` to be installed.
# To execute the script, run `pytest graph0_thunder_0.py --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on`
# To check the peak allocated CUDA memory, use --benchmark-json=json_file_name and look at the "max_allocated_memory_MB" field in the json file
from math import inf
from math import nan
NoneType = type(None)
import torch
from torch import device
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from functools import partial
import thunder
from thunder.transforms.cudagraph import CUDAGraphTransform
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform
import pytest

# NOTE: The reproducer function has already been processed by TorchDynamo.
# If we let it go through TorchDynamo again, it could be segmented further.
# To avoid this, we directly use Inductor here.
def torch_inductor(fn, inputs):
    from torch._inductor import compile as inductor_compile
    from torch.fx import symbolic_trace

    fx_graph = symbolic_trace(fn)
    return inductor_compile(fx_graph, inputs)

bench_executors_dict = {}
bench_executors_dict["thunder"]=partial(thunder.jit, transforms=[thunder.dev_utils.nvtx_profile_transform.NvtxProfileTransform(), thunder.transforms.cudagraph.CUDAGraphTransform()],executors=[thunder.extend.get_executor('nvfuser')],cache='constant values',langctx=None,record_history=False,)
bench_executors_dict["torch_inductor"]=torch_inductor
bench_executors_dict["eager"]=None
bench_executors_dict["thunder_cugraph"]=partial(thunder.jit, transform=CUDAGraphTransform())

executors = list(bench_executors_dict.values())
executor_ids = list(bench_executors_dict.keys())

@pytest.mark.parametrize(
    "executor,",
    executors,
    ids=executor_ids,
)
def test_graph0_thunder_0(benchmark, executor):
    class DynamoModule(torch.nn.Module):
      def forward(self, l_x_ : torch.Tensor):
          x = torch.sin(l_x_);  l_x_ = None
          sum_1 = x.sum()
          gt = sum_1 > 0;  sum_1 = None
          return (x, gt)

    inputs = [
        torch.testing.make_tensor((31,), dtype=torch.int64,  device='cuda:0', requires_grad=False, low=3, high=9,).as_strided((4, 4), (8, 2)),

    ]

    mod = DynamoModule()
    if executor == None:
        compiled = mod
    elif executor == torch_inductor:
        compiled = executor(mod, inputs)
    else:
        compiled = executor(mod)
    from thunder.benchmarks import record_peak_allocated_memory

    with record_peak_allocated_memory(benchmark):
        benchmark(compiled, *inputs)

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @kiya00.

Just to check my understanding, I assume this patch is tested by following tests, is that right?

def test_dynamo_reproducer_2graph(executor, device: str, dtype: dtypes.dtype, use_pytest_benchmark, tmp_path):

def test_dynamo_reproducer_submodules(use_pytest_benchmark, tmp_path):

thunder/dynamo/utils.py Show resolved Hide resolved
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry mruberry enabled auto-merge (squash) December 11, 2024 19:36
@mruberry mruberry merged commit f89ceca into main Dec 11, 2024
41 checks passed
@mruberry mruberry deleted the fix1521 branch December 11, 2024 19:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Repro function saved from FX graph is segmented again when passed back to torch.compile
3 participants