在tensorflow中寻找resnet实现

biswetbf  于 2023-08-06  发布在  其他
关注(0)|答案(4)|浏览(95)

tensorflow中有resnet实现吗?我遇到了一些(例如)。https://github.com/ry/tensorflow-resnethttps://github.com/xuyuwei/resnet-tf),但是这些实现有一些bug(例如:参见相应github页面上的Issues部分)。我希望使用resnet训练imagenet,并寻找tensorflow实现。

tyu7yeag

tyu7yeag1#

在tensorflow中有一些(50/101/152):models/slim。
example notebook展示了如何运行预先训练好的inceptionres-net可能也没有什么不同。

qv7cva1a

qv7cva1a2#

我用tensorflow实现了一个cifar 10版本的ResNet。ResNet-32、ResNet-56和ResNet-110的验证误差分别为6.7%、6.5%和6.2%。(您可以轻松地将层数修改为超参数。)
我试着和新的ResNet粉丝友好相处,把一切都写得直截了当。您可以直接运行cifar10_train.py文件,而无需任何下载。
https://github.com/wenxinxu/resnet_in_tensorflow

e5njpo68

e5njpo683#

请在下面找到Resnet34自定义实现的代码。你可以使用这个模型来建立你的图像分类模型。

#Dependencies

import tensorflow as tf

from keras.models import Model
from keras.layers import GlobalAveragePooling2D, Dense, Layer, MaxPooling2D, Activation, Conv2D, Add, BatchNormalization

CONFIGURATIONS = {
               
                "NUM_CLASSES" : 3

             }

 # Custom Class Definition inherited from Model class

 class Resnet34(Model):

   def __init__(self,):
    super(Resnet34, self).__init__(name="resnet_34")

    self.conv_1 = CustomConv2D(64, 7, 2, padding="same")
    self.max_pool = MaxPooling2D(3, 2)

    self.conv_2_1 = ResidualBlock(64)
    self.conv_2_2 = ResidualBlock(64)
    self.conv_2_3 = ResidualBlock(64)

    self.conv_3_1 = ResidualBlock(128, 2) #2 for downsampling
    self.conv_3_2 = ResidualBlock(128)
    self.conv_3_3 = ResidualBlock(128)
    self.conv_3_4 = ResidualBlock(128)

    self.conv_4_1 = ResidualBlock(256, 2)  # 2 for downsampling
    self.conv_4_2 = ResidualBlock(256)
    self.conv_4_3 = ResidualBlock(256)
    self.conv_4_4 = ResidualBlock(256)
    self.conv_4_5 = ResidualBlock(256)
    self.conv_4_6 = ResidualBlock(256)

    self.conv_5_1 = ResidualBlock(512, 2)  # 2 for downsampling
    self.conv_5_2 = ResidualBlock(512)
    self.conv_5_3 = ResidualBlock(512)

    self.global_pool = GlobalAveragePooling2D()

    self.fc_3 = Dense(CONFIGURATIONS["NUM_CLASSES"], activation="softmax")

   def call(self, x, training=True):

    x = self.conv_1(x)
    x = self.max_pool(x)

    x = self.conv_2_1(x, training)
    x = self.conv_2_2(x, training)
    x = self.conv_2_3(x, training)

    x = self.conv_3_1(x, training)
    x = self.conv_3_2(x, training)
    x = self.conv_3_3(x, training)
    x = self.conv_3_4(x, training)

    x = self.conv_4_1(x, training)
    x = self.conv_4_2(x, training)
    x = self.conv_4_3(x, training)
    x = self.conv_4_4(x, training)
    x = self.conv_4_5(x, training)
    x = self.conv_4_6(x, training)

    x = self.conv_5_1(x, training)
    x = self.conv_5_2(x, training)
    x = self.conv_5_3(x, training)

    x = self.global_pool(x)

    return self.fc_3(x)

 # Custom Conv2D Class inherited from layer
class CustomConv2D(Layer):

   def __init__(self, n_filters, kernel_size, n_strides, padding="valid"):
    super(CustomConv2D, self).__init__(name="custom_conv2D")

    self.conv = Conv2D(
        filters=n_filters,
        kernel_size=kernel_size,
        activation="relu",
        strides= n_strides,
        padding=padding
    )

   self.batch_norm = BatchNormalization()

   def call(self, x, training=True):

    x = self.conv(x)
    x = self.batch_norm(x, training)

    return x

#Custom Residual Block inherited from Layer class
class ResidualBlock(Layer):

   def __init__(self, n_channels, n_strides=1):
    super(ResidualBlock, self).__init__(name="res_block")

    self.dotted = (n_strides!=1)

    self.custom_conv_1 = CustomConv2D(n_channels, 3, n_strides, padding="same")
    self.custom_conv_2 = CustomConv2D(n_channels, 3, 1, padding="same")

    self.activation = Activation('relu')

    if self.dotted:
        self.custom_conv_3 = CustomConv2D(n_channels, 1, n_strides) # 1  X 1 Conv layer

  def call(self, input, training):

    x = self.custom_conv_1(input, training)
    x = self.custom_conv_2(x)

    if self.dotted:
        x_add = self.custom_conv_3(input, training)
        x_add = Add()([x, x_add])
    else:
        x_add = Add()([x, input])

    return self.activation(x_add)

 #Calling with default build
 resnet_34 = Resnet34()
 resnet_34(tf.zeros([1, 256,256, 3]), training=False)
 resnet_34.summary()

字符串

htzpubme

htzpubme4#

我使用ronnie.ai和keras实现了Resnet。这两个工具都很棒。而罗尼则更容易从头开始。

相关问题