gpt4 book ai didi

python - 如何创建具有共享权重的两层,其中一层是另一层的转置?

转载 作者:行者123 更新时间:2023-12-01 22:39:17 25 4
gpt4 key购买 nike

我在 Tensorflow 2.0 中使用 Keras API。

举个例子,假设我想在我的模型中有两个密集层,称为layer1layer2。但我想绑定(bind)它们的权重,使得layer1中的权重矩阵始终等于layer2中的权重矩阵的转置。

我该怎么做?

最佳答案

您可以为此定义一个自定义 Keras 层,您可以在其中传递引用 Dense 层。

自定义密集层:

class CustomDense(Layer):
def __init__(self, reference_layer):
super(CustomDense, self).__init__()
self.ref_layer = reference_layer

def call(self, inputs):
weights = self.ref_layer.get_weights()[0]
bias = self.ref_layer.get_weights()[1]
weights = tf.transpose(weights)
x = tf.linalg.matmul(inputs, weights) + bias
return x

现在,您可以使用Functional-API将此层添加到您的模型中.

inp = Input(shape=(5))
dense = Dense(5)
transposed_dense = CustomDense(dense)

#model
x = dense(inp)
x = transposed_dense(x)
model = Model(inputs=inp, outputs=x)
model.summary()
'''
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 5)] 0
_________________________________________________________________
dense_1 (Dense) (None, 5) 30
_________________________________________________________________
custom_dense_1 (CustomDense) (None, 5) 30
=================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
_________________________________________________________________
'''

如您所见,densecustom_dense 共享 30 个参数。这里 custom_dense 只是使用 dense 层的转置权重进行密集操作,它没有自己的参数。

编辑1:回答评论中的问题(子类层如何获取#params?):

Layer 类跟踪传递给其 __init__ 方法的所有对象。

transposed_dense._layers
# [<tensorflow.python.keras.layers.core.Dense at 0x7fc3e0874f28>]

以上参数将给出正在跟踪的依赖层。所有子属性权重可以查看为:

transposed_dense._gather_children_attribute("weights")
#[<tf.Variable 'dense_9/kernel:0' shape=(10, 5) dtype=float32>,
# <tf.Variable 'dense_9/bias:0' shape=(5,) dtype=float32>]

因此,当我们调用 model.summary() 时,它会在内部为每个 Layer 调用 count_params(),从而对所有 trainable_variable 进行计数。包括自身和子属性。

关于python - 如何创建具有共享权重的两层,其中一层是另一层的转置?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59663963/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com