forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython_anomaly_mode.cpp
139 lines (125 loc) · 4.09 KB
/
python_anomaly_mode.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <c10/util/Exception.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/python_anomaly_mode.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_strings.h>
namespace torch::autograd {
void PyAnomalyMetadata::store_stack() {
pybind11::gil_scoped_acquire gil;
THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
if (!mod) {
throw python_error();
}
THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
if (!list) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
throw python_error();
}
}
void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
pybind11::gil_scoped_acquire gil;
if (!PyDict_Check(dict())) {
throw std::runtime_error("Anomaly metadata is not a python dictionary.");
}
PyObject* trace_stack = nullptr;
if (PyDict_GetItemStringRef(dict(), ANOMALY_TRACE_KEY, &trace_stack) < 0) {
throw python_error();
}
_print_stack(trace_stack, current_node_name, false);
PyObject* pyparent = nullptr;
if (PyDict_GetItemStringRef(dict(), ANOMALY_PARENT_KEY, &pyparent) < 0) {
throw python_error();
}
// if there is no "parent_" in metadata, then it means this metadata's node
// is the root and stop printing the traceback
while (pyparent) {
THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
if (!parent_metadata) {
throw python_error();
}
THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
if (!parent_name_pyobj) {
throw python_error();
}
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
if (!parent_name_char) {
throw python_error();
}
const std::string parent_name(parent_name_char);
PyObject* parent_stack = nullptr;
if (PyDict_GetItemStringRef(
parent_metadata.get(), ANOMALY_TRACE_KEY, &parent_stack) < 0) {
throw python_error();
}
_print_stack(parent_stack, parent_name, true);
// get the parent of this node, if this node is a root, pyparent is simply
// null
if (PyDict_GetItemStringRef(
parent_metadata.get(), ANOMALY_PARENT_KEY, &pyparent) < 0) {
throw python_error();
}
}
}
void PyAnomalyMetadata::assign_parent(
const std::shared_ptr<Node>& parent_node) {
// assign the python object of parent_node in metadata["parent_"]
// if parent_node is nullptr, then do nothing (it can mean that "parent_" key
// is not in metadata)
pybind11::gil_scoped_acquire gil;
if (!parent_node)
return;
THPObjectPtr parent_node_(functionToPyObject(parent_node));
if (!parent_node_) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
throw python_error();
}
}
void _print_stack(
PyObject* stack,
const std::string& current_node_name,
bool is_parent) {
if (!stack) {
TORCH_WARN(
"Error detected in ",
current_node_name,
". ",
"No forward pass information available. Enable detect anomaly "
"during forward pass for more information.");
return;
}
THPObjectPtr empty_string(PyUnicode_FromString(""));
if (!empty_string) {
throw python_error();
}
// stack is a list of Python strings ending with newlines. Use join to convert
// to a single string.
THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
if (!msg) {
throw python_error();
}
if (!is_parent) {
TORCH_WARN(
"Error detected in ",
current_node_name,
". ",
"Traceback of forward call that caused the error:\n",
THPUtils_unpackString(msg.get()));
} else {
TORCH_WARN(
"\n\n",
"Previous calculation was induced by ",
current_node_name,
". "
"Traceback of forward call that induced the previous calculation:\n",
THPUtils_unpackString(msg.get()));
}
}
} // namespace torch::autograd