tensorflow 将LSTM模型转换为tflite

tktrz96b  于 2023-06-30  发布在  其他
关注(0)|答案(1)|浏览(273)

我在将LSTM模型转换为tflite时遇到了一个问题。
我正在转换这个模型,以便在我的Flutter应用程序中使用它。
该模型用于检测和翻译印度手语。
下面是我的转换代码。

import tensorflow as tf
from keras.models import load_model
model=load_model("action.h5")
tf.keras.models.save_model(model,'model.pbtxt')
converter =tf.lite.TFLiteConverter.from_keras_model(model=model)

lite_model=converter.convert()
with open("lite_model.tflite","wb") as f:
    f.write(lite_model)

如果我运行这段代码,会出现以下错误

INFO:tensorflow:Assets written to: model.pbtxt\assets
INFO:tensorflow:Assets written to: model.pbtxt\assets
INFO:tensorflow:Assets written to: C:\Users\gk\AppData\Local\Temp\tmp6276n3rh\assets
INFO:tensorflow:Assets written to: C:\Users\gk\AppData\Local\Temp\tmp6276n3rh\assets
---------------------------------------------------------------------------
ConverterError                            Traceback (most recent call last)
Input In [73], in <cell line: 7>()
      4 tf.keras.models.save_model(model,'model.pbtxt')
      5 converter =tf.lite.TFLiteConverter.from_keras_model(model=model)
----> 7 lite_model=converter.convert()
      8 with open("lite_model.tflite","wb") as f:
      9     f.write(lite_model)

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\lite.py:929, in _export_metrics.<locals>.wrapper(self, *args, **kwargs)
    926 @functools.wraps(convert_func)
    927 def wrapper(self, *args, **kwargs):
    928   # pylint: disable=protected-access
--> 929   return self._convert_and_export_metrics(convert_func, *args, **kwargs)

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\lite.py:908, in TFLiteConverterBase._convert_and_export_metrics(self, convert_func, *args, **kwargs)
    906 self._save_conversion_params_metric()
    907 start_time = time.process_time()
--> 908 result = convert_func(self, *args, **kwargs)
    909 elapsed_time_ms = (time.process_time() - start_time) * 1000
    910 if result:

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\lite.py:1338, in TFLiteKerasModelConverterV2.convert(self)
   1325 @_export_metrics
   1326 def convert(self):
   1327   """Converts a keras model based on instance variables.
   1328 
   1329   Returns:
   (...)
   1336       Invalid quantization parameters.
   1337   """
-> 1338   saved_model_convert_result = self._convert_as_saved_model()
   1339   if saved_model_convert_result:
   1340     return saved_model_convert_result

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\lite.py:1321, in TFLiteKerasModelConverterV2._convert_as_saved_model(self)
   1317   graph_def, input_tensors, output_tensors = (
   1318       self._convert_keras_to_saved_model(temp_dir))
   1319   if self.saved_model_dir:
   1320     return super(TFLiteKerasModelConverterV2,
-> 1321                  self).convert(graph_def, input_tensors, output_tensors)
   1322 finally:
   1323   shutil.rmtree(temp_dir, True)

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\lite.py:1131, in TFLiteConverterBaseV2.convert(self, graph_def, input_tensors, output_tensors)
   1126   logging.info("Using new converter: If you encounter a problem "
   1127                "please file a bug. You can opt-out "
   1128                "by setting experimental_new_converter=False")
   1130 # Converts model.
-> 1131 result = _convert_graphdef(
   1132     input_data=graph_def,
   1133     input_tensors=input_tensors,
   1134     output_tensors=output_tensors,
   1135     **converter_kwargs)
   1137 return self._optimize_tflite_model(
   1138     result, self._quant_mode, quant_io=self.experimental_new_quantizer)

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\convert_phase.py:212, in convert_phase.<locals>.actual_decorator.<locals>.wrapper(*args, **kwargs)
    210   else:
    211     report_error_message(str(converter_error))
--> 212   raise converter_error from None  # Re-throws the exception.
    213 except Exception as error:
    214   report_error_message(str(error))

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\convert_phase.py:205, in convert_phase.<locals>.actual_decorator.<locals>.wrapper(*args, **kwargs)
    202 @functools.wraps(func)
    203 def wrapper(*args, **kwargs):
    204   try:
--> 205     return func(*args, **kwargs)
    206   except ConverterError as converter_error:
    207     if converter_error.errors:

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\convert.py:794, in convert_graphdef(input_data, input_tensors, output_tensors, **kwargs)
    791   else:
    792     model_flags.output_arrays.append(util.get_tensor_name(output_tensor))
--> 794 data = convert(
    795     model_flags.SerializeToString(),
    796     conversion_flags.SerializeToString(),
    797     input_data.SerializeToString(),
    798     debug_info_str=debug_info.SerializeToString() if debug_info else None,
    799     enable_mlir_converter=enable_mlir_converter)
    800 return data

File ~\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\lite\python\convert.py:311, in convert(model_flags_str, conversion_flags_str, input_data_str, debug_info_str, enable_mlir_converter)
    309     for error_data in _metrics_wrapper.retrieve_collected_errors():
    310       converter_error.append_error(error_data)
--> 311     raise converter_error
    313 return _run_deprecated_conversion_binary(model_flags_str,
    314                                          conversion_flags_str, input_data_str,
    315                                          debug_info_str)

ConverterError: C:\Users\gk\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\saved_model\save.py:1325:0: error: 'tf.TensorListReserve' op requires element_shape to be static during TF Lite transformation pass
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
C:\Users\gk\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\saved_model\save.py:1325:0: error: failed to legalize operation 'tf.TensorListReserve' that was explicitly marked illegal
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
<unknown>:0: error: Lowering tensor list ops is failed. Please consider using Select TF ops and disabling `_experimental_lower_tensor_list_ops` flag in the TFLite converter object. For example, converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]\n converter._experimental_lower_tensor_list_ops = False

!python --version
​

它在converter.convert()中抛出一个错误。我是深度学习的新手,我尝试过很多其他方法,但都导致了同样的错误。
如果这个错误无法解决,请建议我该怎么办.....有没有其他模型可以用于有效地检测手语,也可以用于flutter应用程序。

hwamh0ep

hwamh0ep1#

model = tf.keras.models.load_model('./model_save/best.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
 converter.optimizations = [tf.lite.Optimize.DEFAULT]
 converter.target_spec.supported_ops = [
   tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
   tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
 ]
tflite_model = converter.convert()
with open("best63.tflite", 'wb') as f:
  f.write(tflite_model)

我用这个,这是工作!

相关问题