gpt4 book ai didi

tensorflow - TF2.1 : SegNet model architecture problem. 度量计算错误,保持不变并收敛到确定的值

转载 作者:行者123 更新时间:2023-12-03 14:53:53 26 4
gpt4 key购买 nike

我正在 Tensorflow 2.1.0 中构建自定义模型 ( SegNet ) .

我面临的第一个问题是重新利用本文所述所需的最大池操作的索引。
基本上,由于它是一种编码器-解码器架构,因此在解码中需要网络编码部分的池化索引来对特征图进行上采样并保持相应索引的目标值。

现在,在 TF 中,tf.keras.layers.MaxPool2D 层默认不导出这些索引。 (例如在 PyTorch 中)。
要获得最大池操作的索引,需要使用 tf.nn.max_pool_with_argmax .
无论如何,此操作以扁平格式返回索引(argmax),这需要进一步的操作才能在网络的其他部分有用。

为了实现一个执行 MaxPooling2D 并导出这些索引(扁平化)的层,我在 keras 中定义了一个自定义层。

class MaxPoolingWithArgmax2D(Layer):

def __init__(
self,
pool_size=(2, 2),
strides=2,
padding='same',
**kwargs):
super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
self.padding = padding
self.pool_size = pool_size
self.strides = strides

def call(self, inputs, **kwargs):
padding = self.padding
pool_size = self.pool_size
strides = self.strides
output, argmax = tf.nn.max_pool_with_argmax(
inputs,
ksize=pool_size,
strides=strides,
padding=padding.upper(),
output_dtype=tf.int64)
return output, argmax

显然,该层用于网络的编码部分,因此需要解码各个层来执行逆运算(UpSampling2D),并利用索引(本文中有关此操作的更多详细信息)。

经过一番研究,我找到了遗留代码(TF<2.1.0)并对其进行了修改以执行操作。
无论如何,我不是 100% 相信这段代码运行良好,事实上有些事情我不喜欢。
class MaxUnpooling2D(Layer):
def __init__(self, size=(2, 2), **kwargs):
super(MaxUnpooling2D, self).__init__(**kwargs)
self.size = size

def call(self, inputs, output_shape=None):
updates, mask = inputs[0], inputs[1]
with tf.name_scope(self.name):
mask = tf.cast(mask, 'int32')
#input_shape = tf.shape(updates, out_type='int32')
input_shape = updates.get_shape()

# This statement is required if I don't want to specify a batch size
if input_shape[0] == None:
batches = 1
else:
batches = input_shape[0]

# calculation new shape
if output_shape is None:
output_shape = (
batches,
input_shape[1]*self.size[0],
input_shape[2]*self.size[1],
input_shape[3])

# calculation indices for batch, height, width and feature maps
one_like_mask = tf.ones_like(mask, dtype='int32')
batch_shape = tf.concat(
[[batches], [1], [1], [1]],
axis=0)
batch_range = tf.reshape(
tf.range(output_shape[0], dtype='int32'),
shape=batch_shape)
b = one_like_mask * batch_range
y = mask // (output_shape[2] * output_shape[3])
x = (mask // output_shape[3]) % output_shape[2]
feature_range = tf.range(output_shape[3], dtype='int32')
f = one_like_mask * feature_range

# transpose indices & reshape update values to one dimension
updates_size = tf.size(updates)
indices = tf.transpose(tf.reshape(
tf.stack([b, y, x, f]),
[4, updates_size]))
values = tf.reshape(updates, [updates_size])
ret = tf.scatter_nd(indices, values, output_shape)
return ret

困扰我的事情是:
  • 执行展开索引的操作 (MaxUnpooling2D) 与了解特定的批量大小严格相关,对于模型验证,我希望为 None 或未指定。
  • 我不确定这段代码实际上是否与库的其余部分 100% 兼容。事实上在fit期间如果我使用 tf.keras.metrics.MeanIoU值收敛到 0.341并在除第一个以外的每个其他时期保持不变。相反,标准准确度指标工作得很好。

  • 深度网络架构

    以下是模型的完整定义。
    import tensorflow as tf
    import tensorflow.keras as keras
    import tensorflow.keras.layers as layers
    from tensorflow.keras.layers import Layer


    class SegNet:
    def __init__(self, data_shape, classes = 3, batch_size = None):
    self.MODEL_NAME = 'SegNet'
    self.MODEL_VERSION = '0.2'

    self.classes = classes
    self.batch_size = batch_size

    self.build_model(data_shape)

    def build_model(self, data_shape):
    input_shape = (data_shape, data_shape, 3)

    inputs = keras.Input(shape=input_shape, batch_size=self.batch_size, name='Input')

    # Build sequential model

    # Encoding
    encoders = 5
    feature_maps = [64, 128, 256, 512, 512]
    n_convolutions = [2, 2, 3, 3, 3]
    eb_input = inputs
    eb_argmax_indices = []
    for encoder_index in range(encoders):
    encoder_block, argmax_indices = self.encoder_block(
    eb_input, encoder_index, feature_maps[encoder_index], n_convolutions[encoder_index])
    eb_argmax_indices.append(argmax_indices)
    eb_input = encoder_block

    # Decoding
    decoders = encoders
    db_input = encoder_block
    eb_argmax_indices.reverse()
    feature_maps.reverse()
    n_convolutions.reverse()
    d_feature_maps = [512, 512, 256, 128, 64]
    d_n_convolutions = n_convolutions
    for decoder_index in range(decoders):
    decoder_block = self.decoder_block(
    db_input, eb_argmax_indices[decoder_index], decoder_index, d_feature_maps[decoder_index], d_n_convolutions[decoder_index])
    db_input = decoder_block

    output = layers.Softmax()(decoder_block)

    self.model = keras.Model(inputs=inputs, outputs=output, name="SegNet")

    def encoder_block(self, x, encoder_index, feature_maps, n_convolutions):
    bank_input = x
    for conv_index in range(n_convolutions):
    bank = self.eb_layers_bank(
    bank_input, conv_index, feature_maps, encoder_index)
    bank_input = bank

    max_pool, indices = MaxPoolingWithArgmax2D(pool_size=(
    2, 2), strides=2, padding='same', name='EB_{}_MPOOL'.format(encoder_index + 1))(bank)

    return max_pool, indices

    def eb_layers_bank(self, x, bank_index, feature_maps, encoder_index):

    bank_input = x

    conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='EB_{}_BANK_{}_CONV'.format(
    encoder_index + 1, bank_index + 1))(bank_input)
    batch_norm = layers.BatchNormalization(
    name='EB_{}_BANK_{}_BN'.format(encoder_index + 1, bank_index + 1))(conv_l)
    relu = layers.ReLU(name='EB_{}_BANK_{}_RL'.format(
    encoder_index + 1, bank_index + 1))(batch_norm)

    return relu

    def decoder_block(self, x, max_pooling_idices, decoder_index, feature_maps, n_convolutions):
    #bank_input = self.unpool_with_argmax(x, max_pooling_idices)
    bank_input = MaxUnpooling2D(name='DB_{}_UPSAMP'.format(decoder_index + 1))([x, max_pooling_idices])
    #bank_input = layers.UpSampling2D()(x)
    for conv_index in range(n_convolutions):
    if conv_index == n_convolutions - 1:
    last_l_banck = True
    else:
    last_l_banck = False
    bank = self.db_layers_bank(
    bank_input, conv_index, feature_maps, decoder_index, last_l_banck)
    bank_input = bank

    return bank

    def db_layers_bank(self, x, bank_index, feature_maps, decoder_index, last_l_bank):
    bank_input = x

    if (last_l_bank) & (decoder_index == 4):
    conv_l = layers.Conv2D(self.classes, (1, 1), padding='same', name='DB_{}_BANK_{}_CONV'.format(
    decoder_index + 1, bank_index + 1))(bank_input)
    #batch_norm = layers.BatchNormalization(
    # name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
    return conv_l
    else:

    if (last_l_bank) & (decoder_index > 0):
    conv_l = layers.Conv2D(int(feature_maps / 2), (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
    decoder_index + 1, bank_index + 1))(bank_input)
    else:
    conv_l = layers.Conv2D(feature_maps, (3, 3), padding='same', name='DB_{}_BANK_{}_CONV'.format(
    decoder_index + 1, bank_index + 1))(bank_input)
    batch_norm = layers.BatchNormalization(
    name='DB_{}_BANK_{}_BN'.format(decoder_index + 1, bank_index + 1))(conv_l)
    relu = layers.ReLU(name='DB_{}_BANK_{}_RL'.format(
    decoder_index + 1, bank_index + 1))(batch_norm)

    return relu

    def get_model(self):
    return self.model

    这里是 model.summary() 的输出.
    Model: "SegNet"
    __________________________________________________________________________________________________
    Layer (type) Output Shape Param # Connected to
    ==================================================================================================
    Input (InputLayer) [(None, 416, 416, 3) 0
    __________________________________________________________________________________________________
    EB_1_BANK_1_CONV (Conv2D) (None, 416, 416, 64) 1792 Input[0][0]
    __________________________________________________________________________________________________
    EB_1_BANK_1_BN (BatchNormalizat (None, 416, 416, 64) 256 EB_1_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    EB_1_BANK_1_RL (ReLU) (None, 416, 416, 64) 0 EB_1_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    EB_1_BANK_2_CONV (Conv2D) (None, 416, 416, 64) 36928 EB_1_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    EB_1_BANK_2_BN (BatchNormalizat (None, 416, 416, 64) 256 EB_1_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    EB_1_BANK_2_RL (ReLU) (None, 416, 416, 64) 0 EB_1_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    EB_1_MPOOL (MaxPoolingWithArgma ((None, 208, 208, 64 0 EB_1_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    EB_2_BANK_1_CONV (Conv2D) (None, 208, 208, 128 73856 EB_1_MPOOL[0][0]
    __________________________________________________________________________________________________
    EB_2_BANK_1_BN (BatchNormalizat (None, 208, 208, 128 512 EB_2_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    EB_2_BANK_1_RL (ReLU) (None, 208, 208, 128 0 EB_2_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    EB_2_BANK_2_CONV (Conv2D) (None, 208, 208, 128 147584 EB_2_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    EB_2_BANK_2_BN (BatchNormalizat (None, 208, 208, 128 512 EB_2_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    EB_2_BANK_2_RL (ReLU) (None, 208, 208, 128 0 EB_2_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    EB_2_MPOOL (MaxPoolingWithArgma ((None, 104, 104, 12 0 EB_2_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_1_CONV (Conv2D) (None, 104, 104, 256 295168 EB_2_MPOOL[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_1_BN (BatchNormalizat (None, 104, 104, 256 1024 EB_3_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_1_RL (ReLU) (None, 104, 104, 256 0 EB_3_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_2_CONV (Conv2D) (None, 104, 104, 256 590080 EB_3_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_2_BN (BatchNormalizat (None, 104, 104, 256 1024 EB_3_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_2_RL (ReLU) (None, 104, 104, 256 0 EB_3_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_3_CONV (Conv2D) (None, 104, 104, 256 590080 EB_3_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_3_BN (BatchNormalizat (None, 104, 104, 256 1024 EB_3_BANK_3_CONV[0][0]
    __________________________________________________________________________________________________
    EB_3_BANK_3_RL (ReLU) (None, 104, 104, 256 0 EB_3_BANK_3_BN[0][0]
    __________________________________________________________________________________________________
    EB_3_MPOOL (MaxPoolingWithArgma ((None, 52, 52, 256) 0 EB_3_BANK_3_RL[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_1_CONV (Conv2D) (None, 52, 52, 512) 1180160 EB_3_MPOOL[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_1_BN (BatchNormalizat (None, 52, 52, 512) 2048 EB_4_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_1_RL (ReLU) (None, 52, 52, 512) 0 EB_4_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_2_CONV (Conv2D) (None, 52, 52, 512) 2359808 EB_4_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_2_BN (BatchNormalizat (None, 52, 52, 512) 2048 EB_4_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_2_RL (ReLU) (None, 52, 52, 512) 0 EB_4_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_3_CONV (Conv2D) (None, 52, 52, 512) 2359808 EB_4_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_3_BN (BatchNormalizat (None, 52, 52, 512) 2048 EB_4_BANK_3_CONV[0][0]
    __________________________________________________________________________________________________
    EB_4_BANK_3_RL (ReLU) (None, 52, 52, 512) 0 EB_4_BANK_3_BN[0][0]
    __________________________________________________________________________________________________
    EB_4_MPOOL (MaxPoolingWithArgma ((None, 26, 26, 512) 0 EB_4_BANK_3_RL[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_1_CONV (Conv2D) (None, 26, 26, 512) 2359808 EB_4_MPOOL[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_1_BN (BatchNormalizat (None, 26, 26, 512) 2048 EB_5_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_1_RL (ReLU) (None, 26, 26, 512) 0 EB_5_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_2_CONV (Conv2D) (None, 26, 26, 512) 2359808 EB_5_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_2_BN (BatchNormalizat (None, 26, 26, 512) 2048 EB_5_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_2_RL (ReLU) (None, 26, 26, 512) 0 EB_5_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_3_CONV (Conv2D) (None, 26, 26, 512) 2359808 EB_5_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_3_BN (BatchNormalizat (None, 26, 26, 512) 2048 EB_5_BANK_3_CONV[0][0]
    __________________________________________________________________________________________________
    EB_5_BANK_3_RL (ReLU) (None, 26, 26, 512) 0 EB_5_BANK_3_BN[0][0]
    __________________________________________________________________________________________________
    EB_5_MPOOL (MaxPoolingWithArgma ((None, 13, 13, 512) 0 EB_5_BANK_3_RL[0][0]
    __________________________________________________________________________________________________
    DB_1_UPSAMP (MaxUnpooling2D) (1, 26, 26, 512) 0 EB_5_MPOOL[0][0]
    EB_5_MPOOL[0][1]
    __________________________________________________________________________________________________
    DB_1_BANK_1_CONV (Conv2D) (1, 26, 26, 512) 2359808 DB_1_UPSAMP[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_1_BN (BatchNormalizat (1, 26, 26, 512) 2048 DB_1_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_1_RL (ReLU) (1, 26, 26, 512) 0 DB_1_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_2_CONV (Conv2D) (1, 26, 26, 512) 2359808 DB_1_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_2_BN (BatchNormalizat (1, 26, 26, 512) 2048 DB_1_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_2_RL (ReLU) (1, 26, 26, 512) 0 DB_1_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_3_CONV (Conv2D) (1, 26, 26, 512) 2359808 DB_1_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_3_BN (BatchNormalizat (1, 26, 26, 512) 2048 DB_1_BANK_3_CONV[0][0]
    __________________________________________________________________________________________________
    DB_1_BANK_3_RL (ReLU) (1, 26, 26, 512) 0 DB_1_BANK_3_BN[0][0]
    __________________________________________________________________________________________________
    DB_2_UPSAMP (MaxUnpooling2D) (1, 52, 52, 512) 0 DB_1_BANK_3_RL[0][0]
    EB_4_MPOOL[0][1]
    __________________________________________________________________________________________________
    DB_2_BANK_1_CONV (Conv2D) (1, 52, 52, 512) 2359808 DB_2_UPSAMP[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_1_BN (BatchNormalizat (1, 52, 52, 512) 2048 DB_2_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_1_RL (ReLU) (1, 52, 52, 512) 0 DB_2_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_2_CONV (Conv2D) (1, 52, 52, 512) 2359808 DB_2_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_2_BN (BatchNormalizat (1, 52, 52, 512) 2048 DB_2_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_2_RL (ReLU) (1, 52, 52, 512) 0 DB_2_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_3_CONV (Conv2D) (1, 52, 52, 256) 1179904 DB_2_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_3_BN (BatchNormalizat (1, 52, 52, 256) 1024 DB_2_BANK_3_CONV[0][0]
    __________________________________________________________________________________________________
    DB_2_BANK_3_RL (ReLU) (1, 52, 52, 256) 0 DB_2_BANK_3_BN[0][0]
    __________________________________________________________________________________________________
    DB_3_UPSAMP (MaxUnpooling2D) (1, 104, 104, 256) 0 DB_2_BANK_3_RL[0][0]
    EB_3_MPOOL[0][1]
    __________________________________________________________________________________________________
    DB_3_BANK_1_CONV (Conv2D) (1, 104, 104, 256) 590080 DB_3_UPSAMP[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_1_BN (BatchNormalizat (1, 104, 104, 256) 1024 DB_3_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_1_RL (ReLU) (1, 104, 104, 256) 0 DB_3_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_2_CONV (Conv2D) (1, 104, 104, 256) 590080 DB_3_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_2_BN (BatchNormalizat (1, 104, 104, 256) 1024 DB_3_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_2_RL (ReLU) (1, 104, 104, 256) 0 DB_3_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_3_CONV (Conv2D) (1, 104, 104, 128) 295040 DB_3_BANK_2_RL[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_3_BN (BatchNormalizat (1, 104, 104, 128) 512 DB_3_BANK_3_CONV[0][0]
    __________________________________________________________________________________________________
    DB_3_BANK_3_RL (ReLU) (1, 104, 104, 128) 0 DB_3_BANK_3_BN[0][0]
    __________________________________________________________________________________________________
    DB_4_UPSAMP (MaxUnpooling2D) (1, 208, 208, 128) 0 DB_3_BANK_3_RL[0][0]
    EB_2_MPOOL[0][1]
    __________________________________________________________________________________________________
    DB_4_BANK_1_CONV (Conv2D) (1, 208, 208, 128) 147584 DB_4_UPSAMP[0][0]
    __________________________________________________________________________________________________
    DB_4_BANK_1_BN (BatchNormalizat (1, 208, 208, 128) 512 DB_4_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    DB_4_BANK_1_RL (ReLU) (1, 208, 208, 128) 0 DB_4_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    DB_4_BANK_2_CONV (Conv2D) (1, 208, 208, 64) 73792 DB_4_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    DB_4_BANK_2_BN (BatchNormalizat (1, 208, 208, 64) 256 DB_4_BANK_2_CONV[0][0]
    __________________________________________________________________________________________________
    DB_4_BANK_2_RL (ReLU) (1, 208, 208, 64) 0 DB_4_BANK_2_BN[0][0]
    __________________________________________________________________________________________________
    DB_5_UPSAMP (MaxUnpooling2D) (1, 416, 416, 64) 0 DB_4_BANK_2_RL[0][0]
    EB_1_MPOOL[0][1]
    __________________________________________________________________________________________________
    DB_5_BANK_1_CONV (Conv2D) (1, 416, 416, 64) 36928 DB_5_UPSAMP[0][0]
    __________________________________________________________________________________________________
    DB_5_BANK_1_BN (BatchNormalizat (1, 416, 416, 64) 256 DB_5_BANK_1_CONV[0][0]
    __________________________________________________________________________________________________
    DB_5_BANK_1_RL (ReLU) (1, 416, 416, 64) 0 DB_5_BANK_1_BN[0][0]
    __________________________________________________________________________________________________
    DB_5_BANK_2_CONV (Conv2D) (1, 416, 416, 3) 195 DB_5_BANK_1_RL[0][0]
    __________________________________________________________________________________________________
    softmax (Softmax) (1, 416, 416, 3) 0 DB_5_BANK_2_CONV[0][0]
    ==================================================================================================
    Total params: 29,459,075
    Trainable params: 29,443,203
    Non-trainable params: 15,872
    __________________________________________________________________________________________________


    如您所见,我被迫在 MaxUnpooling2D 中指定批量大小,否则我会收到无法执行操作的错误,因为有 None值和形状无法正确转换。

    当我尝试预测图像时,我不得不指定正确的批处理尺寸,否则会出现如下错误:
    InvalidArgumentError:  Shapes of all inputs must match: values[0].shape = [4,208,208,64] != values[1].shape = [1,208,208,64]
    [[{{node SegNet/DB_5_UPSAMP/PartitionedCall/PartitionedCall/DB_5_UPSAMP/stack}}]] [Op:__inference_predict_function_70839]

    这是由从最大池操作中解开索引所需的实现引起的。

    训练图

    这是 20 个 epoch 训练的引用。

    如您所见,MeanIoU 指标是线性的,除了 epoch 1 之外没有任何进展,没有更新。
    Mean intersection over union

    另一个指标工作正常,损失正确减少。

    Loss and accuracy

    ––––––––––

    结论
  • 有更好的方法,与最新版本的 TF 更兼容,使用最大池操作中的索引实现解散和上采样?
  • 如果实现是正确的,为什么我的指标会停留在特定的值上?我在模型中做错了吗?

  • 谢谢!

    最佳答案

    您可以通过两种方式在自定义层中进行批量大小未知的 reshape 。

    如果您知道形状的其余部分,请使用 -1 作为批量大小进行整形:

    假设您知道预期数组的大小:

    import tensorflow.keras.backend as K
    reshaped = K.reshape(original, (-1, x, y, channels))

    假设你不知道大小,那么使用 K.shape将形状作为张量:
    inputs_shape = K.shape(inputs)
    batch_size = inputs_shape[:1]
    x = inputs_shape[1:2]
    y = inputs_shape[2:3]
    ch = inputs_shape[3:]

    #you can then concatenate these and operate them (notice I kept them as 1D vector, not as scalar)
    newShape = K.concatenate([batch_size, x, y, ch]) #of course you will make your operations

    一旦我做了自己的 Segnet 版本,我就没有使用索引,而是保留了一个热门版本。确实需要额外的操作,但它可能运行良好:
    def get_indices(original, unpooled):
    is_equal = K.equal(original, unpooled)
    return K.cast(is_equal, K.floatx())

    previous_output = ...
    pooled = MaxPooling2D()(previous_output)
    unpooled = UpSampling2D()(pooled)

    one_hot_indices = Lambda(get_indices)([previous_output, unpooled])

    然后在上采样之后,我连接这些索引并传递一个新的 conv:
    some_output = ...
    upsampled = UpSampling2D()(some_output)
    with_indices = Concatenate([upsampled, one_hot_indices])
    upsampled = Conv2D(...)(with_indices)

    关于tensorflow - TF2.1 : SegNet model architecture problem. 度量计算错误,保持不变并收敛到确定的值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61099167/

    26 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com