diff --git a/test/run_tests.sh b/test/run_tests.sh index 142ca53f2ae2..409cdd6dd070 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -228,6 +228,7 @@ function run_xla_op_tests2 { run_test "$CDIR/test_jax_interop.py" run_test "$CDIR/test_assume_pure.py" run_test "$CDIR/test_assume_pure_spmd.py" + TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=dynamic_shape_detector=5 run_test "$CDIR/test_dynamic_shapes_detector.py" -v } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/test/test_dynamic_shapes_detector.py b/test/test_dynamic_shapes_detector.py index be2a6b4042da..99a8423e45f5 100644 --- a/test/test_dynamic_shapes_detector.py +++ b/test/test_dynamic_shapes_detector.py @@ -1,19 +1,42 @@ +import re +import textwrap import torch import torch_xla import test_utils import unittest +# Processes a string, so that it can be used as the expected error regex. +# Specifically, it does 3 things: +# +# 1. s[1:]: assumes the first character of the string is a new-line, and +# removes it. +# +# 2. textwrap.dedent(): strips the leading space in the string, allowing us +# to write more readable multi-line strings. +# +# 3. ESCAPE_RE.sub(): escapes special characters, such as parenthesis, +# brackets, and braces, so as to allow us to write more +# readable strings. +# +# Note that because of (3), we lose the "regex" part, not being able to use +# regex wildcards, such as "*". +ESCAPE_RE = re.compile(r"([\[\](){}])") + + +def escape(s): + return ESCAPE_RE.sub(r"\\\1", textwrap.dedent(s[1:])) + class TestDynamicShapeDetector(test_utils.XlaTestCase): - def _run_and_compare(self, f, args=None, allowed_traces=None): + def _run_and_compare(self, f, args=None, max_different_graphs=None): """Run f and its torch_xla.compile wrapped version, comparing the equality of their results. If no optf is provided, we create a new one by wrapping it with torch_xla.compile ourselves. """ - optf = torch_xla.compile(f, allowed_traces=allowed_traces) + optf = torch_xla.compile(f, max_different_graphs=max_different_graphs) args = args or [] out = f(*args) @@ -22,18 +45,18 @@ def _run_and_compare(self, f, args=None, allowed_traces=None): self.assertEqual(out, optout) def test_single(self): - # Test: trace a function once, when only one trace is allowed. + # Test: trace a function once, when only one graph is allowed. def foo(x): return x + x inp = torch.rand(10, device=torch_xla.device()) - self._run_and_compare(foo, args=(inp,), allowed_traces=1) + self._run_and_compare(foo, args=(inp,), max_different_graphs=1) - def test_many_traces(self): - # Test: multiple traces of a function. + def test_many_graphs(self): + # Test: multiple graphs of a function. # - # Steps 0~2 and 5: create new traces. + # Steps 0~2 and 5: create new graphs. # Steps 3 and 4: ensure we have already traced these paths. def foo(x, step): @@ -50,41 +73,44 @@ def foo(x, step): inp = torch.rand(10, device=torch_xla.device()) for i in range(6): - self._run_and_compare(foo, args=(inp, i), allowed_traces=4) + self._run_and_compare(foo, args=(inp, i), max_different_graphs=4) - def test_trace_limit_exceeded_different_input_shape(self): - # Test: catch trace limit exceeded error when running the function with a + def test_graph_limit_exceeded_different_input_shape(self): + # Test: catch graph limit exceeded error when running the function with a # function with different shape. - allowed_traces = 1 + max_different_graphs = 1 def foo(x): return x + x inp1 = torch.rand(10, device=torch_xla.device()) - self._run_and_compare(foo, args=(inp1,), allowed_traces=allowed_traces) + self._run_and_compare( + foo, args=(inp1,), max_different_graphs=max_different_graphs) - msg = """\ -.* Maximum number of different traces allowed per function exceeded: 1 -Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10) -Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" + expected_error_msg = escape(r""" + Maximum number of different graphs allowed per function exceeded: 1 + Got: [] aten::add, xla_shape=f32[5]{0}, dynamic_dims: () + Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () + """) - with self.assertRaises(RuntimeError, msg=msg): + with self.assertRaisesRegex(RuntimeError, expected_error_msg): inp2 = torch.rand(5, device=torch_xla.device()) - self._run_and_compare(foo, args=(inp2,), allowed_traces=allowed_traces) + self._run_and_compare( + foo, args=(inp2,), max_different_graphs=max_different_graphs) - def test_trace_limit_exceeded_common_sequence_mismatch(self): - # Test: catch trace limit exceeded error when the common sequence (i.e. compressed + def test_graph_limit_exceeded_common_sequence_mismatch(self): + # Test: catch graph limit exceeded error when the common sequence (i.e. compressed # path) of the trie node mismatches. # - # Step 0: creates a trace with one node containing the add operation + # Step 0: creates a graph with one node containing the add operation # # Step 1: tries to create 2 child nodes with: - # (i) add operation (previous trace); and + # (i) add operation (previous graph); and # (ii) mul operation. # However, it fails since we have reached the limit. - allowed_traces = 1 + max_different_graphs = 1 def foo(x, step): if step == 0: @@ -93,24 +119,27 @@ def foo(x, step): return x * 5 inp = torch.rand(10, device=torch_xla.device()) - self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces) + self._run_and_compare( + foo, args=(inp, 0), max_different_graphs=max_different_graphs) - msg = """\ -.* Maximum number of different traces allowed per function exceeded: 1 -Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () -Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" + expected_error_msg = escape(r""" + Maximum number of different graphs allowed per function exceeded: 1 + Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () + Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () + """) - with self.assertRaises(RuntimeError, msg=msg): - self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces) + with self.assertRaisesRegex(RuntimeError, expected_error_msg): + self._run_and_compare( + foo, args=(inp, 2), max_different_graphs=max_different_graphs) - def test_trace_limit_exceeded_children_mismatch(self): - # Test: catch trace limit exceeded error when the expected child of the trie + def test_graph_limit_exceeded_children_mismatch(self): + # Test: catch graph limit exceeded error when the expected child of the trie # node mismatches. # - # Step 0: creates a trace with one node containing 3 operations, the last + # Step 0: creates a graph with one node containing 3 operations, the last # being a mul operation. # - # Step 1: creates another trace by splitting the node, creating 2 other child + # Step 1: creates another graph by splitting the node, creating 2 other child # nodes containing the different operations in the end: # (i) mul operation; and # (ii) add operation. @@ -118,7 +147,7 @@ def test_trace_limit_exceeded_children_mismatch(self): # Step 2: tries to create a 3rd child node: div operation. However, we can't # do it, since we have reached the limit. - allowed_traces = 2 + max_different_graphs = 2 def foo(x, step): r = x + x @@ -129,30 +158,34 @@ def foo(x, step): return r / 3 inp = torch.rand(10, device=torch_xla.device()) - self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces) - self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces) - - msg = """\ -.* Maximum number of different traces allowed per function exceeded: 2 -Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10) -Expected either of: - - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () - - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" - - with self.assertRaises(RuntimeError, msg=msg): - self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces) + self._run_and_compare( + foo, args=(inp, 0), max_different_graphs=max_different_graphs) + self._run_and_compare( + foo, args=(inp, 1), max_different_graphs=max_different_graphs) + + expected_error_msg = escape(r""" + Maximum number of different graphs allowed per function exceeded: 2 + Got: [] aten::div, xla_shape=f32[10]{0}, dynamic_dims: () + Expected either of: + - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () + - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () + """) + + with self.assertRaisesRegex(RuntimeError, expected_error_msg): + self._run_and_compare( + foo, args=(inp, 2), max_different_graphs=max_different_graphs) - def test_trace_limit_exceeded_common_sequence_early_stop(self): - # Test: catch trace limit exceeded error when the trace ends unexpectedly in + def test_graph_limit_exceeded_common_sequence_early_stop(self): + # Test: catch graph limit exceeded error when the graph ends unexpectedly in # the common sequence. # - # Step 0: creates a trace with one node containing 3 operations. + # Step 0: creates a graph with one node containing 3 operations. # - # Step 1: at the end of this trace, it tries to create a new node containing - # the remaining operations of the previous trace, i.e. mul operation. However, + # Step 1: at the end of this graph, it tries to create a new node containing + # the remaining operations of the previous graph, i.e. mul operation. However, # it fails because we have reached the limit. - allowed_traces = 1 + max_different_graphs = 1 def foo(x, mul=False): r = x + x @@ -162,31 +195,33 @@ def foo(x, mul=False): return r inp = torch.rand(10, device=torch_xla.device()) - self._run_and_compare(foo, args=(inp, True), allowed_traces=allowed_traces) + self._run_and_compare( + foo, args=(inp, True), max_different_graphs=max_different_graphs) - msg = """\ -.* Maximum number of different traces allowed per function exceeded: 1 -Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () -Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()""" + expected_error_msg = escape(r""" + Maximum number of different graphs allowed per function exceeded: 1 + Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () + Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () + """) - with self.assertRaises(RuntimeError, msg=msg): + with self.assertRaisesRegex(RuntimeError, expected_error_msg): self._run_and_compare( - foo, args=(inp, False), allowed_traces=allowed_traces) + foo, args=(inp, False), max_different_graphs=max_different_graphs) - def test_trace_limit_exceeded_children_early_stop(self): - # Test: catch trace limit exceeded error when the trace ends unexpectedly at + def test_graph_limit_exceeded_children_early_stop(self): + # Test: catch graph limit exceeded error when the graph ends unexpectedly at # a fork point (i.e. next operation would jump to anothe trie node). # - # Step 0: creates a trace with one node containing 3 operations. + # Step 0: creates a graph with one node containing 3 operations. # # Step 1: splits the node, creating 2 child nodes containing: - # (i) the differring operations from the last trace, i.e. mul operation + # (i) the differring operations from the last graph, i.e. mul operation # (ii) the current last operation, i.e. add operation # - # Step 3: at the end of this trace, it tries to turn the current trie node - # into a new trace. However, it fails since we have reached the limit. + # Step 3: at the end of this graph, it tries to turn the current trie node + # into a new graph. However, it fails since we have reached the limit. - allowed_traces = 2 + max_different_graphs = 2 def foo(x, step): r = x + x @@ -197,18 +232,22 @@ def foo(x, step): return r inp = torch.rand(10, device=torch_xla.device()) - self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces) - self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces) - - msg = """\ -.* Maximum number of different traces allowed per function exceeded: 2 -Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () -Expected either of: - - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () - - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()""" - - with self.assertRaises(RuntimeError, msg=msg): - self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces) + self._run_and_compare( + foo, args=(inp, 0), max_different_graphs=max_different_graphs) + self._run_and_compare( + foo, args=(inp, 1), max_different_graphs=max_different_graphs) + + expected_error_msg = escape(r""" + Maximum number of different graphs allowed per function exceeded: 2 + Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () + Expected either of: + - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: () + - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: () + """) + + with self.assertRaisesRegex(RuntimeError, expected_error_msg): + self._run_and_compare( + foo, args=(inp, 2), max_different_graphs=max_different_graphs) if __name__ == "__main__": diff --git a/torch_xla/csrc/dynamic_shape_detector.cpp b/torch_xla/csrc/dynamic_shape_detector.cpp index 1a7518d7a5e8..cc4526b2ce36 100644 --- a/torch_xla/csrc/dynamic_shape_detector.cpp +++ b/torch_xla/csrc/dynamic_shape_detector.cpp @@ -6,21 +6,27 @@ namespace torch_xla { -// Maximum number of allowed traces per function (i.e. session). -static std::size_t max_allowed_traces_per_function = 1; +// Maximum number of allowed graphs per function (i.e. session). +static std::size_t max_different_graphs = 1; + +TrieNode::TrieNode(const TrieValue& value, bool is_graph_boundary) + : TrieNode() { + common_sequence_.push_back(value); + is_graph_boundary_ = is_graph_boundary; +} TrieNode::TrieNode(absl::Span common_sequence, - bool is_trace_boundary) + bool is_graph_boundary) : common_sequence_(common_sequence.begin(), common_sequence.end()), - is_trace_boundary_(is_trace_boundary) {} + is_graph_boundary_(is_graph_boundary) {} bool TrieNode::IsLeaf() const { return children_.empty(); } -void TrieNode::NewTraceNotAllowedError(std::optional value, +void TrieNode::NewGraphNotAllowedError(std::optional value, std::size_t matched) { std::ostringstream ostr; - ostr << "Maximum number of different traces allowed per function exceeded: " - << max_allowed_traces_per_function << std::endl; + ostr << "Maximum number of different graphs allowed per function exceeded: " + << max_different_graphs << std::endl; if (value.has_value()) { ostr << "Got: " << value->str << std::endl; @@ -41,22 +47,22 @@ void TrieNode::NewTraceNotAllowedError(std::optional value, XLA_ERROR() << ostr.str(); } -bool TrieNode::MarkTraceBoundary(std::size_t matched, bool allow_new_trace) { +bool TrieNode::MarkGraphBoundary(std::size_t matched, bool allow_new_graph) { // No need to do anything here, iff: // // 1. nothing was matched, yet // // 2. we matched everything in this node, and this node is already marked as - // a trace boundary. + // a graph boundary. if (matched == 0 || - (common_sequence_.size() == matched && is_trace_boundary_)) { + (common_sequence_.size() == matched && is_graph_boundary_)) { return false; } - // From this point, we will create a new trace. - if (!allow_new_trace) { - // Raise an error if we have reached the maximum number of traces. - NewTraceNotAllowedError(std::nullopt, matched); + // From this point, we will create a new graph. + if (!allow_new_graph) { + // Raise an error if we have reached the maximum number of graphs. + NewGraphNotAllowedError(std::nullopt, matched); } // If we haven't matched everything in this node, we will have to split this @@ -67,20 +73,20 @@ bool TrieNode::MarkTraceBoundary(std::size_t matched, bool allow_new_trace) { MaybeSplitAt(matched); } - // Finally, mark this node as a trace boundary. - is_trace_boundary_ = true; + // Finally, mark this node as a graph boundary. + is_graph_boundary_ = true; return true; } TrieBuilder TrieNode::AddValue(TrieValue value, std::size_t matched, - bool allow_new_trace) { + bool allow_new_graph) { TF_VLOG(5) << "Adding value: " << value.str << " (" << value.hash << ")"; - // If this node has no children and is not marked as a trace boundary, it + // If this node has no children and is not marked as a graph boundary, it // means that TrieBuilder created this node and is incrementally adding // TrieValue to it. Therefore, we just need to keep doing it. - if (IsLeaf() && !is_trace_boundary_) { + if (IsLeaf() && !is_graph_boundary_) { common_sequence_.push_back(value); return {this, matched + 1}; } @@ -101,10 +107,10 @@ TrieBuilder TrieNode::AddValue(TrieValue value, std::size_t matched, return {children_[value.hash].get(), 1}; } - // Otherwise, we will have to create a new trace. So, first, check whether we + // Otherwise, we will have to create a new graph. So, first, check whether we // are allowed to do so. - if (!allow_new_trace) { - NewTraceNotAllowedError(value, matched); + if (!allow_new_graph) { + NewGraphNotAllowedError(value, matched); } // Maybe split the current node into: prefix (before matched) and suffix @@ -112,8 +118,7 @@ TrieBuilder TrieNode::AddValue(TrieValue value, std::size_t matched, bool did_split = MaybeSplitAt(matched); // Create a new node that contains only the given value. - std::unique_ptr node = - std::make_unique(absl::Span{value}); + std::unique_ptr node = std::make_unique(value); // Associate the given value with the created node in the children's map. children_[value.hash] = std::move(node); @@ -121,11 +126,11 @@ TrieBuilder TrieNode::AddValue(TrieValue value, std::size_t matched, TF_VLOG(5) << "Created new node " << children_[value.hash].get() << " for value: " << value.str << " (" << value.hash << ")"; - // Unmark this node as trace boundary iff we actually split this node (i.e. - // suffix actually had something). Otherwise, this should still be a trace + // Unmark this node as graph boundary iff we actually split this node (i.e. + // suffix actually had something). Otherwise, this should still be a graph // boundary. if (did_split) { - is_trace_boundary_ = false; + is_graph_boundary_ = false; } return {children_[value.hash].get(), 1}; @@ -143,7 +148,7 @@ bool TrieNode::MaybeSplitAt(std::size_t matched) { // A split only occurs if suffix is not empty. if (!suffix.empty()) { std::unique_ptr suffix_node = - std::make_unique(suffix, is_trace_boundary_); + std::make_unique(suffix, is_graph_boundary_); // The suffix node's children should be what this node's children was before // the split. Therefore, we swap those. @@ -177,38 +182,38 @@ void DynamicShapeDetector::StartSession(const std::string& name) { RootBuilder(); } -void DynamicShapeDetector::SetMaxAllowedTraces(std::size_t value) { - max_allowed_traces_per_function = value; +void DynamicShapeDetector::SetMaxDifferentGraphs(std::size_t value) { + max_different_graphs = value; } -std::size_t DynamicShapeDetector::GetMaxAllowedTraces() { - return max_allowed_traces_per_function; +std::size_t DynamicShapeDetector::GetMaxDifferentGraphs() { + return max_different_graphs; } bool DynamicShapeDetector::IsSessionActive() { return current_session_ != nullptr; } -bool DynamicShapeDetector::AllowNewTrace() { +bool DynamicShapeDetector::AllowNewGraph() { XLA_CHECK(IsSessionActive()); - return current_session_->traces_ < max_allowed_traces_per_function; + return current_session_->graphs_ < max_different_graphs; } void DynamicShapeDetector::EndSession() { XLA_CHECK(IsSessionActive()); try { - // Mark the current builder_ node as trace boundary. - // If we did create a new trace, increment the session's trace number. - if (builder_.MarkTraceBoundary(AllowNewTrace())) { - current_session_->traces_++; - TF_VLOG(5) << "Created new trace."; + // Mark the current builder_ node as graph boundary. + // If we did create a new graph, increment the session's graph number. + if (builder_.MarkGraphBoundary(AllowNewGraph())) { + current_session_->graphs_++; + TF_VLOG(5) << "Created new graph."; } - ResetSession(); TF_VLOG(5) << "Ended session: " << current_session_->name_; + ResetSession(); } catch (const std::exception& e) { - // MarkTraceBoundary might raise an exception if AllowNewTrace() is false. + // MarkGraphBoundary might raise an exception if AllowNewGraph() is false. // Catch it here, so that we can correctly end the session. ResetSession(); throw; @@ -229,13 +234,13 @@ void DynamicShapeDetector::AddNodeInfo(torch::lazy::hash_t hash, XLA_CHECK(current_session_ != nullptr); try { - builder_.AddValue({hash, str}, AllowNewTrace()); + builder_.AddValue({hash, str}, AllowNewGraph()); } catch (const std::exception& e) { - // AddValue might raise an exception if AllowNewTrace() is false. Catch it + // AddValue might raise an exception if AllowNewGraph() is false. Catch it // here, so that we can correctly return the builder to the root of the // trie. // - // TODO(ysiraichi): we should actually rollback this trace. + // TODO(ysiraichi): we should actually rollback this graph. RootBuilder(); throw; } @@ -250,12 +255,12 @@ void DynamicShapeDetector::RemoveSessionIfExists(const std::string& name) { TrieBuilder SessionInfo::NewBuilder() { return {root_.get(), 0}; } -void TrieBuilder::AddValue(TrieValue value, bool allow_new_trace) { - *this = node_->AddValue(value, matched_, allow_new_trace); +void TrieBuilder::AddValue(TrieValue value, bool allow_new_graph) { + *this = node_->AddValue(value, matched_, allow_new_graph); } -bool TrieBuilder::MarkTraceBoundary(bool allow_new_trace) { - return node_->MarkTraceBoundary(matched_, allow_new_trace); +bool TrieBuilder::MarkGraphBoundary(bool allow_new_graph) { + return node_->MarkGraphBoundary(matched_, allow_new_graph); } } // namespace torch_xla diff --git a/torch_xla/csrc/dynamic_shape_detector.h b/torch_xla/csrc/dynamic_shape_detector.h index 623814fb33c5..5afeed85a1cc 100644 --- a/torch_xla/csrc/dynamic_shape_detector.h +++ b/torch_xla/csrc/dynamic_shape_detector.h @@ -39,8 +39,8 @@ struct TrieValue { // 2. matched_ will be 0 only in the beginning for the root node. struct TrieBuilder { // Wrappers to the currently pointed to TrieNode methods. - void AddValue(TrieValue value, bool allow_new_trace); - bool MarkTraceBoundary(bool allow_new_trace); + void AddValue(TrieValue value, bool allow_new_graph); + bool MarkGraphBoundary(bool allow_new_graph); // Current TrieNode. TrieNode* node_; @@ -56,8 +56,8 @@ struct TrieBuilder { // // The main interface to interact with TrieNode is TrieBuilder. We start from // the root, incrementally accepting new TrieValue by calling AddValue. Said -// function will incrementally build the trie. Finally, MarkTraceBoundary will -// set is_trace_boundary_ and maybe split the current node (if we haven't +// function will incrementally build the trie. Finally, MarkGraphBoundary will +// set is_graph_boundary_ and maybe split the current node (if we haven't // matched everything in this node's common_sequence_). // // Main assumption for every TrieNode @@ -92,21 +92,25 @@ struct TrieBuilder { // (ii) a node containing the given TrieValue. The returned TrieBuilder will be // {node (ii), 1}. // -// 5. Consider the TrieBuilder {root, 20}. If MarkTraceBoundary is called, and -// root is a leaf (i.e. no children), then root.is_trace_boundary_ is set to +// 5. Consider the TrieBuilder {root, 20}. If MarkGraphBoundary is called, and +// root is a leaf (i.e. no children), then root.is_graph_boundary_ is set to // true. struct TrieNode { using ChildrenMap = std::map>; + // Create a TrieNode with one value as its common_sequence_. + TrieNode(const TrieValue& value, bool is_graph_boundary = false); + + // Create a TrieNode with a specific common_sequence_. TrieNode(absl::Span common_sequence = {}, - bool is_trace_boundary = false); + bool is_graph_boundary = false); // May add TrieValue to this TrieNode. // - // This function is used to iteratively construct the trace. It does 2 things. + // This function is used to iteratively construct the graph. It does 2 things. // // First, it checks whether the given value actually matches the values - // already inside this node, i.e. this trace was seen before. For example, the + // already inside this node, i.e. this graph was seen before. For example, the // given value may match the value inside common_sequence_ (after `matched` // elements) or one of children (if `matched` equals the size of // common_sequence_). @@ -118,22 +122,22 @@ struct TrieNode { // 3. splitting this node, creating 2 new nodes containing: (i) rest of the // unmatched common_sequence_; and (ii) the given value. TrieBuilder AddValue(TrieValue value, std::size_t matched, - bool allow_new_trace); + bool allow_new_graph); - // Marks this node as trace boundary. + // Marks this node as graph boundary. // // Given the number of `matched` elements in the common_sequence_, this - // function sets `is_trace_boundary_` and possibly moves the rest of the + // function sets `is_graph_boundary_` and possibly moves the rest of the // unmatched common_sequence_ to a new node. // - // Returns whether a new trace was created. - bool MarkTraceBoundary(std::size_t matched, bool allow_new_trace); + // Returns whether a new graph was created. + bool MarkGraphBoundary(std::size_t matched, bool allow_new_graph); - // Issue an error indicating a new trace is not allowed. + // Issue an error indicating a new graph is not allowed. // // This function will correctly inspect the TrieNode, building an informative // error message. - void NewTraceNotAllowedError(std::optional value, + void NewGraphNotAllowedError(std::optional value, std::size_t matched); // Maybe split this node into 2, containing, respectively: (i) @@ -156,9 +160,9 @@ struct TrieNode { // Sequence of values all children_ in this node share. std::vector common_sequence_; - // Flag indicating whether the current node is a trace boundary. i.e. - // whether there is a trace that ends with common_sequence_. - bool is_trace_boundary_; + // Flag indicating whether the current node is a graph boundary. i.e. + // whether there is a graph that ends with common_sequence_. + bool is_graph_boundary_; // Children, i.e. forking points, of this node. ChildrenMap children_; @@ -171,17 +175,17 @@ struct SessionInfo { // Name of this session. std::string name_; - // Root of the trie that stores trace information for this session. + // Root of the trie that stores graph information for this session. std::unique_ptr root_; - // Number of recorded traces for this session. - std::size_t traces_; + // Number of recorded graphs for this session. + std::size_t graphs_; }; // Surface class for detecting dynamic shapes. // // Manages the information related to each session as well as the active -// session, i.e. the one that we are recording traces for. +// session, i.e. the one that we are recording graphs for. class DynamicShapeDetector { public: static DynamicShapeDetector* Get(); @@ -192,17 +196,17 @@ class DynamicShapeDetector { // Stops recording the created IR nodes for the active session. // - // Before doing that, we commit the current trace, turning the current - // TrieNode being visited into a trace boundary. + // Before doing that, we commit the current graph, turning the current + // TrieNode being visited into a graph boundary. // // This function may raise an exception if we aren't allowed to create - // more traces. + // more graphs. void EndSession(); // Records a newly created IR node (its metadata). // // This function may raise an exception if: - // 1. we aren't allowed to create more traces; and + // 1. we aren't allowed to create more graphs; and // 2. we have to create a new TrieNode because this IR node wasn't expected // in the trie. void AddNodeInfo(torch::lazy::hash_t hash, const std::string& str); @@ -213,13 +217,13 @@ class DynamicShapeDetector { // Maybe removes the session entry. void RemoveSessionIfExists(const std::string& name); - // API for setting the maximum number of traces allowed to be recorded. - static void SetMaxAllowedTraces(std::size_t value); - static std::size_t GetMaxAllowedTraces(); + // API for setting the maximum number of graphs allowed to be recorded. + static void SetMaxDifferentGraphs(std::size_t value); + static std::size_t GetMaxDifferentGraphs(); private: - // Whether the current session allows new traces, i.e. new graph compilations. - bool AllowNewTrace(); + // Whether the current session allows new graphs, i.e. new graph compilations. + bool AllowNewGraph(); // Move the TrieBuilder to the root node of this session. void RootBuilder(); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d1b38efe00ea..5d9630464a63 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2614,14 +2614,12 @@ void InitXlaModuleBindings(py::module m) { [](const std::string& session) { DynamicShapeDetector::Get()->RemoveSessionIfExists(session); }); - m.def("_dynamic_shape_detector_set_max_allowed_traces", - [](int64_t max_allowed_traces) { - DynamicShapeDetector::SetMaxAllowedTraces(max_allowed_traces); - }); - m.def("_dynamic_shape_detector_get_max_allowed_traces", - [](int64_t max_allowed_traces) { - return DynamicShapeDetector::GetMaxAllowedTraces(); + m.def("_dynamic_shape_detector_set_max_different_graphs", + [](int64_t max_different_graphs) { + DynamicShapeDetector::SetMaxDifferentGraphs(max_different_graphs); }); + m.def("_dynamic_shape_detector_get_max_different_graphs", + []() { return DynamicShapeDetector::GetMaxDifferentGraphs(); }); m.def("_replace_xla_tensor", [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 7009e6aa5041..70b91e648dc7 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -94,7 +94,7 @@ def compile( f: Optional[Callable] = None, full_graph: Optional[bool] = False, name: Optional[str] = None, - num_different_graphs_allowed: Optional[int] = None, + max_different_graphs: Optional[int] = None, ): """ Optimizes given model/function using torch_xla's LazyTensor tracing mode. @@ -112,7 +112,7 @@ def compile( name (Optional[name]): Name of the compiled program. The name of the function `f` will be used if not specified. This name will be used in the `PT_XLA_DEBUG` messages as well as HLO/IR dump file. - num_different_graphs_allowed (Optional[int]): number of different traced graphs of the given + max_different_graphs (Optional[int]): number of different traced graphs of the given model/function that we are allowed to have. An error will be raised in case this limit is exceeded. @@ -177,16 +177,16 @@ def _compile(): # if full_graph sets to true execution can not happen before the sync below torch_xla._XLAC._set_allow_execution(not full_graph) - if num_different_graphs_allowed is not None: - torch_xla._XLAC._dynamic_shape_detector_set_max_num_different_graphs_allowed( - num_different_graphs_allowed) + if max_different_graphs is not None: + torch_xla._XLAC._dynamic_shape_detector_set_max_different_graphs( + max_different_graphs) torch_xla._XLAC._dynamic_shape_detector_start_session(current_id) try: yield finally: torch_xla._XLAC._set_allow_execution(saved_allow_execution) - if num_different_graphs_allowed is not None: + if max_different_graphs is not None: torch_xla._XLAC._dynamic_shape_detector_end_session() # Collect the traced graph after running the target function and # execute the graph.