Skip to content

Commit

Permalink
cache params/buffers/submodules in ThunderModule (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored May 17, 2024
1 parent 6deb2cc commit 54bb614
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 29 deletions.
78 changes: 50 additions & 28 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,30 @@ def get_computation_inputs_and_intermediates(computation_trace):
return inputs_list, intermediates_set


def get_parameter_or_buffer_or_submodule_name_and_root(provenance):
assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR
assert provenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE
typ = provenance.inputs[0].inputs[1].value
name = [provenance.inputs[1].value]
mprovenance = provenance.inputs[0].inputs[0]

while (
mprovenance.inst is PseudoInst.BINARY_SUBSCR
and mprovenance.inputs[1].inst is PseudoInst.CONSTANT
and mprovenance.inputs[0].inst is PseudoInst.LOAD_ATTR
and mprovenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE
):
assert (
mprovenance.inputs[0].inputs[1].inst is PseudoInst.CONSTANT
and mprovenance.inputs[0].inputs[1].value == "_modules"
)

name_component = mprovenance.inputs[1].value
name.insert(0, name_component)
mprovenance = mprovenance.inputs[0].inputs[0]
return typ, name, mprovenance


def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, *, has_epilogue: bool):
already_unpacked: dict[int, Proxy] = {}
orig_modules: dict[int, Proxy] = {}
Expand All @@ -1160,19 +1184,23 @@ def unpack(v: Variable | Proxy) -> Proxy:
prologue_trace.add_name(p.name)

def from_input(provenance, *, new_output=False):
assert new_output
if provenance.inst == PseudoInst.INPUT_ARGS:
assert new_output
param_ordering[id(pro_args_proxy)] = (pro_args_proxy, [0])
return pro_args_proxy
elif provenance.inst == PseudoInst.INPUT_KWARGS:
assert new_output
param_ordering[id(pro_kwargs_proxy)] = (pro_kwargs_proxy, [1])
return pro_kwargs_proxy
elif provenance.inst == PseudoInst.INPUT_FN:
if provenance.ext_flag & EXT_FLAG_IS_MODULE:
name = "module"
else:
name = "fn"
output = Proxy(name=name)
if new_output:
output = Proxy(name=name)
else:
output = p
param_ordering[id(output)] = (output, [3])
provenance.proxy = output
bsym = prims.unpack_function_obj.bind(output, output=output)
Expand Down Expand Up @@ -1202,28 +1230,8 @@ def from_constant(provenance, *, new_output=False):
raise NotImplementedError(f"constant of type {type(provenance.value)} {provenance.value}")

def unpack_parameter_or_buffer_or_submodule(provenance, *, new_output=False):
assert provenance.inputs[0].inst is PseudoInst.LOAD_ATTR
assert provenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE
typ = provenance.inputs[0].inputs[1].value
name = [provenance.inputs[1].value]
mprovenance = provenance.inputs[0].inputs[0]

while (
mprovenance.inst is PseudoInst.BINARY_SUBSCR
and mprovenance.inputs[1].inst is PseudoInst.CONSTANT
and mprovenance.inputs[0].inst is PseudoInst.LOAD_ATTR
and mprovenance.inputs[0].inputs[0].ext_flag & EXT_FLAG_IS_MODULE
):
assert (
mprovenance.inputs[0].inputs[1].inst is PseudoInst.CONSTANT
and mprovenance.inputs[0].inputs[1].value == "_modules"
)

name_component = mprovenance.inputs[1].value
name.insert(0, name_component)
mprovenance = mprovenance.inputs[0].inputs[0]

root_module = from_provenance(mprovenance, new_output=True)
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(provenance)
root_module = from_provenance(root_module_provenance, new_output=True)
if new_output:
output = Proxy("m") # name? collectify?
else:
Expand Down Expand Up @@ -1415,11 +1423,9 @@ def from_provenance(provenance, *, new_output=False):


def process_recorded_modifications(ctx, epilogue_trace):
root_for_provenances = {}
for modified_object, modifications in ctx._additional_outputs.items():
umodified_object = modified_object.value
## we want this to created in the compute trace context for namespace...
modified_object_proxy = Proxy(history=modified_object.provenance)
epilogue_trace.add_name(modified_object_proxy.name)

if isinstance(umodified_object, dict):
last_modification = {}
Expand All @@ -1435,8 +1441,24 @@ def process_recorded_modifications(ctx, epilogue_trace):
(value,) = args
assert isinstance(value.value, Proxy)

assert modified_object.provenance.inst is PseudoInst.LOAD_ATTR
assert modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
assert modified_object.provenance.inputs[1].value == "_buffers"

typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
modified_object.provenance.inputs[0]
)
assert typ == "_modules"
root_module_proxy = root_for_provenances.get(root_module_provenance)
if root_module_proxy is None:
## we want this to created in the compute trace context for namespace...
root_module_proxy = Proxy(history=root_module_provenance)
epilogue_trace.add_name(root_module_proxy.name)
root_for_provenances[root_module_provenance] = root_module_proxy

name = ".".join(name + [k])
with tracectx(epilogue_trace):
bsym = prims.pack_setitem.bind(modified_object_proxy, k, value.value, output=None)
bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None)
epilogue_trace.bound_symbols.append(bsym)
else:
raise NotImplementedError(f"Modifications {inst} on dicts are not supported")
Expand Down
15 changes: 14 additions & 1 deletion thunder/core/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import contextmanager
import itertools
from typing import Any

import torch as pytorch
Expand Down Expand Up @@ -26,7 +27,13 @@ def __init__(self, model, compiled_model_call):
self._forward_fn = compiled_model_call

# overrides for parameters and buffers (see get_buffer/get_parameter)
self._overrides = {}
# we populate these here for performance reasons (sam as module cache),
# a single dict lookup is cheaper than traversin the module
# hierarchy, see https://github.com/Lightning-AI/lightning-thunder/issues/396#issuecomment-2113231498
self._overrides = {
k: v for k, v in itertools.chain(self._model.named_parameters(), self._model.named_buffers())
}
self._module_cache = {k: v for k, v in self._model.named_modules()}

self._null = object()

Expand All @@ -36,13 +43,19 @@ def get_buffer(self, name):
return p
return self._model.get_buffer(name)

def set_buffer(self, name, value):
p = self._overrides[name] = value

def get_parameter(self, name):
p = self._overrides.get(name, self._null)
if p is not self._null:
return p
return self._model.get_parameter(name)

def get_submodule(self, name):
p = self._module_cache.get(name, self._null)
if p is not self._null:
return p
return self._model.get_submodule(name)

def forward(self, *args, **kwargs):
Expand Down
44 changes: 44 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class PrimIDs(Enum):
UNPACK_SUBMODULE = auto()
UNPACK_THUNDER_MODULE = auto()
CONSTRUCT_TUPLE = auto()
PACK_BUFFER = auto()
PACK_SETITEM = auto()
# TODO: UNPACK_SET
# Utility prims
Expand Down Expand Up @@ -1115,6 +1116,49 @@ def unpack_buffer_impl(o: Any, key: str, /) -> Any:
)


# NOTE PACK_BUFFER is intended only to be bound to directly, and not called
def pack_buffer_meta(o: Any, key: Any, value: Any) -> Any:
raise NotImplementedError


def pack_buffer_printer(
bsym: BoundSymbol, out_printables: Any, arg_printables: Sequence[Printable], kwarg_printables: dict[str, Printable]
):
utils.check(
len(arg_printables) == 3,
lambda: f"Expected three arguments for pack_buffer but got {arg_printables}",
exception_type=AssertionError,
)
utils.check(
len(kwarg_printables) == 0,
lambda: f"Expected no kwargs for pack_buffer but got {kwarg_printables}",
exception_type=AssertionError,
)

# Converts printables to strings
obj, key, value = arg_printables
obj_str = codeutils.prettyprint(obj)
key_str = codeutils.prettyprint(key)
value_str = codeutils.prettyprint(value)
return f"{obj_str}.set_buffer({key_str}, {value_str})"


def pack_buffer_impl(o: Any, key: Any, v: Any) -> None:
# o[key] = v
XXX
return None


pack_buffer = make_prim(
PrimIDs.PACK_BUFFER,
"unpack_buffer",
meta=pack_buffer_meta,
python_printer=pack_buffer_printer,
python_impl=pack_buffer_impl,
tags=(OpTags.DONT_DCE,),
)


# NOTE PACK_SETITEM is intended only to be bound to directly, and not called
def pack_setitem_meta(o: Any, key: Any, value: Any) -> Any:
raise NotImplementedError
Expand Down

0 comments on commit 54bb614

Please sign in to comment.