Keras和Pytorch Conv2D使用相同的权重给予不同的结果

wbrvyc0a  于 2023-04-12  发布在  其他
关注(0)|答案(2)|浏览(202)

我的代码:

import tensorflow as tf
from tensorflow.keras.layers import Conv2D
import torch, torchvision
import torch.nn as nn
import numpy as np

# Define the PyTorch layer
pt_layer = torch.nn.Conv2d(3, 12, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

# Get the weight tensor from the PyTorch layer
pt_weights = pt_layer.weight.detach().numpy()

# Create the equivalent Keras layer
keras_layer = Conv2D(12, kernel_size=(3, 3), strides=(2, 2), padding='same', use_bias=False, input_shape=(None, None, 3))

# Build the Keras layer to initialize its weights
keras_layer.build((None, None, None, 3))

# Transpose the PyTorch weights to match the expected shape of the Keras layer
keras_weights = pt_weights.transpose(2, 3, 1, 0)

# Set the weights of the Keras layer to the PyTorch weights
keras_layer.set_weights([keras_weights])

#Test both models
arr = np.random.normal(0,1,(1, 3, 224, 224))

print(pt_layer(torch.from_numpy(arr).float())[0,0])
print(keras_layer(arr.transpose(0,2,3,1))[0,:,:,0])

我希望这两个打印非常相似,但实际上是不同的。我在Colab上运行了它,以确保它不是由于旧的Pytorch/Keras版本。我肯定我错过了一些琐碎的东西,但我找不到它...欢迎任何帮助,请。

smtd7mpg

smtd7mpg1#

找到了答案:Keras和Pytorch中的填充似乎完全不同。要修复,请使用ZeroPadding2D:

keras_layer = tf.keras.Sequential([
    ZeroPadding2D(padding=(1, 1)),
    Conv2D(12, kernel_size=(3, 3), strides=(2, 2), padding='valid', use_bias=False, input_shape=(None, None, 3))
])
rekjcdws

rekjcdws2#

对上述内容进行补充:对于那些好奇的人来说,它与如何根据图像大小计算填充有关。输出图像的高度计算为
output height = (input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
因此,对于大小为5的图像,大小为3的内核,步长为2,我们得到
output height = (5 + 1 + 1 - 3) / 2 + 1 = 3
这是一个整数。当输出不是整数时,PyTorch和Keras的行为不同。例如,在上面的例子中,目标图像大小将是122.5,将被舍入为122。
PyTorch,不管舍入与否,总是会在所有侧面添加填充(由于层定义)。另一方面,Keras不会在图像的顶部和左侧添加填充,导致卷积从图像的原始左上角开始,而不是填充的那个,给出不同的结果。
示意性地,如果我们有一个4x4的图像,它看起来像这样:(其中xs表示卷积,0表示填充,点表示“缺失”填充):

PyTorch       Keras
xxx000        ......
xxx110        .xxx10
xxx110        .xxx10
011110        .xxx10
011110        .11110
000000        .00000

相关问题