Skip to content

Commit

Permalink
Refactor error msg stack handling, add TORCH_RETHROW (pytorch#37101)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#37101

Fixes pytorch#36954.

The basic concept is to streamline the process of rethrowing
c10::Error with extra error information.  This is in a few
steps:

- I completely remodeled the Error data type and the internal
  invariants.  Instead of manually adding in newlines, the
  message stack formatting process is responsible for inserting
  newlines and spacing as necessary.  Call sites are then
  modified to respect the new API model.
- TORCH_RETHROW macro is added, which adds context to an error
  message and then rethrows it.

New internal assert failure looks like:

```
0 INTERNAL ASSERT FAILED at ../c10/test/util/exception_test.cpp:64, please report a bug to PyTorch.
Exception raised from TestBody at ../c10/test/util/exception_test.cpp:64 (most recent call first):
frame #0: <unknown function> + 0x6aab9 (0x7ff611d3aab9 in /data/users/ezyang/pytorch-tmp/build/lib/libc10.so)
frame #1: ...
```

Error message with context looks like:

```
This is an error
  This is context 1
  This is context 2
```

Signed-off-by: Edward Z. Yang <[email protected]>

Test Plan: Imported from OSS

Differential Revision: D21202891

Pulled By: ezyang

fbshipit-source-id: 361cadd16bc52e5886dba08e79277771ada76169
  • Loading branch information
ezyang authored and facebook-github-bot committed May 4, 2020
1 parent efd8f70 commit a058e93
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 57 deletions.
43 changes: 43 additions & 0 deletions c10/test/util/exception_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@
#include <gtest/gtest.h>
#include <stdexcept>

using c10::Error;

namespace {
bool throw_func() {
throw std::runtime_error("I'm throwing...");
}

template<class Functor>
inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
try {
std::forward<Functor>(functor)();
} catch (const Error& e) {
EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
return;
}
ADD_FAILURE() << "Expected to throw exception with message \""
<< expectedMessage << "\" but didn't throw";
}
} // namespace

TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
Expand All @@ -22,3 +36,32 @@ TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
TEST(WarningTest, JustPrintWarning) {
TORCH_WARN("I'm a warning");
}

TEST(ExceptionTest, ErrorFormatting) {
expectThrowsEq([]() {
TORCH_CHECK(false, "This is invalid");
}, "This is invalid");

expectThrowsEq([]() {
try {
TORCH_CHECK(false, "This is invalid");
} catch (Error& e) {
TORCH_RETHROW(e, "While checking X");
}
}, "This is invalid (While checking X)");

expectThrowsEq([]() {
try {
try {
TORCH_CHECK(false, "This is invalid");
} catch (Error& e) {
TORCH_RETHROW(e, "While checking X");
}
} catch (Error& e) {
TORCH_RETHROW(e, "While checking Y");
}
},
R"msg(This is invalid
While checking X
While checking Y)msg");
}
55 changes: 34 additions & 21 deletions c10/util/Exception.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@
#include <c10/util/Logging.h>

#include <iostream>
#include <sstream>
#include <numeric>
#include <string>

namespace c10 {

Error::Error(
const std::string& new_msg,
const std::string& backtrace,
const void* caller)
: msg_stack_{new_msg}, backtrace_(backtrace), caller_(caller) {
msg_ = msg();
msg_without_backtrace_ = msg_without_backtrace();
Error::Error(std::string msg, std::string backtrace, const void* caller)
: msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
refresh_what();
}

// PyTorch-style error message
Expand All @@ -38,29 +35,45 @@ Error::Error(
"] ",
condition,
". ",
msg,
"\n"),
msg),
backtrace,
caller) {}

std::string Error::msg() const {
return std::accumulate(
msg_stack_.begin(), msg_stack_.end(), std::string("")) +
backtrace_;
std::string Error::compute_what(bool include_backtrace) const {
std::ostringstream oss;

oss << msg_;

if (context_.size() == 1) {
// Fold error and context in one line
oss << " (" << context_[0] << ")";
} else {
for (const auto& c : context_) {
oss << "\n " << c;
}
}

if (include_backtrace) {
oss << "\n" << backtrace_;
}

return oss.str();
}

std::string Error::msg_without_backtrace() const {
return std::accumulate(msg_stack_.begin(), msg_stack_.end(), std::string(""));
void Error::refresh_what() {
what_ = compute_what(/*include_backtrace*/ true);
what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
}

void Error::AppendMessage(const std::string& new_msg) {
msg_stack_.push_back(new_msg);
// Refresh the cache
// TODO: Calling AppendMessage O(n) times has O(n^2) cost. We can fix
void Error::add_context(std::string new_msg) {
context_.push_back(std::move(new_msg));
// TODO: Calling add_context O(n) times has O(n^2) cost. We can fix
// this perf problem by populating the fields lazily... if this ever
// actually is a problem.
msg_ = msg();
msg_without_backtrace_ = msg_without_backtrace();
// NB: If you do fix this, make sure you do it in a thread safe way!
// what() is almost certainly expected to be thread safe even when
// accessed across multiple threads
refresh_what();
}

namespace Warning {
Expand Down
77 changes: 60 additions & 17 deletions c10/util/Exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,26 @@ namespace c10 {
/// NB: c10::Error is handled specially by the default torch to suppress the
/// backtrace, see torch/csrc/Exceptions.h
class C10_API Error : public std::exception {
std::vector<std::string> msg_stack_;
// The actual error message.
std::string msg_;

// Context for the message (in order of decreasing specificity). Context will
// be automatically formatted appropriately, so it is not necessary to add
// extra leading/trailing newlines to strings inside this vector
std::vector<std::string> context_;

// The C++ backtrace at the point when this exception was raised. This
// may be empty if there is no valid backtrace. (We don't use optional
// here to reduce the dependencies this file has.)
std::string backtrace_;

// These two are derived fields from msg_stack_ and backtrace_, but we need
// fields for the strings so that we can return a const char* (as the
// signature of std::exception requires).
std::string msg_;
std::string msg_without_backtrace_;
// signature of std::exception requires). Currently, the invariant
// is that these fields are ALWAYS populated consistently with respect
// to msg_stack_ and backtrace_.
std::string what_;
std::string what_without_backtrace_;

// This is a little debugging trick: you can stash a relevant pointer
// in caller, and then when you catch the exception, you can compare
Expand All @@ -43,11 +55,11 @@ class C10_API Error : public std::exception {
const void* caller_;

public:
Error(
const std::string& msg,
const std::string& backtrace,
const void* caller = nullptr);
Error(SourceLocation source_location, const std::string& msg);
// PyTorch-style Error constructor. NB: the implementation of this
// is actually in Logging.cpp
Error(SourceLocation source_location, std::string msg);

// Caffe2-style error message
Error(
const char* file,
const uint32_t line,
Expand All @@ -56,30 +68,51 @@ class C10_API Error : public std::exception {
const std::string& backtrace,
const void* caller = nullptr);

void AppendMessage(const std::string& msg);
// Base constructor
Error(
std::string msg,
std::string backtrace,
const void* caller = nullptr);

const std::vector<std::string>& msg_stack() const {
return msg_stack_;
// Add some new context to the message stack. The last added context
// will be formatted at the end of the context list upon printing.
// WARNING: This method is O(n) in the size of the stack, so don't go
// wild adding a ridiculous amount of context to error messages.
void add_context(std::string msg);

const std::string& msg() const {
return msg_;
}

const std::vector<std::string>& context() const {
return context_;
}

const std::string& backtrace() const {
return backtrace_;
}

/// Returns the complete error message, including the source location.
/// The returned pointer is invalidated if you call add_context() on
/// this object.
const char* what() const noexcept override {
return msg_.c_str();
return what_.c_str();
}

const void* caller() const noexcept {
return caller_;
}

/// Returns only the error message string, without source location.
/// The returned pointer is invalidated if you call add_context() on
/// this object.
const char* what_without_backtrace() const noexcept {
return msg_without_backtrace_.c_str();
return what_without_backtrace_.c_str();
}

private:
// Compute the full message from msg_ and msg_without_backtrace_
std::string msg() const;
std::string msg_without_backtrace() const;
void refresh_what();
std::string compute_what(bool include_backtrace) const;
};

class C10_API WarningHandler {
Expand Down Expand Up @@ -204,6 +237,16 @@ inline std::string if_empty_then(std::string x, std::string y) {
// Error reporting macros
// ----------------------------------------------------------------------------

#ifdef STRIP_ERROR_MESSAGES
#define TORCH_RETHROW(e, ...) throw
#else
#define TORCH_RETHROW(e, ...) \
do { \
e.add_context(::c10::str(__VA_ARGS__)); \
throw; \
} while (false)
#endif

// A utility macro to provide assert()-like functionality; that is, enforcement
// of internal invariants in code. It supports an arbitrary number of extra
// arguments (evaluated only on failure), which will be printed in the assert
Expand Down
6 changes: 3 additions & 3 deletions c10/util/Logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void ThrowEnforceNotMet(
const void* caller) {
c10::Error e(file, line, condition, msg, (*GetFetchStackTrace())(), caller);
if (FLAGS_caffe2_use_fatal_for_enforce) {
LOG(FATAL) << e.msg_stack()[0];
LOG(FATAL) << e.msg();
}
throw e;
}
Expand All @@ -63,8 +63,8 @@ void ThrowEnforceFiniteNotMet(

// PyTorch-style error message
// (This must be defined here for access to GetFetchStackTrace)
Error::Error(SourceLocation source_location, const std::string& msg)
: Error(msg, str(" (", source_location, ")\n", (*GetFetchStackTrace())())) {
Error::Error(SourceLocation source_location, std::string msg)
: Error(std::move(msg), str("Exception raised from ", source_location, " (most recent call first):\n", (*GetFetchStackTrace())())) {
}

using APIUsageLoggerType = std::function<void(const std::string&)>;
Expand Down
29 changes: 15 additions & 14 deletions caffe2/core/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <string>
#include <typeinfo>
#include <vector>
#include <sstream>

#include <c10/macros/Macros.h>
#include <c10/util/Registry.h>
Expand Down Expand Up @@ -154,9 +155,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
return inputs_.at(idx)->template Get<T>();
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
enf.AppendMessage(".\nOffending Blob name: ");
enf.AppendMessage(debug_def().input(idx));
enf.AppendMessage(".\n");
TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
}
throw enf;
}
Expand All @@ -180,9 +179,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
return tensor;
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
enf.AppendMessage(".\nOffending Blob name: ");
enf.AppendMessage(debug_def().input(idx));
enf.AppendMessage(".\n");
TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
}
throw enf;
}
Expand Down Expand Up @@ -521,26 +518,30 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
return;
}

bool found_input;
bool found_input = false;
bool found_output = false;
if (err->caller() != nullptr) {
std::ostringstream oss;
for (size_t i = 0; i < inputs_.size(); i++) {
if (inputs_[i]->GetRaw() == err->caller()) {
found_input = true;
err->AppendMessage(
"\n** while accessing input: " + debug_def().input(i));
oss << "while accessing input: " << debug_def().input(i);
break;
}
}
for (size_t i = 0; i < outputs_.size(); i++) {
if (outputs_[i]->GetRaw() == err->caller()) {
found_output = true;
if (found_input) {
err->AppendMessage("\n OR ");
oss << " OR ";
}
err->AppendMessage(
"\n** while accessing output: " + debug_def().output(i));
oss << "while accessing output: " << debug_def().output(i);
break;
}
}
if (found_input || found_output) {
err->add_context(oss.str());
}
}
}

Expand Down Expand Up @@ -1071,7 +1072,7 @@ class Operator : public OperatorBase {
return result;
} catch (EnforceNotMet& err) {
if (has_debug_def()) {
err.AppendMessage(
err.add_context(
"Error from operator: \n" + ProtoDebugString(debug_def()));
AddRelatedBlobInfo(&err);
}
Expand Down Expand Up @@ -1109,7 +1110,7 @@ class Operator : public OperatorBase {
return result;
} catch (EnforceNotMet& err) {
if (has_debug_def()) {
err.AppendMessage(
err.add_context(
"Error from operator: \n" + ProtoDebugString(debug_def()));
AddRelatedBlobInfo(&err);
}
Expand Down
3 changes: 1 addition & 2 deletions caffe2/ideep/utils/ideep_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class IDEEPOperator : public OperatorBase {
StopAllObservers();
return result;
} catch (EnforceNotMet& err) {
err.AppendMessage(getErrorMsg());
throw;
TORCH_RETHROW(err, getErrorMsg());
} catch (ideep::error& e) {
LOG(ERROR) << "IDEEP error:" << e.message;
throw;
Expand Down

0 comments on commit a058e93

Please sign in to comment.