Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nvFuser linear fusion leads to notebook example timeout #1490

Open
IvanYashchuk opened this issue Nov 29, 2024 · 2 comments
Open

nvFuser linear fusion leads to notebook example timeout #1490

IvanYashchuk opened this issue Nov 29, 2024 · 2 comments
Assignees
Labels

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Nov 29, 2024

Originally posted by @kiya00 in #1371 (comment)

Env:
pjnl-20241127 and this PR
nvfuser 0.2.23+git4c2ea06

When use 30 or more layers of linear+relu, the nvFusion0 becomes slow, so the notebook runs out of time
Here is the repro script using Thunder (nv_enable_linear=True, it takes about 219s; nv_enable_linear=False, it takes about 8s)

import torch
import thunder

class MySimpleModel(torch.nn.Module):
    def __init__(self, n_layers=10):
        super().__init__()
        self.fcs = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(n_layers)])

    def forward(self, x):
        for fc in self.fcs:
            x = torch.nn.functional.relu(fc(x))
        
        return x

def get_model_and_args():
    device = 'cuda'
    model = MySimpleModel(n_layers=30).to(device)
    args = (torch.randn(128, 16, device=device),)
    kwargs = {}
    return model, args, kwargs

model, args, kwargs = get_model_and_args()

# Check against the vanilla `thunder.jit` model
jfun = thunder.jit(model, nv_enable_linear=True)
import time
st=time.time()
expected = jfun(*args, **kwargs)
print("time:", time.time()-st)

the nvfuser repro script I saved from

return fd.execute(args, **kwargs)

by print(fd.repro_script_for(args))

# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
#  1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git6d4cd3e
# cuda version: 12.8
# nvfuser version: 0.2.23+git4c2ea06
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[128, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T4 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T8 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T9 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T10 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T11 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T12 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T13 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T14 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T15 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T16 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T17 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T18 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T19 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T20 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T21 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T22 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T23 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T24 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T25 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T26 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T27 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T28 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T29 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T30 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T31 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T32 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T33 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T34 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T35 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T36 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T37 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T38 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T39 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T40 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T41 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T42 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T43 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T44 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T45 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T46 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T47 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T48 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T49 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T50 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T51 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T52 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T53 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T54 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T55 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T56 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T57 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T58 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T59 = fd.define_tensor(shape=[16, 16], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T60 = fd.define_tensor(shape=[16], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T61 = fd.ops.linear(T0, T1, T2)
    S62 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T63 = fd.ops.gt(T61, S62)
    S64 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T65 = fd.ops.where(T63, T61, S64)
    T66 = fd.ops.linear(T65, T3, T4)
    S67 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T68 = fd.ops.gt(T66, S67)
    S69 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T70 = fd.ops.where(T68, T66, S69)
    T71 = fd.ops.linear(T70, T5, T6)
    S72 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T73 = fd.ops.gt(T71, S72)
    S74 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T75 = fd.ops.where(T73, T71, S74)
    T76 = fd.ops.linear(T75, T7, T8)
    S77 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T78 = fd.ops.gt(T76, S77)
    S79 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T80 = fd.ops.where(T78, T76, S79)
    T81 = fd.ops.linear(T80, T9, T10)
    S82 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T83 = fd.ops.gt(T81, S82)
    S84 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T85 = fd.ops.where(T83, T81, S84)
    T86 = fd.ops.linear(T85, T11, T12)
    S87 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T88 = fd.ops.gt(T86, S87)
    S89 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T90 = fd.ops.where(T88, T86, S89)
    T91 = fd.ops.linear(T90, T13, T14)
    S92 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T93 = fd.ops.gt(T91, S92)
    S94 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T95 = fd.ops.where(T93, T91, S94)
    T96 = fd.ops.linear(T95, T15, T16)
    S97 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T98 = fd.ops.gt(T96, S97)
    S99 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T100 = fd.ops.where(T98, T96, S99)
    T101 = fd.ops.linear(T100, T17, T18)
    S102 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T103 = fd.ops.gt(T101, S102)
    S104 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T105 = fd.ops.where(T103, T101, S104)
    T106 = fd.ops.linear(T105, T19, T20)
    S107 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T108 = fd.ops.gt(T106, S107)
    S109 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T110 = fd.ops.where(T108, T106, S109)
    T111 = fd.ops.linear(T110, T21, T22)
    S112 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T113 = fd.ops.gt(T111, S112)
    S114 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T115 = fd.ops.where(T113, T111, S114)
    T116 = fd.ops.linear(T115, T23, T24)
    S117 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T118 = fd.ops.gt(T116, S117)
    S119 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T120 = fd.ops.where(T118, T116, S119)
    T121 = fd.ops.linear(T120, T25, T26)
    S122 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T123 = fd.ops.gt(T121, S122)
    S124 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T125 = fd.ops.where(T123, T121, S124)
    T126 = fd.ops.linear(T125, T27, T28)
    S127 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T128 = fd.ops.gt(T126, S127)
    S129 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T130 = fd.ops.where(T128, T126, S129)
    T131 = fd.ops.linear(T130, T29, T30)
    S132 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T133 = fd.ops.gt(T131, S132)
    S134 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T135 = fd.ops.where(T133, T131, S134)
    T136 = fd.ops.linear(T135, T31, T32)
    S137 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T138 = fd.ops.gt(T136, S137)
    S139 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T140 = fd.ops.where(T138, T136, S139)
    T141 = fd.ops.linear(T140, T33, T34)
    S142 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T143 = fd.ops.gt(T141, S142)
    S144 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T145 = fd.ops.where(T143, T141, S144)
    T146 = fd.ops.linear(T145, T35, T36)
    S147 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T148 = fd.ops.gt(T146, S147)
    S149 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T150 = fd.ops.where(T148, T146, S149)
    T151 = fd.ops.linear(T150, T37, T38)
    S152 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T153 = fd.ops.gt(T151, S152)
    S154 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T155 = fd.ops.where(T153, T151, S154)
    T156 = fd.ops.linear(T155, T39, T40)
    S157 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T158 = fd.ops.gt(T156, S157)
    S159 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T160 = fd.ops.where(T158, T156, S159)
    T161 = fd.ops.linear(T160, T41, T42)
    S162 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T163 = fd.ops.gt(T161, S162)
    S164 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T165 = fd.ops.where(T163, T161, S164)
    T166 = fd.ops.linear(T165, T43, T44)
    S167 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T168 = fd.ops.gt(T166, S167)
    S169 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T170 = fd.ops.where(T168, T166, S169)
    T171 = fd.ops.linear(T170, T45, T46)
    S172 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T173 = fd.ops.gt(T171, S172)
    S174 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T175 = fd.ops.where(T173, T171, S174)
    T176 = fd.ops.linear(T175, T47, T48)
    S177 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T178 = fd.ops.gt(T176, S177)
    S179 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T180 = fd.ops.where(T178, T176, S179)
    T181 = fd.ops.linear(T180, T49, T50)
    S182 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T183 = fd.ops.gt(T181, S182)
    S184 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T185 = fd.ops.where(T183, T181, S184)
    T186 = fd.ops.linear(T185, T51, T52)
    S187 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T188 = fd.ops.gt(T186, S187)
    S189 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T190 = fd.ops.where(T188, T186, S189)
    T191 = fd.ops.linear(T190, T53, T54)
    S192 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T193 = fd.ops.gt(T191, S192)
    S194 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T195 = fd.ops.where(T193, T191, S194)
    T196 = fd.ops.linear(T195, T55, T56)
    S197 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T198 = fd.ops.gt(T196, S197)
    S199 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T200 = fd.ops.where(T198, T196, S199)
    T201 = fd.ops.linear(T200, T57, T58)
    S202 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T203 = fd.ops.gt(T201, S202)
    S204 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T205 = fd.ops.where(T203, T201, S204)
    T206 = fd.ops.linear(T205, T59, T60)
    S207 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T208 = fd.ops.gt(T206, S207)
    S209 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T210 = fd.ops.where(T208, T206, S209)
    fd.add_output(T63)
    fd.add_output(T65)
    fd.add_output(T68)
    fd.add_output(T70)
    fd.add_output(T73)
    fd.add_output(T75)
    fd.add_output(T78)
    fd.add_output(T80)
    fd.add_output(T83)
    fd.add_output(T85)
    fd.add_output(T88)
    fd.add_output(T90)
    fd.add_output(T93)
    fd.add_output(T95)
    fd.add_output(T98)
    fd.add_output(T100)
    fd.add_output(T103)
    fd.add_output(T105)
    fd.add_output(T108)
    fd.add_output(T110)
    fd.add_output(T113)
    fd.add_output(T115)
    fd.add_output(T118)
    fd.add_output(T120)
    fd.add_output(T123)
    fd.add_output(T125)
    fd.add_output(T128)
    fd.add_output(T130)
    fd.add_output(T133)
    fd.add_output(T135)
    fd.add_output(T138)
    fd.add_output(T140)
    fd.add_output(T143)
    fd.add_output(T145)
    fd.add_output(T148)
    fd.add_output(T150)
    fd.add_output(T153)
    fd.add_output(T155)
    fd.add_output(T158)
    fd.add_output(T160)
    fd.add_output(T163)
    fd.add_output(T165)
    fd.add_output(T168)
    fd.add_output(T170)
    fd.add_output(T173)
    fd.add_output(T175)
    fd.add_output(T178)
    fd.add_output(T180)
    fd.add_output(T183)
    fd.add_output(T185)
    fd.add_output(T188)
    fd.add_output(T190)
    fd.add_output(T193)
    fd.add_output(T195)
    fd.add_output(T198)
    fd.add_output(T200)
    fd.add_output(T203)
    fd.add_output(T205)
    fd.add_output(T208)
    fd.add_output(T210)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((128, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((16,), dtype=torch.float32, device='cuda:0'),
]
import time
st=time.time()
fd.execute(inputs)
print(time.time()-st)

cc: @IvanYashchuk @wujingyue

cc @tfogal

@kevinstephano
Copy link
Collaborator

There are a couple of issue happening here:

  1. The segmenter increases in time to 300s at 30 layers for some reason up from 300 ms at 20 layers. [The Largest Issue]
  2. The NVRTC compilation is happening more than once for the same activation kernel.

Screenshot 2024-12-02 at 11 19 03

@kevinstephano
Copy link
Collaborator

I created an nvFuser Issue to investigate making segmentation faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants