gpt4 book ai didi

python - 无法在 Tensorflow v2 中创建不可训练的变量

转载 作者:行者123 更新时间:2023-12-05 06:53:59 25 4
gpt4 key购买 nike

我手动实现了 batchnormalize 层。但是初始函数中创建 nontrainable 变量的代码似乎不起作用。代码:

import tensorflow as tf
class batchNormalization(tf.keras.layers.Layer):
def __init__(self, shape, Trainable, **kwargs):
super(batchNormalization, self).__init__(**kwargs)
self.shape = shape
self.Trainable = Trainable
self.beta = tf.Variable(initial_value=tf.zeros(shape), trainable=Trainable)
self.gamma = tf.Variable(initial_value=tf.ones(shape), trainable=Trainable)
self.moving_mean = tf.Variable(initial_value=tf.zeros(self.shape), trainable=False)
self.moving_var = tf.Variable(initial_value=tf.ones(self.shape), trainable=False)

def update_var(self,inputs):
wu, sigma = tf.nn.moments(inputs, axes=[0, 1, 2], shift=None, keepdims=False, name=None)
var = tf.math.sqrt(sigma)
self.moving_mean = self.moving_mean * 0.09 + wu * 0.01
self.moving_var = self.moving_var * 0.09 + var * 0.01
return wu,var

def call(self, inputs):
wu, var = self.update_var(inputs)
return tf.nn.batch_normalization(inputs, wu, var, self.beta,
self.gamma, variance_epsilon=0.001)


@tf.function
def train_step(model, inputs, label,optimizer):
with tf.GradientTape(persistent=False) as tape:
predictions = model(inputs, training=1)
loss = tf.keras.losses.mean_squared_error(predictions,label)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))


if __name__=='__main__':
f=tf.ones([2,256,256,8])
label=tf.ones([2,256,256,8])
inputs = tf.keras.Input(shape=(256,256,8))
outputs=batchNormalization([8],True)(inputs)
Model = tf.keras.Model(inputs=inputs, outputs=outputs)
Layer = batchNormalization([8],True)
print(len(Model.variables))
print(len(Model.trainable_variables))
print(len(Layer.variables))
print(len(Layer.trainable_variables))
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)
for i in range(0,100):
train_step(Layer, f, label,optimizer)
# train_step(Model,f,label,optimizer)

训练的时候,又报错了:TypeError:正在向函数构建代码之外的操作传递“Graph”张量。通过在函数构建中包含 tf.init_scope,可以使 Graph 张量从函数构建上下文中泄漏出来代码。

最佳答案

替换

self.moving_mean = self.moving_mean * 0.09 + wu * 0.01
self.moving_var = self.moving_var * 0.09 + var * 0.01
``
by
self.moving_mean.assign(self.moving_mean * 0.09 + wu * 0.01)
self.moving_var.assign(self.moving_var * 0.09 + var * 0.01)

可以解决这个问题。

关于python - 无法在 Tensorflow v2 中创建不可训练的变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65661130/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com