We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
由于希望模型部署到端侧非GPU设备,需要将模型转化为int8的tflite模型,采用keras默认的QAT 量化意识训练实现方式。
input = KL.Input(shape=(args.height, None, 3), name='input') labels = KL.Input(name='labels', shape=[None], dtype='int64') # labels = KL.Input(name='labels', shape=[None], dtype='float32') # labels = KL.Input(name='labels', shape=[1,1], dtype='float32') input_length = KL.Input(name='input_length', shape=[1], dtype='int64') label_length = KL.Input(name='label_length', shape=[1], dtype='int64') y_predict = densenet(input=input, num_classes=nclass) basemodel = Model(inputs=input, outputs=y_predict) basemodel.summary() loss_out = KL.Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_predict, labels, input_length, label_length]) model = Model(inputs=[input, labels, input_length, label_length], outputs=loss_out) model.summary() # model.compile(optimizer='adam', loss='categorical_crossentropy') #need to further check # this is automatically done by inner built-in calculation model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy']) # model.compile(loss={'ctc':loss_out}, optimizer='adam', metrics=['accuracy']) # customcallback = CustomCallback() init_epoch = 0 if args.restore != None: restore_path='' if args.restore_path is not None: restore_path = args.restore_path else: restore_path = chkp_dir+'/chkp-{:04}.chkp'.format(args.restore) print('Restore from epoch: ', args.restore, restore_path) if restore_path.endswith('h5'): model.load_weights(restore_path, by_name=True, skip_mismatch=True) else: model.load_weights(restore_path) init_epoch = int(args.restore) quantize_model = tfmot.quantization.keras.quantize_model q_aware_model = quantize_model(model) # q_aware_model = quantize_model(basemodel) q_aware_model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam', metrics=['accuracy']) if args.restore_qat != None: if args.restore_qat_path is not None: restore_qat_path = args.restore_qat_path else: restore_qat_path = chkp_qat_dir+'/chkp-{:04}.chkp'.format(args.restore_qat) print('Restore from epoch: ', restore_qat_path) q_aware_model.load_weights(restore_qat_path) # model.load_model(restore_path) # model = tf.keras.models.load_model(save_path) init_epoch_qat = int(args.restore_qat) q_aware_model.summary() print('-----------Start qat training-----------') q_aware_model.fit( train_dataset, epochs=args.epoch_qat, validation_data=val_dataset, initial_epoch = init_epoch, callbacks=[ tf.keras.callbacks.TensorBoard(log_dir, histogram_freq=1, update_freq=500, write_images=True), tf.keras.callbacks.ModelCheckpoint(filepath=chkp_path, save_weights_only=True, verbose=1), tf.keras.callbacks.LearningRateScheduler(tf.keras.optimizers.schedules.PiecewiseConstantDecay(lr_boundaries, lr_values), verbose = 1) # customcallback # tf.keras.callbacks.LearningRateScheduler(lambda epoch: float(learning_rate[epoch])) ])
但运行QAT训练就会报以下错误,大家有遇到类似问题的吗?针对报错提示,有什么解决办法吗?非常感谢!
File “/usr/local/lib/python3.8/dist-packages/tensorflow/python/framework/ops.py”, line 1939, in _create_c_op raise ValueError(e.message) ValueError: Exception encountered when calling layer “batch_normalization” (type BatchNormalization). Shape must be rank 4 but is rank 7 for ‘{{node batch_normalization/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format=“NHWC”, epsilon=1.1e-05, exponential_avg_factor=1, is_training=false](Placeholder, batch_normalization/ReadVariableOp, batch_normalization/ReadVariableOp_1, batch_normalization/FusedBatchNormV3/ReadVariableOp, batch_normalization/FusedBatchNormV3/ReadVariableOp_1)’ with input shapes: [1,1,1,?,16,?,64], [64], [64], [64], [64]. Call arguments received: • inputs=tf.Tensor(shape=(1, 1, 1, None, 16, None, 64), dtype=float32) • training=None
The text was updated successfully, but these errors were encountered:
No branches or pull requests
由于希望模型部署到端侧非GPU设备,需要将模型转化为int8的tflite模型,采用keras默认的QAT 量化意识训练实现方式。
但运行QAT训练就会报以下错误,大家有遇到类似问题的吗?针对报错提示,有什么解决办法吗?非常感谢!
The text was updated successfully, but these errors were encountered: