将keras模型(.h5)保存为TensorFlow SavedModel格式时出现问题(KeyError:“输入”)

9rnv2umw  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(101)

在TF 2.4.0中,我正在训练Keras RetinaNet模型(代码来自https://github.com/fizyr/keras-retinanet)。训练完成后,我想将model.h5转换为TensorFlow SavedModel格式。但是我有一个错误KeyError: 'inputs'
转换代码:

# Import libraries
import tensorflow as tf
from tensorflow import keras
from keras_retinanet import models
from keras_retinanet.models import load_model

# Load the model
model = load_model("model.h5", backbone_name="resnet50")

# Save the model
model.save('model_tf', save_format='tf')

错误KeyError: 'inputs'Error KeyError: 'inputs'

Traceback (most recent call last):
  File "convert_h5_2_pb.py", line 11, in <module>
    model.save('model_tf', save_format='tf') 
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2001, in save
    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 156, in save_model
    saved_model_save.save(model, filepath, overwrite, include_optimizer,
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 89, in save
    save_lib.save(model, filepath, signatures, options)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1032, in save
    _, exported_graph, object_saver, asset_info = _build_meta_graph(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1198, in _build_meta_graph
    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1132, in _build_meta_graph_impl
    signatures = signature_serialization.find_function_to_export(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 150, in list_functions
    obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2612, in _list_functions_for_serialization
    functions = super(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3086, in _list_functions_for_serialization
    return (self._trackable_saved_model_saver
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 94, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 78, in functions_to_serialize
    return (self._get_serialized_attributes(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
    object_dict, function_dict = self._get_serialized_attributes_internal(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 56, in _get_serialized_attributes_internal
    super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 104, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 155, in wrap_layer_functions
    original_fns = _replace_child_layer_functions(layer, serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 273, in _replace_child_layer_functions
    child_layer._trackable_saved_model_saver._get_serialized_attributes(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 94, in _get_serialized_attributes
    object_dict, function_dict = self._get_serialized_attributes_internal(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 104, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 163, in wrap_layer_functions
    call_fn_with_losses = call_collection.add_function(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 505, in add_function
    self.add_trace(*self._input_signature)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 420, in add_trace
    trace_with_training(True)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 418, in trace_with_training
    fn.get_concrete_function(*args, **kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 550, in get_concrete_function
    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1299, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1205, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 515, in wrapper
    inputs = call_collection.get_input_arg_value(args, kwargs)
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 454, in get_input_arg_value
    return self.layer._get_call_arg_value(  # pylint: disable=protected-access
  File "/home/egorundel/venvs/test_venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 2603, in _get_call_arg_value
    return args_dict[arg_name]
KeyError: 'inputs'

我该怎么做才能解决这个问题?
我上网想改密码,但没用。

acruukt9

acruukt91#

第一件事首先,检查您使用的keras版本。你写道你正在使用tensorflow 2.4,但最新版本是2.14。然后在repo中它说“这个项目应该与keras 2.4和tensorflow 2.3.0一起工作,更新的版本可能会中断支持。",对keras来说是好的,但对Tensorflow我有一些怀疑。
尝试将其保存为推理模型(将用于训练的所有部分剥离应使其更易于保存),我建议以这种方式转换它:

from keras_retinanet import models

# Convert the model to an inference model
inference_model = models.convert_model(model)

# Save the model in SavedModel format
inference_model.save('model_tf')
dzhpxtsq

dzhpxtsq2#

解决办法已经找到了!
有必要更改pip安装包keras_resnet/layers/_batch_normalization.py中的文件,此处描述的代码行:github.com/broadinstitute/keras-resnet/commit/73c50f

相关问题