From d2f0a132f0ff488d6af8e51770e356a909ea3d10 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 30 Apr 2024 23:14:50 +0100 Subject: [PATCH] Feat (graph/standardize): default keepdim value --- src/brevitas/graph/standardize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 7bbbb201d..93e99eac9 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -59,7 +59,7 @@ def match_node(self, node: Node) -> bool: is_adaptive_2d_mean = ((2, 3) in node.args or [2, 3] in node.args or 'dim' in node.kwargs and (node.kwargs['dim'] == (2, 3) or node.kwargs['dim'] == [2, 3])) - is_adaptive_2d_mean = is_adaptive_2d_mean and not node.kwargs['keepdim'] + is_adaptive_2d_mean = is_adaptive_2d_mean and not node.kwargs.get('keepdim', False) return spr and is_adaptive_2d_mean def move_node_args_to_kwargs(self, node: Node):