我正在尝试使用DeepLabV 3+的Keras实现,使用自定义通道数(例如4而不是3)执行二进制分割任务。
到目前为止我已修改的内容:
- 我知道我不能使用imagenet权重(为3个通道训练),所以我已经将权重改为无(随机)。
- 我已经更改了4通道实现的输入形状:输入(形状=(图像大小,图像大小,4))
- 由于这是一个二进制分割任务,因此我使用S形激活更改了输出:(x)在每个层中创建一个新的类,并将其添加到每个层中。
下面是链接不可用时的原始代码:
def convolution_block(
block_input,
num_filters=256,
kernel_size=3,
dilation_rate=1,
padding="same",
use_bias=False,
):
x = layers.Conv2D(
num_filters,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
padding="same",
use_bias=use_bias,
kernel_initializer=keras.initializers.HeNormal(),
)(block_input)
x = layers.BatchNormalization()(x)
return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape
x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1)
return output
def DeeplabV3Plus(image_size, num_classes):
model_input = keras.Input(shape=(image_size, image_size, 3))
resnet50 = keras.applications.ResNet50(
weights="imagenet", include_top=False, input_tensor=model_input
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
interpolation="bilinear",
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]),
interpolation="bilinear",
)(x)
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
下面是我对4通道输入和二进制输出所做更改的代码:
def DeeplabV3Plus(image_size, num_classes=1):
model_input = keras.Input(shape=(image_size, image_size, 4))
resnet50 = keras.applications.ResNet50(
weights=None, include_top=False, input_tensor=model_input
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
interpolation="bilinear",
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]),
interpolation="bilinear",
)(x)
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same", activation='sigmoid')(x)
return keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
但是,我收到以下错误:
InvalidArgumentError: Graph execution error:
Detected at node 'assert_greater_equal/Assert/AssertGuard/Assert'
和
Node: 'assert_greater_equal/Assert/AssertGuard/Assert'
2 root error(s) found.
(0) INVALID_ARGUMENT: assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (model_5/conv2d_59/Sigmoid:0) = ] [[[[nan][nan][nan]]]...] [y (Cast_8/x:0) = ] [0]
[[{{node assert_greater_equal/Assert/AssertGuard/Assert}}]]
[[assert_greater_equal_1/Assert/AssertGuard/pivot_f/_23/_65]]
(1) INVALID_ARGUMENT: assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (model_5/conv2d_59/Sigmoid:0) = ] [[[[nan][nan][nan]]]...] [y (Cast_8/x:0) = ] [0]
[[{{node assert_greater_equal/Assert/AssertGuard/Assert}}]]
注意:我的图像和遮罩缩放为0-1。
1条答案
按热度按时间bvjxkvbb1#
所提供的代码是Keras中用于二进制分割的DeepLabV3+的4通道版本的正确实现。我的错误是由格式不佳的4通道输入数据造成的。