gpt4 book ai didi

python - 使用 Tensorflow 2.0 的神经网络中的条件参数

转载 作者:太空宇宙 更新时间:2023-11-03 20:19:41 25 4
gpt4 key购买 nike

我对由两个神经网络 NN1()、NN2() 组成的架构感兴趣,这样

第一个神经网络的输出weights_for_NN2 = NN1(inputs1)是第二个神经网络的参数/权重。然后最终输出 outputs = NN2(inputs2) 将用于损失函数 L = loss(outputs)。所以当我们反向传播时,我们只需要更新NN1

的权重

现在我只是运行一个简单的线性回归作为玩具;这是我的代码

model = keras.Sequential([
keras.layers.Dense(128,input_shape=(1,), activation='relu'),
keras.layers.Dense(2,input_shape=(128,))
])


class ConditionalModel(object):
def __init__(self):

self.W = tf.Variable([[5.0]],name="kernel")
self.b = tf.Variable([0.0],name="bias")
self.variables = [self.W,self.b]

def __call__(self, x):
return self.W * x + self.b

def loss(predicted_y, target_y):
return tf.reduce_mean(tf.square(predicted_y - target_y))

# do the training
cond_model = ConditionalModel()
optimizer = tf.keras.optimizers.SGD(lr=.1)
def train(model, inputs, outputs, learning_rate):
with tf.GradientTape() as tape:

param = model(tf.reshape(inputs,[-1,1]))
w = tf.reduce_mean(param[:,:-1],axis=0)
b = tf.reduce_mean(param[:,-1],axis=0)
for var in cond_model.variables:
if "kernel" in var.name:
var.assign(tf.reshape(w,[1,1]))
elif "bias" in var.name:
var.assign(tf.reshape(b,[1,]))

current_loss = loss(tf.squeeze(cond_model(tf.reshape(inputs,[-1,1]))), outputs)
gradients = tape.gradient(current_loss, model.trainable_weights)
print(gradients) # it prints [None,None,None,None] here
optimizer.apply_gradients(zip(gradients, model.trainable_weights))

这是创建一些玩具数据的代码

TRUE_W = 3.0
TRUE_b = 2.0
NUM_EXAMPLES = 1000

inputs = tf.random.normal(shape=[NUM_EXAMPLES])
noise = tf.random.normal(shape=[NUM_EXAMPLES])
outputs = inputs * TRUE_W + TRUE_b + noise

上面的代码不起作用,因为ValueError:没有为我的NN1的任何变量提供渐变。我相信这是我为 NN2 赋值的方式,导致 NN1 计算图的路径损坏。知道如何解决这个问题吗?

最佳答案

修改Wb直接解决bug。

for var in cond_model.variables:
if "kernel" in var.name:
cond_model.W = tf.reshape(w,[1,1])
elif "bias" in var.name:
cond_model.b = tf.reshape(b,[1])

然后就可以得到model.trainable_weights的梯度。

关于python - 使用 Tensorflow 2.0 的神经网络中的条件参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58240762/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com