-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DynamicShapeDetector
with trie implementation.
#7918
Conversation
@JackCaoG When I was writing the tests for this PR, I thought that |
// here, so that we can correctly return the builder to the root of the | ||
// trie. | ||
// | ||
// TODO(ysiraichi): we should actually rollback this trace. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As is, it should work as we expect. However, a better approach (for a future PR) would be to rollback the changes introduced since the last session start.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we hit a error likely the program will just end so maybe it is fine.
Thanks I will try to take a look today. |
torch_xla/torch_xla.py
Outdated
f: Optional[Callable] = None, | ||
full_graph: Optional[bool] = False, | ||
name: Optional[str] = None, | ||
detect_dynamic_shape=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think instead of a boolean value detect_dynamic_shape
, we can have it to be something like max_dynamic_shape_graph_allowed
, then you can map it to _dynamic_shape_detector_set_max_allowed_traces
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about allowed_traces
(because we can have multiple graphs when tracing a function once)?
torch_xla/torch_xla.py
Outdated
@@ -125,6 +129,8 @@ def foo2(x): | |||
elif hasattr(f, '__str__'): | |||
name = f.__str__() | |||
|
|||
current_id = uuid.uuid4().__str__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it a bit and find this is not very ideal. This way if we do
def f():
xxx
torch.compile(f)
torch.compile(f)
we will get 2 uuid. I think we should try to hash the passed in function pointer so we can dedup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use function pointer if it is not None (currently when it is used as a decorator with @, the fn will be none I think).
test/test_dynamic_shapes_detector.py
Outdated
self._run_and_compare(foo, optfoo, args=(inp1,)) | ||
|
||
msg = """\ | ||
torch_xla/csrc/dynamic_shape_detector.cpp:47 : Maximum number of different traces allowed per function exceeded: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpp:47
is too specified, we can remove the line number and file name from the check to make it more general.
ostr << " - " << pair.second->common_sequence_.front().str << std::endl; | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should also dump the current python trace, which would help people to figure out where to fix. Check https://github.com/pytorch/xla/blob/master/torch_xla/csrc/debug_util.cpp#L125-L131 for the example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an example running test_trace_limit_exceeded_common_sequence_mismatch
:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
yield
File "/usr/local/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/usr/local/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "xla/test/test_dynamic_shapes_detector.py", line 104, in test_trace_limit_exceeded_common_sequence_mismatch
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
File "xla/test/test_dynamic_shapes_detector.py", line 20, in _run_and_compare
optout = optf(*args)
File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "xla/test/test_dynamic_shapes_detector.py", line 93, in foo
return x * 5
RuntimeError: torch_xla/csrc/dynamic_shape_detector.cpp:41 : Maximum number of different traces allowed per function exceeded: 1
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
We already have the python trace, since this check is being done incrementally, every time a new IR node is created.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok this is nice, I guess we can go a step further where for every TRIE node we also store the python stack trace for the current node. This way when we raise the runtime error, we can also show that here is the python stack we expects, but now we hit here. It would be easier to debug this way.
we can implement this as a follow up through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implementation lgtm, minor comments on the ux.
Summary of the changes:
|
@JackCaoG This PR is ready for another round of reviews. Could you take a look at it? |
torch_xla/torch_xla.py
Outdated
f: Optional[Callable] = None, | ||
full_graph: Optional[bool] = False, | ||
name: Optional[str] = None, | ||
allowed_traces: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think traces
is too much of an implementation detail. Since we already have the full_graph
above, how about num_different_graph_allowed
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't that be kind of confusing, leading the user to think of the number of HLO graphs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is kind of true right? In here we are detecing how many different IR graph being traced and almost always translating to the number of different HLO graphs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Couldn't think of something better, so went with your suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. That's not the case if we have a fallback operation in the middle, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see, I generally expect user to also turn on full_graph
to True.. I guess trace
is more correct but I just find it expose too much underlying implementation details.
torch_xla._XLAC._dynamic_shape_detector_set_max_num_different_graphs_allowed( | ||
num_different_graphs_allowed) | ||
torch_xla._XLAC._dynamic_shape_detector_start_session(current_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt like it is better to regsiter the session with num_different_graphs_allowed
outside of the _compile
and in here we just need to start the session. We can do that in the follow up
This PR finishes the implementation started in #7817. In this PR, we implement the
DynamicShapeDetector
using a trie. It is used for detecting different traces (sequence of PyTorch operations called), which we are assuming is due to dynamic shapes.How does it work?
_XLAC._dynamic_shape_detector_start_session
call in the beginning, and_XLAC._dynamic_shape_detector_end_session
call in the end of a functionmax_allowed_traces_per_function
, we raise an errorImplementation Details:
TrieBuilder
(state of tracing)TrieBuilder
(similar to states of a DFA)TrieNode
, it means this is part of a new trace (i.e. sequence of operations)cc @miladm @JackCaoG