python 如何量化优化tflite模型的输入和输出

xienkqul  于 2023-04-10  发布在  Python
关注(0)|答案(3)|浏览(314)

我使用下面的代码生成一个量化的tflite模型

import tensorflow as tf

def representative_dataset_gen():
  for _ in range(num_calibration_steps):
    # Get sample input data as a numpy array in a method of your choosing.
    yield [input]

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()

根据post training quantization
结果模型将被完全量化,但为了方便起见,仍然采用浮点输入和输出
要编译tflite模型谷歌珊瑚边缘TPU我需要量化的输入和输出以及。
在模型中,我看到第一个网络层将float输入转换为input_uint8,最后一个层将output_uint8转换为float输出。如何编辑tflite模型以摆脱第一个和最后一个float层?
我知道我可以在转换过程中将输入和输出类型设置为uint8,但这与任何优化都不兼容。唯一可用的选择是使用假量化,这会导致糟糕的模型。

q0qdq0h2

q0qdq0h21#

你可以通过将inference_input_type和inference_output_type(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/lite.py#L460-L476)设置为int 8来避免浮点数变成int 8和int 8变成浮点数“quant/dequant”操作。

flvtvl50

flvtvl502#

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir
converter.optimizations = [tf.lite.Optimize.DEFAULT] 
converter.representative_dataset = representative_dataset
#The below 3 lines performs the input - output quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
bf1o4zei

bf1o4zei3#

这一点:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

tflite_model_quant = converter.convert()

生成具有Float32输入和输出的Float32模型。这:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()

生成具有UINT8输入和输出的UINT8模型
您可以通过以下方式确保这一点:

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

其返回:

input:  <class 'numpy.uint8'>
output:  <class 'numpy.uint8'>

如果你想要一个完整的UINT8量化,你可以通过使用netron可视化地检查你的模型来进行双重检查

相关问题