Skip to content

Commit

Permalink
Add persistent keyword parameter to ThunderModule.named_buffers (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
shino16 authored Aug 9, 2024
1 parent 088798e commit 9f6e5b1
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions thunder/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self, *args, **kwargs):
def _named_parameters_or_buffers(self, overrides, orig_iter, prefix="", recurse=True, remove_duplicate=True):
seen_ids = set()
seen_names = set()
for k, v in itertools.chain(overrides.items(), orig_iter(remove_duplicate=remove_duplicate)):
for k, v in itertools.chain(overrides.items(), orig_iter):
if remove_duplicate:
id_v = id(v)
if id_v in seen_ids:
Expand All @@ -83,16 +83,21 @@ def _named_parameters_or_buffers(self, overrides, orig_iter, prefix="", recurse=
def named_parameters(self, prefix="", recurse=True, remove_duplicate=True):
yield from self._named_parameters_or_buffers(
self._overrides_parameters,
self._model.named_parameters,
self._model.named_parameters(remove_duplicate=remove_duplicate),
prefix=prefix,
recurse=recurse,
remove_duplicate=remove_duplicate,
)

def named_buffers(self, prefix="", recurse=True, remove_duplicate=True):
def named_buffers(self, prefix="", recurse=True, remove_duplicate=True, *, persistent=None):
if persistent is not None:
orig_buffers = self._model.named_buffers(remove_duplicate=remove_duplicate, persistent=persistent)
else:
orig_buffers = self._model.named_buffers(remove_duplicate=remove_duplicate)

yield from self._named_parameters_or_buffers(
self._overrides_buffers,
self._model.named_buffers,
orig_buffers,
prefix=prefix,
recurse=recurse,
remove_duplicate=remove_duplicate,
Expand Down

0 comments on commit 9f6e5b1

Please sign in to comment.