From ad5b1388269446c7d73d8005312d0ea4a65c7af5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 11 Mar 2024 23:10:44 +0000 Subject: [PATCH] pre-commit: running and fixing... --- thunder/executors/cudnn_layernormex.py | 8 ++++++-- thunder/executors/cudnnex.py | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/thunder/executors/cudnn_layernormex.py b/thunder/executors/cudnn_layernormex.py index c0c93d314f..9f0021dbc9 100644 --- a/thunder/executors/cudnn_layernormex.py +++ b/thunder/executors/cudnn_layernormex.py @@ -35,7 +35,9 @@ def cudnn_available() -> bool: def make_cacheable_cudnn_graph_inputs(func): def wrapper(*args, **kwargs): cudnn_input_args = [ - CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype, args.device_index) if isinstance(arg, torch.Tensor) else arg + CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype, args.device_index) + if isinstance(arg, torch.Tensor) + else arg for arg in args ] return func(*cudnn_input_args, **kwargs) @@ -93,7 +95,9 @@ def _transform_layer_norm_inputs(a, normalized_shape, weight, bias): # Assume strides to be NCHW contiguous assumed_stride = (elements_to_normalize, 1, 1, 1) a_4d = CudnnTensorAttributes((batch_size, elements_to_normalize, 1, 1), assumed_stride, a.dtype, a.device.index) - weight_4d = CudnnTensorAttributes((1, elements_to_normalize, 1, 1), assumed_stride, weight.dtype, weight.device.index) + weight_4d = CudnnTensorAttributes( + (1, elements_to_normalize, 1, 1), assumed_stride, weight.dtype, weight.device.index + ) bias_4d = CudnnTensorAttributes((1, elements_to_normalize, 1, 1), assumed_stride, bias.dtype, bias.device.index) return a_4d, weight_4d, bias_4d diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 64df20cb15..0361575437 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -72,6 +72,7 @@ class CudnnTensorAttributes: dtype: torch.dtype device_index: int + from collections import OrderedDict @@ -222,7 +223,9 @@ def compute_NHWC_strides(shape): # cudnn does not support boolean attn_mask, so make one with -inf attn_mask_dtype = query.dtype if attn_mask.dtype in [torch.bool, dtypes.bool8] else attn_mask.dtype - attn_mask_4d = CudnnTensorAttributes(attn_mask_shape, compute_NHWC_strides(attn_mask_shape), attn_mask_dtype, attn_mask.device.index) + attn_mask_4d = CudnnTensorAttributes( + attn_mask_shape, compute_NHWC_strides(attn_mask_shape), attn_mask_dtype, attn_mask.device.index + ) return query_4d, key_4d, value_4d, attn_mask_4d