Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

模型量化问题 #374

Open
yoummiegao opened this issue Jan 12, 2023 · 0 comments
Open

模型量化问题 #374

yoummiegao opened this issue Jan 12, 2023 · 0 comments

Comments

@yoummiegao
Copy link

yoummiegao commented Jan 12, 2023

由于希望模型部署到端侧非GPU设备,需要将模型转化为int8的tflite模型,采用keras默认的QAT 量化意识训练实现方式。

    input = KL.Input(shape=(args.height, None, 3), name='input')
    labels = KL.Input(name='labels', shape=[None], dtype='int64')
    # labels = KL.Input(name='labels', shape=[None], dtype='float32')
    # labels = KL.Input(name='labels', shape=[1,1], dtype='float32')
    input_length = KL.Input(name='input_length', shape=[1], dtype='int64')
    label_length = KL.Input(name='label_length', shape=[1], dtype='int64')
    y_predict = densenet(input=input, num_classes=nclass)
    basemodel = Model(inputs=input, outputs=y_predict)
    basemodel.summary()
    loss_out = KL.Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_predict, labels, input_length, label_length])
    model = Model(inputs=[input, labels, input_length, label_length], outputs=loss_out)
    model.summary()
    # model.compile(optimizer='adam', loss='categorical_crossentropy') #need to further check
    # this is automatically done by inner built-in calculation
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy'])
    # model.compile(loss={'ctc':loss_out}, optimizer='adam', metrics=['accuracy'])
    # customcallback = CustomCallback()
    init_epoch = 0
    if args.restore != None:
        restore_path=''
        if args.restore_path is not None:
            restore_path = args.restore_path
        else:
            restore_path = chkp_dir+'/chkp-{:04}.chkp'.format(args.restore)
        print('Restore from epoch: ', args.restore, restore_path)
        if restore_path.endswith('h5'):
            model.load_weights(restore_path, by_name=True, skip_mismatch=True)
        else:
            model.load_weights(restore_path)
        init_epoch = int(args.restore)
    quantize_model = tfmot.quantization.keras.quantize_model
    q_aware_model = quantize_model(model)
    # q_aware_model = quantize_model(basemodel)
    q_aware_model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy'])
    if args.restore_qat != None:
        if args.restore_qat_path is not None:
            restore_qat_path = args.restore_qat_path
        else:
            restore_qat_path = chkp_qat_dir+'/chkp-{:04}.chkp'.format(args.restore_qat)
        print('Restore from epoch: ', restore_qat_path)
        q_aware_model.load_weights(restore_qat_path)
        # model.load_model(restore_path)
        # model = tf.keras.models.load_model(save_path)
        init_epoch_qat = int(args.restore_qat)
        q_aware_model.summary()
        print('-----------Start qat training-----------')
        q_aware_model.fit(
            train_dataset,
            epochs=args.epoch_qat,
            validation_data=val_dataset,
            initial_epoch = init_epoch,
            callbacks=[
                tf.keras.callbacks.TensorBoard(log_dir, histogram_freq=1, update_freq=500, write_images=True),
                tf.keras.callbacks.ModelCheckpoint(filepath=chkp_path, save_weights_only=True, verbose=1),
                tf.keras.callbacks.LearningRateScheduler(tf.keras.optimizers.schedules.PiecewiseConstantDecay(lr_boundaries, lr_values), verbose = 1)
                # customcallback
                # tf.keras.callbacks.LearningRateScheduler(lambda epoch: float(learning_rate[epoch]))
        ])

但运行QAT训练就会报以下错误,大家有遇到类似问题的吗?针对报错提示,有什么解决办法吗?非常感谢!

File “/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/ops.py”, line 1939, in _create_c_op
    raise ValueError(e.message)
ValueError: Exception encountered when calling layer “batch_normalization” (type BatchNormalization).
Shape must be rank 4 but is rank 7 for ‘{{node batch_normalization/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format=“NHWC”, epsilon=1.1e-05, exponential_avg_factor=1, is_training=false](Placeholder, batch_normalization/ReadVariableOp, batch_normalization/ReadVariableOp_1, batch_normalization/FusedBatchNormV3/ReadVariableOp, batch_normalization/FusedBatchNormV3/ReadVariableOp_1)’ with input shapes: [1,1,1,?,16,?,64], [64], [64], [64], [64].
Call arguments received:
  • inputs=tf.Tensor(shape=(1, 1, 1, None, 16, None, 64), dtype=float32)
  • training=None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant