Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamicShapeDetector with trie implementation. #7918

Merged
merged 10 commits into from
Sep 4, 2024

Conversation

ysiraichi
Copy link
Collaborator

This PR finishes the implementation started in #7817. In this PR, we implement the DynamicShapeDetector using a trie. It is used for detecting different traces (sequence of PyTorch operations called), which we are assuming is due to dynamic shapes.

How does it work?

  • Wrap a function with _XLAC._dynamic_shape_detector_start_session call in the beginning, and _XLAC._dynamic_shape_detector_end_session call in the end of a function
  • When the function is run, the detector will start keeping track of the created IR nodes
  • If, at the N-th call, we trace a different sequence of operations from the previous calls, we increment the number of traces
  • If the number of traces is greater-than max_allowed_traces_per_function, we raise an error

Implementation Details:

  • Build the trie incrementally, with the help of TrieBuilder (state of tracing)
  • At every traced operation, we update the TrieBuilder (similar to states of a DFA)
  • If, at any point, we have to update the actual TrieNode, it means this is part of a new trace (i.e. sequence of operations)

cc @miladm @JackCaoG

@ysiraichi
Copy link
Collaborator Author

@JackCaoG When I was writing the tests for this PR, I thought that torch_xla.compile could return a class instance (in a future PR), so that we can track the number of recorded traces for each function. What do you think?

// here, so that we can correctly return the builder to the root of the
// trie.
//
// TODO(ysiraichi): we should actually rollback this trace.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is, it should work as we expect. However, a better approach (for a future PR) would be to rollback the changes introduced since the last session start.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we hit a error likely the program will just end so maybe it is fine.

@JackCaoG
Copy link
Collaborator

Thanks I will try to take a look today.

f: Optional[Callable] = None,
full_graph: Optional[bool] = False,
name: Optional[str] = None,
detect_dynamic_shape=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of a boolean value detect_dynamic_shape, we can have it to be something like max_dynamic_shape_graph_allowed, then you can map it to _dynamic_shape_detector_set_max_allowed_traces directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about allowed_traces (because we can have multiple graphs when tracing a function once)?

@@ -125,6 +129,8 @@ def foo2(x):
elif hasattr(f, '__str__'):
name = f.__str__()

current_id = uuid.uuid4().__str__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it a bit and find this is not very ideal. This way if we do

def f():
  xxx   

torch.compile(f)
torch.compile(f)

we will get 2 uuid. I think we should try to hash the passed in function pointer so we can dedup.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use function pointer if it is not None (currently when it is used as a decorator with @, the fn will be none I think).

self._run_and_compare(foo, optfoo, args=(inp1,))

msg = """\
torch_xla/csrc/dynamic_shape_detector.cpp:47 : Maximum number of different traces allowed per function exceeded: 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpp:47 is too specified, we can remove the line number and file name from the check to make it more general.

ostr << " - " << pair.second->common_sequence_.front().str << std::endl;
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also dump the current python trace, which would help people to figure out where to fix. Check https://github.com/pytorch/xla/blob/master/torch_xla/csrc/debug_util.cpp#L125-L131 for the example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example running test_trace_limit_exceeded_common_sequence_mismatch:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
    yield
  File "/usr/local/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/local/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "xla/test/test_dynamic_shapes_detector.py", line 104, in test_trace_limit_exceeded_common_sequence_mismatch
    self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
  File "xla/test/test_dynamic_shapes_detector.py", line 20, in _run_and_compare
    optout = optf(*args)
  File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "xla/test/test_dynamic_shapes_detector.py", line 93, in foo
    return x * 5
RuntimeError: torch_xla/csrc/dynamic_shape_detector.cpp:41 : Maximum number of different traces allowed per function exceeded: 1
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()

We already have the python trace, since this check is being done incrementally, every time a new IR node is created.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this is nice, I guess we can go a step further where for every TRIE node we also store the python stack trace for the current node. This way when we raise the runtime error, we can also show that here is the python stack we expects, but now we hit here. It would be easier to debug this way.

we can implement this as a follow up through.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implementation lgtm, minor comments on the ux.

@ysiraichi
Copy link
Collaborator Author

Summary of the changes:

  • If a function is given, make its current_id dependent on its name and id
  • Add API for removing C++ session entries: DynamicShapeDetector::RemoveSessionIfExists
  • Keep track of the compiled functions that are still alive
    • Different local-scoped functions with the same name and id may exist
  • Add allowed_traces optional parameter + documentation
  • Remove file and line information from the expected error messages

@miladm miladm added the dynamism Dynamic Shape Features label Sep 3, 2024
@ysiraichi
Copy link
Collaborator Author

@JackCaoG This PR is ready for another round of reviews. Could you take a look at it?

f: Optional[Callable] = None,
full_graph: Optional[bool] = False,
name: Optional[str] = None,
allowed_traces: Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think traces is too much of an implementation detail. Since we already have the full_graph above, how about num_different_graph_allowed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that be kind of confusing, leading the user to think of the number of HLO graphs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is kind of true right? In here we are detecing how many different IR graph being traced and almost always translating to the number of different HLO graphs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Couldn't think of something better, so went with your suggestion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. That's not the case if we have a fallback operation in the middle, is it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, I generally expect user to also turn on full_graph to True.. I guess trace is more correct but I just find it expose too much underlying implementation details.

Comment on lines +181 to +183
torch_xla._XLAC._dynamic_shape_detector_set_max_num_different_graphs_allowed(
num_different_graphs_allowed)
torch_xla._XLAC._dynamic_shape_detector_start_session(current_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt like it is better to regsiter the session with num_different_graphs_allowed outside of the _compile and in here we just need to start the session. We can do that in the follow up

@ysiraichi ysiraichi merged commit 400bd91 into master Sep 4, 2024
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamism Dynamic Shape Features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants