keras tensorflow 版本的Pytorch变换

xe55xuns  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(177)

pytorch模型中执行推理之前,我使用以下代码准备图像:

def image_loader(transform, image_name):
    image = Image.open(image_name)
    #transform
    image = transform(image).float()
    image = torch.tensor(image)
    image = image.unsqueeze(0)
    return image

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

我已经将模型转换为Tensorflow模型,但是,我不确定在推理之前如何对图像进行类似的转换,因为似乎没有tensorflowkeras的等价物。有什么建议吗?

gojuced7

gojuced71#

这是一些指针,在pytorch中,您有

from torchvision import transforms
from PIL import Image 
import torch 

def image_loader(transform, image_name):
    image = Image.open(image_name).convert('RGB')
    image = transform(image).float()
    image = torch.tensor(image)
    image = image.unsqueeze(0)
    return image

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# check: visualize 
i = image_loader(data_transforms, '/content/1.png')
i.shape

plt.figure(figsize=(25,10))
subplot(121); imshow(np.array(i[0]).transpose(1, 2, 0));

tensorflow中,您可以通过以下方式实现这一点

def transform(image, mean, std):
    for channel in range(3):
        image[:, :, channel] = (image[:, :, channel] - mean[channel]) \
            / std[channel]
    return image

def image_loader(image_name):
    image = Image.open(image_name).convert('RGB')
    image = transform(np.array(image) / 255, mean=[0.485, 0.456,
                      0.406], std=[0.229, 0.224, 0.225])
    image = tf.cast(image, tf.float32)
    image = tf.expand_dims(image, 0)
    return image

# check: visualize 
i = image_loader('/content/1.png')
i.shape 

plt.figure(figsize=(25,10))
subplot(121); imshow(i[0]);

**这应该输出相同的结果。**注意,在第二种情况下,我们定义了transform函数,来自另一个OP,here,这很好,但是,您也可以检查tf. keras...Normalization,有关详细信息,请参见this answer

相关问题