Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linear to be consumed by nvFuser by default #1371

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

IvanYashchuk
Copy link
Collaborator

This change lowers peak memory usage of LitGPT implementations that use mlp_class_name="GptNeoxMLP" configuration (#1175, #1233, #246).

config (mlp_class_name) Before This PR
stablecode-completion-alpha-3b (GptNeoxMLP) 889.68 ms | 77.02 GB 892.37 ms | 74.33 GB
Llama-2-7b-hf (LLaMAMLP) 336.06 ms | 64.22 GB 340.86 ms | 64.18 GB

Better memory usage comes from simplifying the setup for Thunder's fusion rematerialization. With this change, there are fewer "producer" fusions.

cc @Priya2698, @wujingyue

@wujingyue
Copy link
Collaborator

There are multiple concerning CI failures:

  1. FAILED thunder/tests/test_nvfuser.py::test_cse_rematerialization_nvfuser_cuda_None - assert 3 == 11 is likely due to a behavior change in rematerialization that should be reflected as well in the test.
  2. FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama1-like] - RuntimeError: !detect_exception_in_thread_pool.load() INTERNAL ASSERT FAILED at "/workspace/Fuser/csrc/kernel_cache.cpp":1234, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below. sounds like a bug in nvFuser.
Error from segmentation group 4: producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/workspace/Fuser/csrc/device_lower/analysis/sync_information.cpp":699, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV30 (T30_l[ iblockIdx.x284{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * 16 ), 128) ), 1) ), 1) )}, iUS285{1}, iS283{1}, ithreadIdx.x281{128} ]) and TV25(T25_l[ iblockIdx.x232{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * ( 1 * 16 ) ) ), 128) ), 1) ), 1) )}, iUS233{1}, iS231{1}, ithreadIdx.x229{128} ] ca_pos( 4 )). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.x threadIdx.x)

@Priya2698
Copy link
Collaborator

We did not run Thunder benchmarks using nvfuser linear. Should we run other benchmarks as well before enabling it by default?

Additionally, @wujingyue needed to remove support for 1D weights to facilitate DID-aware execution. We might have to add an additional check on Thunder side or use unsqueeze-squeeze operators to support it. I believe we currently do not have cases exercising 1D weights in Thunder so this should not break anything right now.

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orthogonal to this PR but related to nvfuser knobs: are there enable_foobars other than linear and matmul?

@IvanYashchuk
Copy link
Collaborator Author

are there enable_foobars other than linear and matmul?

Yes, there's also nv_enable_sdpa. I didn't enable matmul in this PR because it leads to excessive memory usage in my tests. I haven't tried nv_enable_sdpa I thin we want to continue using the one from cuDNN for performance.

@wujingyue
Copy link
Collaborator

wujingyue commented Nov 16, 2024

CI with torch-nightly is now passing with NVIDIA/Fuser#3369 fixed. CI with older versions of torch (and therefore older versions of nvFuser) still fail, because we can't fix a past version. @IvanYashchuk, can you up nvFuser's version and enable linear only for that?

@IvanYashchuk
Copy link
Collaborator Author

Yes, I will do that.

@wujingyue
Copy link
Collaborator

I'm unsure about the lightning-thunder (ipynb) test. Is that an infra failure or indeed a regression?

@IvanYashchuk
Copy link
Collaborator Author

For some reason, there was a timeout in the cell execution. Let's try rerunning the failed job.

@IvanYashchuk
Copy link
Collaborator Author

@kiya00, could you please check what's going on with the Notebooks CI job?

@kiya00
Copy link
Collaborator

kiya00 commented Nov 28, 2024

Env:
pjnl-20241127 and this PR
nvfuser 0.2.23+git4c2ea06

When use 30 or more layers of linear+relu, the nvFusion0 becomes slow, so the notebook runs out of time
Here is the repro script using Thunder (nv_enable_linear=True, it takes about 219s; nv_enable_linear=False, it takes about 8s)

import torch
import thunder

class MySimpleModel(torch.nn.Module):
    def __init__(self, n_layers=10):
        super().__init__()
        self.fcs = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(n_layers)])

    def forward(self, x):
        for fc in self.fcs:
            x = torch.nn.functional.relu(fc(x))
        
        return x

def get_model_and_args():
    device = 'cuda'
    model = MySimpleModel(n_layers=30).to(device)
    args = (torch.randn(128, 16, device=device),)
    kwargs = {}
    return model, args, kwargs

model, args, kwargs = get_model_and_args()

# Check against the vanilla `thunder.jit` model
jfun = thunder.jit(model, nv_enable_linear=True)
import time
st=time.time()
expected = jfun(*args, **kwargs)
print("time:", time.time()-st)

the nvfuser repro script I saved from

return fd.execute(args, **kwargs)

by print(fd.repro_script_for(args))

# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
#  1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git6d4cd3e
# cuda version: 12.8
# nvfuser version: 0.2.23+git4c2ea06
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[128, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T4 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T8 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T9 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T10 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T11 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T12 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T13 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T14 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T15 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T16 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T17 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T18 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T19 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T20 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T21 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T22 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T23 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T24 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T25 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T26 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T27 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T28 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T29 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T30 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T31 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T32 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T33 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T34 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T35 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T36 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T37 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T38 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T39 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T40 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T41 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T42 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T43 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T44 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T45 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T46 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T47 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T48 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T49 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T50 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T51 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T52 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T53 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T54 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T55 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T56 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T57 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T58 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T59 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T60 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T61 = fd.ops.linear(T0, T1, T2)
    S62 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T63 = fd.ops.gt(T61, S62)
    S64 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T65 = fd.ops.where(T63, T61, S64)
    T66 = fd.ops.linear(T65, T3, T4)
    S67 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T68 = fd.ops.gt(T66, S67)
    S69 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T70 = fd.ops.where(T68, T66, S69)
    T71 = fd.ops.linear(T70, T5, T6)
    S72 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T73 = fd.ops.gt(T71, S72)
    S74 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T75 = fd.ops.where(T73, T71, S74)
    T76 = fd.ops.linear(T75, T7, T8)
    S77 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T78 = fd.ops.gt(T76, S77)
    S79 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T80 = fd.ops.where(T78, T76, S79)
    T81 = fd.ops.linear(T80, T9, T10)
    S82 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T83 = fd.ops.gt(T81, S82)
    S84 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T85 = fd.ops.where(T83, T81, S84)
    T86 = fd.ops.linear(T85, T11, T12)
    S87 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T88 = fd.ops.gt(T86, S87)
    S89 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T90 = fd.ops.where(T88, T86, S89)
    T91 = fd.ops.linear(T90, T13, T14)
    S92 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T93 = fd.ops.gt(T91, S92)
    S94 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T95 = fd.ops.where(T93, T91, S94)
    T96 = fd.ops.linear(T95, T15, T16)
    S97 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T98 = fd.ops.gt(T96, S97)
    S99 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T100 = fd.ops.where(T98, T96, S99)
    T101 = fd.ops.linear(T100, T17, T18)
    S102 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T103 = fd.ops.gt(T101, S102)
    S104 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T105 = fd.ops.where(T103, T101, S104)
    T106 = fd.ops.linear(T105, T19, T20)
    S107 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T108 = fd.ops.gt(T106, S107)
    S109 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T110 = fd.ops.where(T108, T106, S109)
    T111 = fd.ops.linear(T110, T21, T22)
    S112 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T113 = fd.ops.gt(T111, S112)
    S114 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T115 = fd.ops.where(T113, T111, S114)
    T116 = fd.ops.linear(T115, T23, T24)
    S117 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T118 = fd.ops.gt(T116, S117)
    S119 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T120 = fd.ops.where(T118, T116, S119)
    T121 = fd.ops.linear(T120, T25, T26)
    S122 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T123 = fd.ops.gt(T121, S122)
    S124 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T125 = fd.ops.where(T123, T121, S124)
    T126 = fd.ops.linear(T125, T27, T28)
    S127 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T128 = fd.ops.gt(T126, S127)
    S129 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T130 = fd.ops.where(T128, T126, S129)
    T131 = fd.ops.linear(T130, T29, T30)
    S132 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T133 = fd.ops.gt(T131, S132)
    S134 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T135 = fd.ops.where(T133, T131, S134)
    T136 = fd.ops.linear(T135, T31, T32)
    S137 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T138 = fd.ops.gt(T136, S137)
    S139 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T140 = fd.ops.where(T138, T136, S139)
    T141 = fd.ops.linear(T140, T33, T34)
    S142 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T143 = fd.ops.gt(T141, S142)
    S144 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T145 = fd.ops.where(T143, T141, S144)
    T146 = fd.ops.linear(T145, T35, T36)
    S147 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T148 = fd.ops.gt(T146, S147)
    S149 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T150 = fd.ops.where(T148, T146, S149)
    T151 = fd.ops.linear(T150, T37, T38)
    S152 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T153 = fd.ops.gt(T151, S152)
    S154 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T155 = fd.ops.where(T153, T151, S154)
    T156 = fd.ops.linear(T155, T39, T40)
    S157 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T158 = fd.ops.gt(T156, S157)
    S159 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T160 = fd.ops.where(T158, T156, S159)
    T161 = fd.ops.linear(T160, T41, T42)
    S162 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T163 = fd.ops.gt(T161, S162)
    S164 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T165 = fd.ops.where(T163, T161, S164)
    T166 = fd.ops.linear(T165, T43, T44)
    S167 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T168 = fd.ops.gt(T166, S167)
    S169 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T170 = fd.ops.where(T168, T166, S169)
    T171 = fd.ops.linear(T170, T45, T46)
    S172 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T173 = fd.ops.gt(T171, S172)
    S174 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T175 = fd.ops.where(T173, T171, S174)
    T176 = fd.ops.linear(T175, T47, T48)
    S177 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T178 = fd.ops.gt(T176, S177)
    S179 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T180 = fd.ops.where(T178, T176, S179)
    T181 = fd.ops.linear(T180, T49, T50)
    S182 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T183 = fd.ops.gt(T181, S182)
    S184 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T185 = fd.ops.where(T183, T181, S184)
    T186 = fd.ops.linear(T185, T51, T52)
    S187 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T188 = fd.ops.gt(T186, S187)
    S189 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T190 = fd.ops.where(T188, T186, S189)
    T191 = fd.ops.linear(T190, T53, T54)
    S192 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T193 = fd.ops.gt(T191, S192)
    S194 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T195 = fd.ops.where(T193, T191, S194)
    T196 = fd.ops.linear(T195, T55, T56)
    S197 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T198 = fd.ops.gt(T196, S197)
    S199 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T200 = fd.ops.where(T198, T196, S199)
    T201 = fd.ops.linear(T200, T57, T58)
    S202 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T203 = fd.ops.gt(T201, S202)
    S204 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T205 = fd.ops.where(T203, T201, S204)
    T206 = fd.ops.linear(T205, T59, T60)
    S207 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T208 = fd.ops.gt(T206, S207)
    S209 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T210 = fd.ops.where(T208, T206, S209)
    fd.add_output(T63)
    fd.add_output(T65)
    fd.add_output(T68)
    fd.add_output(T70)
    fd.add_output(T73)
    fd.add_output(T75)
    fd.add_output(T78)
    fd.add_output(T80)
    fd.add_output(T83)
    fd.add_output(T85)
    fd.add_output(T88)
    fd.add_output(T90)
    fd.add_output(T93)
    fd.add_output(T95)
    fd.add_output(T98)
    fd.add_output(T100)
    fd.add_output(T103)
    fd.add_output(T105)
    fd.add_output(T108)
    fd.add_output(T110)
    fd.add_output(T113)
    fd.add_output(T115)
    fd.add_output(T118)
    fd.add_output(T120)
    fd.add_output(T123)
    fd.add_output(T125)
    fd.add_output(T128)
    fd.add_output(T130)
    fd.add_output(T133)
    fd.add_output(T135)
    fd.add_output(T138)
    fd.add_output(T140)
    fd.add_output(T143)
    fd.add_output(T145)
    fd.add_output(T148)
    fd.add_output(T150)
    fd.add_output(T153)
    fd.add_output(T155)
    fd.add_output(T158)
    fd.add_output(T160)
    fd.add_output(T163)
    fd.add_output(T165)
    fd.add_output(T168)
    fd.add_output(T170)
    fd.add_output(T173)
    fd.add_output(T175)
    fd.add_output(T178)
    fd.add_output(T180)
    fd.add_output(T183)
    fd.add_output(T185)
    fd.add_output(T188)
    fd.add_output(T190)
    fd.add_output(T193)
    fd.add_output(T195)
    fd.add_output(T198)
    fd.add_output(T200)
    fd.add_output(T203)
    fd.add_output(T205)
    fd.add_output(T208)
    fd.add_output(T210)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((128, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
]
import time
st=time.time()
fd.execute(inputs)
print(time.time()-st)

cc: @IvanYashchuk @wujingyue

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this PR, but if this means needing to add correctness exceptions to half a dozen models, and the savings on Llama are marginal, we should not do this by default and instead use recipes to enable it where desired.

@mruberry
Copy link
Collaborator

mruberry commented Dec 3, 2024

Thank you for this PR, but if this means needing to add correctness exceptions to half a dozen models, and the savings on Llama are marginal, we should not do this by default and instead use recipes to enable it where desired.

Which correctness exceptions are you referring to, @t-vi? I agree we can't enable something by default if it causes timeouts in some models, but that's just a speed motivation.

@kiya00 Is there a bug for the slowness issue? @kevinstephano, do you think it's worth investigating?

@kshitij12345, @IvanYashchuk, @jjsjann123, should we consider requiring nv fusions have enough nodes to be valuable (2+?) and not so many nodes they might be slow (< 30?). Maybe we could develop better heuristics in the future

@kiya00
Copy link
Collaborator

kiya00 commented Dec 3, 2024

@kiya00 Is there a bug for the slowness issue? @kevinstephano, do you think it's worth investigating?

Yes, this is the issue Kevin created for nvfuser #1490 (comment)

@kevinstephano
Copy link
Collaborator

kevinstephano commented Dec 3, 2024

@kiya00 Is there a bug for the slowness issue? @kevinstephano, do you think it's worth investigating?

While there is a real issue, between 20 to 30 matmuls in a fusion that leads to an explosion in segmentation time, I am not sure how often we are going to see so many matmuls in one fusion. It would be interesting to see if this comes up in actual models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants