diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 9ed0ea8d1b..8916b1767e 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1135,6 +1135,30 @@ def get_computation_inputs_and_intermediates(computation_trace): return inputs_list, intermediates_set +def get_parameter_or_buffer_or_submodule_name_and_root(provenance): + assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR + assert provenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE + typ = provenance.inputs[0].inputs[1].value + name = [provenance.inputs[1].value] + mprovenance = provenance.inputs[0].inputs[0] + + while ( + mprovenance.inst is PseudoInst.BINARY_SUBSCR + and mprovenance.inputs[1].inst is PseudoInst.CONSTANT + and mprovenance.inputs[0].inst is PseudoInst.LOAD_ATTR + and mprovenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE + ): + assert ( + mprovenance.inputs[0].inputs[1].inst is PseudoInst.CONSTANT + and mprovenance.inputs[0].inputs[1].value == "_modules" + ) + + name_component = mprovenance.inputs[1].value + name.insert(0, name_component) + mprovenance = mprovenance.inputs[0].inputs[0] + return typ, name, mprovenance + + def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, *, has_epilogue: bool): already_unpacked: dict[int, Proxy] = {} orig_modules: dict[int, Proxy] = {} @@ -1160,11 +1184,12 @@ def unpack(v: Variable | Proxy) -> Proxy: prologue_trace.add_name(p.name) def from_input(provenance, *, new_output=False): - assert new_output if provenance.inst == PseudoInst.INPUT_ARGS: + assert new_output param_ordering[id(pro_args_proxy)] = (pro_args_proxy, [0]) return pro_args_proxy elif provenance.inst == PseudoInst.INPUT_KWARGS: + assert new_output param_ordering[id(pro_kwargs_proxy)] = (pro_kwargs_proxy, [1]) return pro_kwargs_proxy elif provenance.inst == PseudoInst.INPUT_FN: @@ -1172,7 +1197,10 @@ def from_input(provenance, *, new_output=False): name = "module" else: name = "fn" - output = Proxy(name=name) + if new_output: + output = Proxy(name=name) + else: + output = p param_ordering[id(output)] = (output, [3]) provenance.proxy = output bsym = prims.unpack_function_obj.bind(output, output=output) @@ -1202,28 +1230,8 @@ def from_constant(provenance, *, new_output=False): raise NotImplementedError(f"constant of type {type(provenance.value)} {provenance.value}") def unpack_parameter_or_buffer_or_submodule(provenance, *, new_output=False): - assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR - assert provenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE - typ = provenance.inputs[0].inputs[1].value - name = [provenance.inputs[1].value] - mprovenance = provenance.inputs[0].inputs[0] - - while ( - mprovenance.inst is PseudoInst.BINARY_SUBSCR - and mprovenance.inputs[1].inst is PseudoInst.CONSTANT - and mprovenance.inputs[0].inst is PseudoInst.LOAD_ATTR - and mprovenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE - ): - assert ( - mprovenance.inputs[0].inputs[1].inst is PseudoInst.CONSTANT - and mprovenance.inputs[0].inputs[1].value == "_modules" - ) - - name_component = mprovenance.inputs[1].value - name.insert(0, name_component) - mprovenance = mprovenance.inputs[0].inputs[0] - - root_module = from_provenance(mprovenance, new_output=True) + typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(provenance) + root_module = from_provenance(root_module_provenance, new_output=True) if new_output: output = Proxy("m") # name? collectify? else: @@ -1415,11 +1423,9 @@ def from_provenance(provenance, *, new_output=False): def process_recorded_modifications(ctx, epilogue_trace): + root_for_provenances = {} for modified_object, modifications in ctx._additional_outputs.items(): umodified_object = modified_object.value - ## we want this to created in the compute trace context for namespace... - modified_object_proxy = Proxy(history=modified_object.provenance) - epilogue_trace.add_name(modified_object_proxy.name) if isinstance(umodified_object, dict): last_modification = {} @@ -1435,8 +1441,24 @@ def process_recorded_modifications(ctx, epilogue_trace): (value,) = args assert isinstance(value.value, Proxy) + assert modified_object.provenance.inst is PseudoInst.LOAD_ATTR + assert modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT + assert modified_object.provenance.inputs[1].value == "_buffers" + + typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root( + modified_object.provenance.inputs[0] + ) + assert typ == "_modules" + root_module_proxy = root_for_provenances.get(root_module_provenance) + if root_module_proxy is None: + ## we want this to created in the compute trace context for namespace... + root_module_proxy = Proxy(history=root_module_provenance) + epilogue_trace.add_name(root_module_proxy.name) + root_for_provenances[root_module_provenance] = root_module_proxy + + name = ".".join(name + [k]) with tracectx(epilogue_trace): - bsym = prims.pack_setitem.bind(modified_object_proxy, k, value.value, output=None) + bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None) epilogue_trace.bound_symbols.append(bsym) else: raise NotImplementedError(f"Modifications {inst} on dicts are not supported") diff --git a/thunder/core/module.py b/thunder/core/module.py index c0ffa021df..d3430817c8 100644 --- a/thunder/core/module.py +++ b/thunder/core/module.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +import itertools from typing import Any import torch as pytorch @@ -26,7 +27,13 @@ def __init__(self, model, compiled_model_call): self._forward_fn = compiled_model_call # overrides for parameters and buffers (see get_buffer/get_parameter) - self._overrides = {} + # we populate these here for performance reasons (sam as module cache), + # a single dict lookup is cheaper than traversin the module + # hierarchy, see https://github.com/Lightning-AI/lightning-thunder/issues/396#issuecomment-2113231498 + self._overrides = { + k: v for k, v in itertools.chain(self._model.named_parameters(), self._model.named_buffers()) + } + self._module_cache = {k: v for k, v in self._model.named_modules()} self._null = object() @@ -36,6 +43,9 @@ def get_buffer(self, name): return p return self._model.get_buffer(name) + def set_buffer(self, name, value): + p = self._overrides[name] = value + def get_parameter(self, name): p = self._overrides.get(name, self._null) if p is not self._null: @@ -43,6 +53,9 @@ def get_parameter(self, name): return self._model.get_parameter(name) def get_submodule(self, name): + p = self._module_cache.get(name, self._null) + if p is not self._null: + return p return self._model.get_submodule(name) def forward(self, *args, **kwargs): diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 573291d760..1c9b847743 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -125,6 +125,7 @@ class PrimIDs(Enum): UNPACK_SUBMODULE = auto() UNPACK_THUNDER_MODULE = auto() CONSTRUCT_TUPLE = auto() + PACK_BUFFER = auto() PACK_SETITEM = auto() # TODO: UNPACK_SET # Utility prims @@ -1115,6 +1116,49 @@ def unpack_buffer_impl(o: Any, key: str, /) -> Any: ) +# NOTE PACK_BUFFER is intended only to be bound to directly, and not called +def pack_buffer_meta(o: Any, key: Any, value: Any) -> Any: + raise NotImplementedError + + +def pack_buffer_printer( + bsym: BoundSymbol, out_printables: Any, arg_printables: Sequence[Printable], kwarg_printables: dict[str, Printable] +): + utils.check( + len(arg_printables) == 3, + lambda: f"Expected three arguments for pack_buffer but got {arg_printables}", + exception_type=AssertionError, + ) + utils.check( + len(kwarg_printables) == 0, + lambda: f"Expected no kwargs for pack_buffer but got {kwarg_printables}", + exception_type=AssertionError, + ) + + # Converts printables to strings + obj, key, value = arg_printables + obj_str = codeutils.prettyprint(obj) + key_str = codeutils.prettyprint(key) + value_str = codeutils.prettyprint(value) + return f"{obj_str}.set_buffer({key_str}, {value_str})" + + +def pack_buffer_impl(o: Any, key: Any, v: Any) -> None: + # o[key] = v + XXX + return None + + +pack_buffer = make_prim( + PrimIDs.PACK_BUFFER, + "unpack_buffer", + meta=pack_buffer_meta, + python_printer=pack_buffer_printer, + python_impl=pack_buffer_impl, + tags=(OpTags.DONT_DCE,), +) + + # NOTE PACK_SETITEM is intended only to be bound to directly, and not called def pack_setitem_meta(o: Any, key: Any, value: Any) -> Any: raise NotImplementedError