gpt4 book ai didi

python - 应用组卷积,其中每个组都被限制为具有相同的权重

转载 作者:行者123 更新时间:2023-12-03 19:14:31 27 4
gpt4 key购买 nike

在 Keras 中可以使用组卷积,使用此处的代码例如:https://github.com/tensorflow/tensorflow/issues/34024#issuecomment-552034933

但是,对于我的具体应用,我要求在训练时,组卷积中的每个组具有相同的权重。例如,如果我有一个形状为 8x8x32 的张量并且我想要 groups = 2 和 filter_size=3x3,那么正常的组卷积将使用两个 3x3x16 张量在 8x8x32 的前半部分和8x8x32 的后半部分。我想确保两个 3x3x16 张量具有相同的权重,即使在训练期间也是如此。

我可以通过摆脱我的组卷积框架并将我的 8x8x32 张量拆分为两个 8x8x16 张量,然后通过单个非分组 3x3x16 卷积运行它们中的每一个来实现这一点。但是,由于不使用组卷积框架,代码运行速度较慢,因为任务不是并行运行的。

如何在 Keras 中使用组卷积提供的速度提升,同时将每个组中的权重限制为相同?

最佳答案

显然,组卷积在 TensorFlow 中的工作方式(至少目前,因为它似乎还没有被记录,所以我猜它可能会改变)是,给定一个批处理 img形状 (n, h, w, c) 和一个形状为 (kh, kw, c1, c2) 的过滤器 k,它使一个g = c/c1 组中的卷积,其中结果具有 c2 channel 。 c 必须能被 c1 整除,c2 必须是 g 的倍数。据我了解,这意味着,如果我们调用每组输出 channel 的数量 a = c2/g,那么第一组使用过滤器 k[:, :, :, :a],第二组k[:, :, :, a:2*a],依此类推。如果你想对每个卷积组使用完全相同的过滤器,你只需要为单个组制作过滤器,形状为 (kh, kw, c1, a) 然后平铺它 g 最后一维的次数。

在您引用的代码中,您只需进行以下更改。 self.kernel 的定义将更改为:

# Make sure self.filters is divisible by self.groups
kernel_shape = self.kernel_size + (input_dim // self.groups, self.filters // self.groups)
# Filters for a single group
self.kernel_base = self.add_weight(
name='kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype)
# Tile filters for the rest of groups
self.kernel = tf.tile(self.kernel_base, [1, 1, 1, self.groups])

假设您也希望偏差以相同的方式工作,您可以对其执行相同的操作:

if self.use_bias:
# Bias for a single group
self.bias_base = self.add_weight(
name='bias',
shape=(self.filters // self.groups,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype)
# Bias for all groups
self.bias = tf.tile(self.kernel_base, [1, 1, 1, self.groups])
else:
self.bias = None

其余代码的工作方式类似,如tf.nn.conv2d将像以前一样使用 self.kernels,并且 self.bias 将被类似地添加。

关于python - 应用组卷积,其中每个组都被限制为具有相同的权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61213956/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com