Skip to content

Commit

Permalink
Enable segmentation support in user-scheduling
Browse files Browse the repository at this point in the history
Create _execute_segments to orchestrate segments in original fusion.

Add supports_segmentation flag; Enabled by default
  • Loading branch information
rdspring1 committed Nov 3, 2024
1 parent e0a4538 commit f278350
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 7 deletions.
3 changes: 3 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,9 @@ void initNvFuserPythonBindings(PyObject* module) {
// Mark the end of segmentation
inst::Trace::instance()->endEvent(nullptr);
})
.def("inputs", [](FusionDefinition& self) { return self.inputs(); })
.def("outputs", [](FusionDefinition& self) { return self.outputs(); })
.def("extents", [](FusionDefinition& self) { return self.extents(); })
.def(
"__repr__",
[](FusionDefinition& self) {
Expand Down
79 changes: 79 additions & 0 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,82 @@ def __exit__(self, type, value, traceback):
def definition(self):
raise NotImplementedError("definition() should be implemented by child class!")

def _execute_segments(self, input_arguments, *, device=None, profile=False):
"""
Run the sequence of FusionDefinition segments to generate the results
of this FusionDefinition.
This FusionDefinition acts an argument manager. It gathers input
arguments for the segments and stores their output results. After
running a segment, any redundant intermediate values, which are
unnecessary for any other segments, are deleted to save memory.
Args:
inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
Kwargs:
device (Optional[Union[int, str, torch.device]]): This is a hint to run
the Fusion on the given CUDA device. This is not typically
necessary, as the device is usually inferred from the locations
of input tensors. However, for some fusion definitions, no
tensors will be input (for example when all tensors are
generated with `full` or `uniform` ops). In these cases, we
must either tell NVFuser where to run the resulting kernel, or
let it default to 0. Note that passing this option providing
and input tensors that lie on another device is an error.
profile (bool): Captures a CUPTI based profile of a fusion.
Returns:
List[Tensor]: The output results for this FusionDefinition.
"""
assert len(self.segments) > 0
assert len(self.segments) == len(self.segment_maps)

input_arguments_with_extents = [*input_arguments]
for a in input_arguments:
if type(a) is torch.Tensor:
input_arguments_with_extents.extend(a.size())

# Map inputs arguments to original fid
map_original_fid_to_value = {
fd_state: argument
for fd_state, argument in zip(
self.inputs() + self.extents(), input_arguments_with_extents
)
}

# Run all segments in correct order
for idx, (segment, segment_to_original_map) in enumerate(
zip(self.segments, self.segment_maps)
):
# Gather segment input arguments
segment_arguments = [
map_original_fid_to_value[segment_to_original_map[fd_state]]
for fd_state in segment.inputs()
]

# Run segment
segment_outputs = segment.execute(
segment_arguments, device=device, profile=profile
)

# Update original fusion definition indices to outputs
for fd_state, output in zip(segment.outputs(), segment_outputs):
map_original_fid_to_value[segment_to_original_map[fd_state]] = output

# Destroy any arguments that are not used by future segments
for segment_input in segment.inputs():
original_input = segment_to_original_map[segment_input]
if (
original_input not in self.outputs()
and self.last_used_segment[original_input] == idx
):
del map_original_fid_to_value[original_input]

# Map output fid to actual results
return [map_original_fid_to_value[fd_state] for fd_state in self.outputs()]

def execute(
self,
inputs,
Expand Down Expand Up @@ -212,6 +288,9 @@ def execute(
fake_mode = FakeTensorMode()
self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]

if hasattr(self, "segments") and len(self.segments) > 0:
return self._execute_segments(inputs, device=device, profile=profile)

results = None
try:
if print_repro:
Expand Down
10 changes: 5 additions & 5 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3211,7 +3211,7 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T54)
fd.add_output(T30)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)
# self.assertEqual(nvf_out[0], t24)

# Test that symbolic IterDomains can be concatenated
Expand Down Expand Up @@ -3743,7 +3743,7 @@ def fusion_func(fd: FusionDefinition) -> None:
fd.add_output(T57)
fd.add_output(T101)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# A simple pointwise fusion, but passed misaligned input
def test_misaligned_add(self):
Expand Down Expand Up @@ -3909,7 +3909,7 @@ def fusion_func(fd: FusionDefinition) -> None:

fd.add_output(T88)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# See https://github.com/NVIDIA/Fuser/issues/2275
@pytest.mark.skipif(
Expand Down Expand Up @@ -3955,7 +3955,7 @@ def fusion_func(fd: FusionDefinition) -> None:
T101 = fd.ops.cat([T7, T100], dim=-1)
fd.add_output(T101)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)

# See https://github.com/NVIDIA/Fuser/issues/2317
@pytest.mark.skipif(
Expand Down Expand Up @@ -4736,4 +4736,4 @@ def fusion_func(fd: FusionDefinition) -> None:
T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0)
fd.add_output(T223)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False)
19 changes: 17 additions & 2 deletions tests/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,24 @@ def check_captured_python_definition(reference_outputs, fd, inputs, device=None)

# Run original FusionDefinition
# Clone FusionDefinition
# Apply segmentation if it supported for this FusionDefinition
# Run cloned python definition
# Check that the result of cloned python definition matches original results
def check_cpp_translation(reference_outputs, fd, inputs, device=None):
def check_cpp_translation(
reference_outputs, fd, inputs, supports_segmentation, device=None
):
try:
torch.manual_seed(0)

# Clone
cloned_fd = FusionDefinition()
clone(fd, cloned_fd)

# Segment
if supports_segmentation:
cloned_fd.segment(inputs)

# Run
cloned_outputs = cloned_fd.execute(inputs, device=device)

# Make sure the results of original and cloned definitions match.
Expand All @@ -268,6 +279,7 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None):
print(
"(A failure here suggests a mismatch in functionality between the original and cloned definitions.)"
)
print("Does FusionDefinition supports segmentation?\t", supports_segmentation)
print(fd.getReproErrorString("executing", inputs))
raise err

Expand Down Expand Up @@ -419,6 +431,7 @@ def exec_nvfuser(
new_fusion_expected=True,
device=None,
is_clonable=True,
supports_segmentation=True,
):
fc = FusionCache.get()
before_fusions = fc.num_fusions()
Expand All @@ -441,5 +454,7 @@ def exec_nvfuser(
self.assertEqual(fc.num_fusions() - before_fusions, int(new_fusion_expected))

if is_clonable:
self.assertTrue(check_cpp_translation(out, fd, inputs_cloned))
self.assertTrue(
check_cpp_translation(out, fd, inputs_cloned, supports_segmentation)
)
return out, fd

0 comments on commit f278350

Please sign in to comment.