Skip to content

Commit

Permalink
Update on "[rfc][dynamo] "skip_guard_eval" stance for power users"
Browse files Browse the repository at this point in the history
# Motivation

We have spent quite some time this year on improving guard performance and soundness. Nevertheless, guards STILL take time. We have seen multiple requests/evidences from power users where they want to have almost 0% guard overhead. First, we saw this in vLLM where even 1% overhead is bad. And recently we saw this in hqq (low precision LLM
generation) - #138386. To put some numbers for perspective, low precision LLM inference reaches around 250 tokens/second, i.e, each token takes a mere 4 milliseconds. If guard overhead is even 200 us, its still 5% overhead in total.

Here, users ask - "we can guarantee that there will no more recompilations in the steady state, give us the lowest guard overhead"

# Design

A must-have consideration is to support fast inference where the model has recompiled, i.e., has multiple cache entries for a code object (could be because of dynamism, or just tensor dtype change in the case of hqq). So, we still have to run the guards to figure out which compiled graph to run.

What we need is the "minimal set of differentiating guards" - i.e. minimals set of guards that we can run to choose the compiled graph. Note that this works ONLY with the assumption that users really guarantee no more recompilation scenarios (no more mutations, no more dynamism after the model has been warmed up). It is possible that if user violates this assumption, and it is not covered by the diff guard set, we will choose a wrong compiled graph to run.

When we designed C++ guards, Ed and Voz suggested to use Trie-structure to directly represent this "diff guard set". But due to complexity, we went for tree structure and relied on a GuardManager state - "fail_count" - to fail fast. I realized that we can rely on this "fail_count" to find the diff guard set.

If we recompile, this means that all the cache line guard eval check_fns have failed. Whenever a guard check_fn fails, we increment the counter in the failing node (and propagate it to the root node) to do faster fail next time. If we want to run the "guard diff set", we just have to run only those nodes in the tree which have "fail_count > 0". 

This PR relies on this observation to introduce a new stance - "skip_guard_eval". The idea is that user will warm up their model with torch.compile, and the run the steady state with this stance. This stance go through the existing cache lines for the intercepted code object but only runs the diff guard set. This dramatically reduces the guard overhead. In case, all guards fail, we fall back to eager (however if this happens then user is violating the assumption, so we should perhaps hard error, I need to fix some silly issue from _dynamo.disable to hard error here).

A bonus point here is  that this "theoretically" works with graph breaks as well. But, I need more testing to convince myself about this.

# Evaluation

I tried the hqq model in #138386. With very small changes in the user code ([hqq PR](mobiusml/hqq#127)), I see the throughput increase from **160 tokens/sec to 174 tokens/sec**.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed Oct 28, 2024
2 parents 6128b69 + 05cddd3 commit 37fec86
Show file tree
Hide file tree
Showing 520 changed files with 7,029 additions and 4,723 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/executorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ca4783992ed7602a39528ba304d61f00396b2a5a
16b633b4daa7f3d3442be62a3589bd60b2f7fdc7
5 changes: 5 additions & 0 deletions .ci/docker/libtorch/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ RUN bash ./install_cuda.sh 12.4
RUN bash ./install_magma.sh 12.4
RUN ln -sf /usr/local/cuda-12.4 /usr/local/cuda

FROM cuda as cuda12.6
RUN bash ./install_cuda.sh 12.6
RUN bash ./install_magma.sh 12.6
RUN ln -sf /usr/local/cuda-12.6 /usr/local/cuda

FROM cpu as rocm
ARG PYTORCH_ROCM_ARCH
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
Expand Down
4 changes: 2 additions & 2 deletions .ci/docker/requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ tb-nightly==2.13.0a20230426
#test that import:

# needed by torchgen utils
typing-extensions
typing-extensions>=4.10.0
#Description: type hints for python
#Pinned versions:
#test that import:
Expand Down Expand Up @@ -331,7 +331,7 @@ sympy==1.13.1 ; python_version >= "3.9"
#Pinned versions:
#test that import:

onnx==1.16.1
onnx==1.17.0
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/audio.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
79047bf6bdec9e32c4cffd0f9835b347781fefbf
fa44bdab1fe49bab58389e7b6a33061ffced9bc7
2 changes: 1 addition & 1 deletion .github/workflows/build-libtorch-images.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
strategy:
matrix:
cuda_version: ["12.4", "12.1", "11.8"]
cuda_version: ["12.6", "12.4", "12.1", "11.8"]
env:
GPU_ARCH_TYPE: cuda
GPU_ARCH_VERSION: ${{ matrix.cuda_version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-manywheel-images.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
strategy:
matrix:
cuda_version: ["12.4", "12.1", "11.8"]
cuda_version: ["12.6", "12.4", "12.1", "11.8"]
env:
GPU_ARCH_TYPE: cuda
GPU_ARCH_VERSION: ${{ matrix.cuda_version }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-cu124.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-micro-benchmark-x86.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ permissions: read-all

jobs:
linux-jammy-cpu-py3_9-gcc11-inductor-build:
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
name: linux-jammy-cpu-py3.9-gcc11-inductor
uses: ./.github/workflows/_linux-build.yml
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-micro-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-perf-compare.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
get-test-label-type:
name: get-test-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-perf-test-nightly-a10g.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-perf-test-nightly-aarch64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-perf-test-nightly-x86.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-perf-test-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-periodic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor-rocm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/inductor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/periodic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
37 changes: 21 additions & 16 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand All @@ -53,10 +54,11 @@ jobs:
docker-image-name: pytorch-linux-jammy-py3.9-gcc11
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
Expand Down Expand Up @@ -185,10 +187,11 @@ jobs:
docker-image-name: pytorch-linux-focal-py3.9-clang10
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
Expand Down Expand Up @@ -217,10 +220,11 @@ jobs:
docker-image-name: pytorch-linux-focal-py3.11-clang10
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
Expand Down Expand Up @@ -251,10 +255,11 @@ jobs:
docker-image-name: pytorch-linux-focal-py3.12-clang10
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
{ config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/rocm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
contents: read

linux-focal-rocm6_2-py3_10-build:
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_linux-build.yml
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/slow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
Expand Down
5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ The following packages should be installed with either `conda` or `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `mypy` - recommended for linting
- `pytest` - recommended to run tests more selectively
Running
```
pip install -r requirements
```
will install these dependencies for you.

All PyTorch test suites are located in the `test` folder and start with
`test_`. Run the entire test
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER)
endif()
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})

file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec128/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp")
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h")
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp")
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/Parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ TORCH_API std::string get_parallel_info();
TORCH_API void set_num_interop_threads(int);

// Returns the number of threads used for inter-op parallelism
TORCH_API int get_num_interop_threads();
TORCH_API size_t get_num_interop_threads();

// Launches inter-op parallel task
TORCH_API void launch(std::function<void()> func);
Expand All @@ -142,7 +142,7 @@ void launch_no_thread_state(std::function<void()> fn);
} // namespace internal

// Launches intra-op parallel task
TORCH_API void intraop_launch(std::function<void()> func);
TORCH_API void intraop_launch(const std::function<void()>& func);

// Returns number of intra-op threads used by default
TORCH_API int intraop_default_num_threads();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/ParallelFuture.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ namespace at {

// Launches intra-op parallel task, returns a future
TORCH_API c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
std::function<void()> func);
const std::function<void()>& func);

} // namespace at
6 changes: 3 additions & 3 deletions aten/src/ATen/ParallelNative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ bool in_parallel_region() {
#endif // C10_MOBILE
}

void intraop_launch(std::function<void()> func) {
void intraop_launch(const std::function<void()>& func) {
#ifndef C10_MOBILE
if (!in_parallel_region() && get_num_threads() > 1) {
_get_intraop_pool().run(std::move(func));
_get_intraop_pool().run(func);
} else {
// execute inline if we're in parallel region
func();
Expand All @@ -289,7 +289,7 @@ void intraop_launch(std::function<void()> func) {
}

c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
std::function<void()> func) {
const std::function<void()>& func) {
#ifndef C10_MOBILE
auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
if (!in_parallel_region() && get_num_threads() > 1) {
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/ParallelOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

namespace at {
#if AT_MKLDNN_ENABLED()
namespace native { namespace mkldnn {
namespace native::mkldnn {
// NOLINTNEXTLINE(misc-use-internal-linkage)
void clear_computation_cache();
}} // namespace native::mkldnn
} // namespace native::mkldnn
#endif

namespace {
Expand Down Expand Up @@ -100,13 +101,13 @@ bool in_parallel_region() {
#endif
}

void intraop_launch(std::function<void()> func) {
void intraop_launch(const std::function<void()>& func) {
// execute inline in openmp case
func();
}

c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
std::function<void()> func) {
const std::function<void()>& func) {
func();
auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
future->markCompleted();
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/ParallelThreadPoolNative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void set_num_interop_threads(int nthreads) {
"has started or set_num_interop_threads called");
}

int get_num_interop_threads() {
size_t get_num_interop_threads() {
at::internal::lazy_init_num_threads();
int nthreads = num_interop_threads.load();
if (nthreads > 0) {
Expand All @@ -82,7 +82,7 @@ void launch_no_thread_state(std::function<void()> fn) {
void launch(std::function<void()> func) {
// NOLINTNEXTLINE(modernize-avoid-bind)
internal::launch_no_thread_state(std::bind([](
std::function<void()> f, ThreadLocalState thread_locals) {
const std::function<void()>& f, const ThreadLocalState& thread_locals) {
ThreadLocalStateGuard guard(thread_locals);
f();
},
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1483,8 +1483,6 @@ FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorCo
return FastSetupType::NONE;
}

TensorIteratorBase::TensorIteratorBase() = default;

void TensorIteratorBase::build(TensorIteratorConfig& config) {
// populate some persistent configuration fields
is_reduction_ = config.is_reduction_;
Expand Down
Loading

0 comments on commit 37fec86

Please sign in to comment.