我有一个简单的测试脚本,它只是保存和加载一个tensorflow模型。
在我的系统上从Python运行时我通过了。但是当我从pyinstaller包运行它时,无法加载模型,它失败并显示TypeError: Parameter to MergeFrom() must be instance of same class: expected tensorflow.TensorShapeProto got tensorflow.TensorShapeProto.
完整的剧本是
import tempfile
from dataclasses import dataclass
import cv2
import tensorflow as tf
@dataclass
class ColorFilterModel(tf.Module):
rgb_filter_strengths = (0., 1., 1.)
def compute_mahalonabis_sq_heatmap(self, image):
return tf.reduce_sum(tf.cast(image, tf.float32) * tf.constant(self.rgb_filter_strengths), axis=-1)
def test_basic_save_and_load_model():
print("Testing basic model save/load")
model = ColorFilterModel()
numpy_image = cv2.imread(tf.keras.utils.get_file('basalt_canyon.jpg', "https://raw.githubusercontent.com/petered/data/master/images/basalt_canyon.jpg"))
print(f"Image shape: {numpy_image.shape}")
with tempfile.TemporaryDirectory() as tmpdirname:
concrete_func = tf.function(model.compute_mahalonabis_sq_heatmap, input_signature=[tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8)]).get_concrete_function()
result_original = concrete_func(numpy_image)
print(f"Result shape: {result_original.shape}")
tf.saved_model.save(model, tmpdirname, signatures={'mahal': concrete_func}) # THIS LINE FAILS
loaded_func = tf.saved_model.load(tmpdirname).signatures['mahal']
result_loaded = loaded_func(image=numpy_image)['output_0']
assert tf.reduce_all(tf.equal(result_original, result_loaded))
print("Passed test_basic_save_and_load_model")
if __name__ == "__main__":
test_basic_save_and_load_model()
当从pyinstaller运行时,它首先得到错误消息
WARNING:tensorflow:AutoGraph is not available in this environment: functions lack code information. This is typical of some environments like the interactive Python shell. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code for more information.
然后输出:
Testing basic model save/load
Image shape: (1500, 2000, 3)
2023-05-24 17:20:13.912138: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Result shape: (1500, 2000)
Traceback (most recent call last):
File "main.py", line 26, in <module>
test_basic_save_and_load_model()
File "video_scanner/detection_utils/test_standalone_model_io.py", line 22, in test_basic_save_and_load_model
tf.saved_model.save(model, tmpdirname, signatures={'mahal': concrete_func})
File "tensorflow/python/saved_model/save.py", line 1232, in save
File "tensorflow/python/saved_model/save.py", line 1268, in save_and_return_nodes
File "tensorflow/python/saved_model/save.py", line 1441, in _build_meta_graph
File "tensorflow/python/saved_model/save.py", line 1396, in _build_meta_graph_impl
File "tensorflow/python/saved_model/save.py", line 794, in _fill_meta_graph_def
File "tensorflow/python/saved_model/save.py", line 607, in _generate_signatures
File "tensorflow/python/saved_model/save.py", line 474, in _tensor_dict_to_tensorinfo
File "tensorflow/python/saved_model/save.py", line 475, in <dictcomp>
File "tensorflow/python/saved_model/utils_impl.py", line 78, in build_tensor_info_internal
TypeError: Parameter to MergeFrom() must be instance of same class: expected tensorflow.TensorShapeProto got tensorflow.TensorShapeProto.
[11509] Failed to execute script 'main' due to unhandled exception: Parameter to MergeFrom() must be instance of same class: expected tensorflow.TensorShapeProto got tensorflow.TensorShapeProto.
奇怪的是,这似乎不会发生在只有Windows的Mac上。
如果我在Pyinstaller之外进行保存,只从pyinstaller进行加载,我会在load AttributeError: as_proto
上得到一个错误,我猜它有相同的路由原因。
这是怎么回事,我如何在pyinstaller应用程序中保存/加载tensorflow模型?
1条答案
按热度按时间bqf10yzr1#
我的tensorflow安装了
当我启动一个新的env并安装
问题消失了。
因此,当程序是用pyinstaller构建的时候,看起来
conda install -c apple tensorflow
不能很好地与tf.saved_model.save/tf.saved_model.load
一起运行。