From 7c2f0afb1c0ff4dbfb8daeed8cef65074651c92a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 22 Dec 2024 16:44:13 -1000 Subject: [PATCH] update `get_parameter_dtype` (#10342) add: q --- src/diffusers/models/modeling_utils.py | 48 ++++++++++++++++++-------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 872d4d73d41f..d236ebb83983 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -99,21 +99,39 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: - try: - return next(parameter.parameters()).dtype - except StopIteration: - try: - return next(parameter.buffers()).dtype - except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for param in parameter.parameters(): + last_dtype = param.dtype + if param.is_floating_point(): + return param.dtype + + for buffer in parameter.buffers(): + last_dtype = buffer.dtype + if buffer.is_floating_point(): + return buffer.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for tuple in gen: + last_tuple = tuple + if tuple[1].is_floating_point(): + return tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype class ModelMixin(torch.nn.Module, PushToHubMixin):