diff --git a/thunder/core/module.py b/thunder/core/module.py index 2faab4566f..974aa092db 100644 --- a/thunder/core/module.py +++ b/thunder/core/module.py @@ -64,7 +64,7 @@ def forward(self, *args, **kwargs): def _named_parameters_or_buffers(self, overrides, orig_iter, prefix="", recurse=True, remove_duplicate=True): seen_ids = set() seen_names = set() - for k, v in itertools.chain(overrides.items(), orig_iter(remove_duplicate=remove_duplicate)): + for k, v in itertools.chain(overrides.items(), orig_iter): if remove_duplicate: id_v = id(v) if id_v in seen_ids: @@ -83,16 +83,21 @@ def _named_parameters_or_buffers(self, overrides, orig_iter, prefix="", recurse= def named_parameters(self, prefix="", recurse=True, remove_duplicate=True): yield from self._named_parameters_or_buffers( self._overrides_parameters, - self._model.named_parameters, + self._model.named_parameters(remove_duplicate=remove_duplicate), prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate, ) - def named_buffers(self, prefix="", recurse=True, remove_duplicate=True): + def named_buffers(self, prefix="", recurse=True, remove_duplicate=True, *, persistent=None): + if persistent is not None: + orig_buffers = self._model.named_buffers(remove_duplicate=remove_duplicate, persistent=persistent) + else: + orig_buffers = self._model.named_buffers(remove_duplicate=remove_duplicate) + yield from self._named_parameters_or_buffers( self._overrides_buffers, - self._model.named_buffers, + orig_buffers, prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate,