gpt4 book ai didi

python - 将 tf.contrib.layers.batch_norm 迁移到 Tensorflow 2.0

转载 作者:行者123 更新时间:2023-12-05 06:18:11 24 4
gpt4 key购买 nike

我正在将 TensorFlow 代码迁移到 TensorFlow 2.1.0。

原代码如下:

conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)

这就是我所做的:

conv1 = Conv2D(out_channels, (3, 3), activation='relu', padding='same', data_format='channels_last', name=name)(inputs)
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', data_format="channels_last")(conv1)
#conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
pool1 = MaxPooling2D(pool_size=(2, 2), data_format="channels_last")(conv1)

我的问题是我不知道如何处理 tf.contrib.layers.batch_norm

如何将 tf.contrib.layers.batch_norm 迁移到 Tensorflow 2.x?

更新:
使用评论建议,我认为我已经正确迁移:

conv1 = BatchNormalization(momentum=0.99, scale=True, center=True)(conv1)

但我不确定 decay 是否像 momentum 并且我不知道如何在 中设置 updates_collections BatchNormalization 方法。

最佳答案

我在使用我要微调的训练模型时遇到了这个问题。只是像 OP 一样用 tf.keras.layers.BatchNormalization 替换 tf.contrib.layers.batch_norm 确实给了我一个错误,其修复如下所述。

旧代码如下所示:

tf.contrib.layers.batch_norm(
tensor,
scale=True,
center=True,
is_training=self.use_batch_statistics,
trainable=True,
data_format=self._data_format,
updates_collections=None,
)

更新后的工作代码如下所示:

tf.keras.layers.BatchNormalization(
name="BatchNorm",
scale=True,
center=True,
trainable=True,
)(tensor)

我不确定我删除的所有关键字参数是否会成为问题,但一切似乎都有效。请注意 name="BatchNorm" 参数。这些层使用不同的命名模式,因此我不得不使用 inspect_checkpoint.py 工具来查看模型并找到恰好是 BatchNorm 的层名称。

关于python - 将 tf.contrib.layers.batch_norm 迁移到 Tensorflow 2.0,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61362819/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com