python TF Keras如何在加载模型时获得预期的输入形状?

xwbd5t1u  于 2023-02-11  发布在  Python
关注(0)|答案(2)|浏览(141)

是否有可能从“model.h5”文件中获得预期的输入形状?我有两个模型用于相同的数据集,但具有不同的选项和形状。第一个模型需要dim(None,64,48,1),第二个模型需要输入形状(None,128,96,3)。(注:宽度或高度不是固定的,当我再次训练时可以改变)。通道问题很容易“修复”(或者绕过),只需使用try:除了因为只有两个选项(1用于灰度图像,3用于RGB图像):

channels = self.df["channels"][0]
        file = ""
        try:
            images, src_images, data = self.get_images()
            images = self.preprocess_data(images, channels)
            predictions, file = self.load_model(images, file)
            self.predict_data(src_images, predictions, data)
        except:
            if channels == 1:
                print("Except channels =", channels)
                channels = 3
                images, src_images, data = self.get_images()
                images = self.preprocess_data(images, channels)
                predictions = self.load_model(images, file)
                self.predict_data(src_images, predictions, data)
            else:
                channels = 1
                print("Except channels =", channels)
                images, src_images, data = self.get_images()
                images = self.preprocess_data(images, channels)
                predictions = self.load_model(images, file)
                self.predict_data(src_images, predictions, data)

但是这个解决方案不能用于图像的宽度和高度,因为基本上有无限数量的选项。此外,它相当慢,因为我读了所有的数据两次,并没有任何原因地预处理两次。

是否有办法加载model.h5文件并以如下形式打印预期的输入形状?:

[None, 128, 96, 3]
unhi4e5o

unhi4e5o1#

我终于自己找到了答案。

config = model.get_config() # Returns pretty much every information about your model
print(config["layers"][0]["config"]["batch_input_shape"]) # returns a tuple of width, height and channels

这将输出以下内容:

(None, 128, 96, 3)
jrcvhitl

jrcvhitl2#

我发现here的答案更简洁:

model.layers[0].input_shape[0]

我的理解是,用这种方式处理多个输入也应该更容易。

相关问题