diff --git a/thunder/__init__.py b/thunder/__init__.py index 54c94855d..c5eb46819 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -71,7 +71,7 @@ AnyProxy, ) from thunder.core.interpreter import print_interpreter_log, print_to_log -from thunder.core.jit_ext import thunder_general_jit +from thunder.core.jit_ext import thunder_general_jit, InnerException from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction # NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this @@ -814,7 +814,49 @@ def maybe_call_epilogue(cache_entry, result, pro_to_epi): return result + def unwrap_inner_exception(c: Callable) -> Callable: + def _thunder_unwrap_inner_exception(*args, **kwargs): + # Run the function, and caputre the exception if there is one. + try: + return c(*args, **kwargs) + except InnerException as e: + exc = e.value + + def internal_to_thunder(co): + if co is thunder_general_jit.__code__ or co is _thunder_unwrap_inner_exception.__code__: + return True + return co.co_filename.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") and ( + co.co_name in ("fn_", "fn_2") + ) + + # Iterate over the traceback and collect frames that don't correspond to thunder internal functions. + tb = exc.__traceback__ + tb_frames = [] + while tb != None: + co = tb.tb_frame.f_code + if not internal_to_thunder(co): + tb_frames.append(tb) + tb = tb.tb_next + + # Relink the non-internal traceback frames + if tb_frames: + top_tb = tb = tb_frames[0] + for _tb in tb_frames[1:]: + tb.tb_next = _tb + tb = _tb + exc.__traceback__ = top_tb + + # Re-raise the exception without retaining it in this stack frame to avoid leaking tensors. + try: + raise exc + except Exception: + del exc + raise # re-raises current exception + + return _thunder_unwrap_inner_exception + @wraps(fn) + @unwrap_inner_exception @update_call_statistics def fn_(*args, **kwargs) -> Any: if is_tracing(): diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index cf186b1ea..e822cb5b6 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -132,6 +132,11 @@ } +class InnerException(BaseException): + def __init__(self, *, value: BaseException): + self.value = value + + class JITSharpEdgeError(RuntimeError): """ Thrown when the program cannot be safely translated to a thunder program, @@ -1741,7 +1746,11 @@ def thunder_general_jit( with jit_ctx(ctx): with tracectx(computation_trace): - result = jfn(*args, **kwargs) + try: + result = jfn(*args, **kwargs) + except BaseException as e: + raise InnerException(value=e) + prims.python_return(result) computation_trace.set_current_source_location(None, None) process_recorded_modifications(ctx, epilogue_trace) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 580ac914d..08f53f0b7 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -850,6 +850,36 @@ def main(): assert weak_x() is None +def test_backtrace_filter(): + import thunder + + def fn1(): + fn2() + + def fn2(): + fn3() + + def fn3(): + raise ValueError + + jfn = thunder.jit(fn1) + + expected_frame_names = ["test_backtrace_filter", "_thunder_unwrap_inner_exception", "fn1", "fn2", "fn3"] + + try: + jfn() + except ValueError as e: + tb_frames = [] + tb = e.__traceback__ + while tb != None: + tb_frames.append(tb) + tb = tb.tb_next + frame_names = [tb.tb_frame.f_code.co_name for tb in tb_frames] + assert frame_names == expected_frame_names + except BaseException as e: + assert False, e # Wrong exception type. + + def test_walrus_operator(jit): def foo(a, b): c = (a := b)