From 79cedf926380cdf494dbfc6dd01eaa66122616d2 Mon Sep 17 00:00:00 2001 From: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:15:47 +0800 Subject: [PATCH] Change the way to validate keep_num_dims attribute for new tf. (#2367) * Change the way to validate keep_num_dims attribute for new tf. Signed-off-by: Jay Zhang * Fix a lint issue. Signed-off-by: Jay Zhang --------- Signed-off-by: Jay Zhang --- tf2onnx/tflite_handlers/tfl_math.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tf2onnx/tflite_handlers/tfl_math.py b/tf2onnx/tflite_handlers/tfl_math.py index add2f7de3..35b377c56 100644 --- a/tf2onnx/tflite_handlers/tfl_math.py +++ b/tf2onnx/tflite_handlers/tfl_math.py @@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs): separate_fused_activation_function(ctx, node) utils.make_sure(node.attr['weights_format'].s == b'DEFAULT', "Only default weights format supported for fully connected op") - utils.make_sure(node.attr['keep_num_dims'].i == 0, - "Only keep_num_dims=False supported for fully connected op") if node.attr['asymmetric_quantize_inputs'].i == 1: dynamic_quantize_inputs(ctx, node) - if ctx.get_rank(node.input[0]) != 2: + if node.attr['keep_num_dims'].i == 0 and ctx.get_rank(node.input[0]) != 2: # When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed utils.make_sure(ctx.get_rank(node.input[1]) == 2, "weights for FullyConnected must have rank 2") weights_shape = ctx.get_shape(node.input[1])[1]