Skip to content

Commit

Permalink
Change the way to validate keep_num_dims attribute for new tf. (#2367)
Browse files Browse the repository at this point in the history
* Change the way to validate keep_num_dims attribute for new tf.

Signed-off-by: Jay Zhang <[email protected]>

* Fix a lint issue.

Signed-off-by: Jay Zhang <[email protected]>

---------

Signed-off-by: Jay Zhang <[email protected]>
  • Loading branch information
fatcat-z authored Nov 20, 2024
1 parent f85e88e commit 79cedf9
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions tf2onnx/tflite_handlers/tfl_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs):
separate_fused_activation_function(ctx, node)
utils.make_sure(node.attr['weights_format'].s == b'DEFAULT',
"Only default weights format supported for fully connected op")
utils.make_sure(node.attr['keep_num_dims'].i == 0,
"Only keep_num_dims=False supported for fully connected op")
if node.attr['asymmetric_quantize_inputs'].i == 1:
dynamic_quantize_inputs(ctx, node)

if ctx.get_rank(node.input[0]) != 2:
if node.attr['keep_num_dims'].i == 0 and ctx.get_rank(node.input[0]) != 2:
# When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed
utils.make_sure(ctx.get_rank(node.input[1]) == 2, "weights for FullyConnected must have rank 2")
weights_shape = ctx.get_shape(node.input[1])[1]
Expand Down

0 comments on commit 79cedf9

Please sign in to comment.