gpt4 book ai didi

python - 带有 Keras 的 GradientTape 返回 0

转载 作者:行者123 更新时间:2023-12-02 04:37:45 25 4
gpt4 key购买 nike

我尝试将 GradientTape 与 Keras 模型(简化)一起使用,如下所示:

import tensorflow as tf
tf.enable_eager_execution()

input_ = tf.keras.layers.Input(shape=(28, 28))
flat = tf.keras.layers.Flatten()(input_)
output = tf.keras.layers.Dense(10, activation='softmax')(flat)
model = tf.keras.Model(input_, output)
model.compile(loss='categorical_crossentropy', optimizer='sgd')

import numpy as np
inp = tf.Variable(np.random.random((1,28,28)), dtype=tf.float32, name='input')
target = tf.constant([[1,0,0,0,0,0,0,0,0,0]], dtype=tf.float32)
with tf.GradientTape(persistent=True) as g:
g.watch(inp)
result = model(inp, training=False)

print(tf.reduce_max(tf.abs(g.gradient(result, inp))))

但是对于inp的一些随机值,梯度到处都是零,而对于其余的,梯度幅度非常小(<1e-7)。

我还使用 MNIST 训练的 3 层 MLP 进行了尝试,结果是相同的,但使用没有激活的 1 层线性模型进行了尝试。

这是怎么回事?

最佳答案

您正在计算 softmax 输出层的梯度 - 由于 softmax 的总和始终为 1,因此梯度(在多输入情况下,据我所知在维度上求和/平均)必须为 0 是有意义的 - - 该层的整体输出不能改变。我认为,获得 > 0 的小值的情况是数值问题。
当您删除激活函数时,此限制不再成立,并且激活可能会变得更大(意味着幅度 > 0 的梯度)。

您是否尝试使用梯度下降来构造输入,从而导致某个类别的概率非常大(如果不是,请忽略此...)? @jdehesa 已经包含了一种通过损失函数来做到这一点的方法。请注意,您也可以通过 softmax 来完成此操作,如下所示:

import tensorflow as tf
tf.enable_eager_execution()

input_ = tf.keras.layers.Input(shape=(28, 28))
flat = tf.keras.layers.Flatten()(input_)
output = tf.keras.layers.Dense(10, activation='softmax')(flat)
model = tf.keras.Model(input_, output)
model.compile(loss='categorical_crossentropy', optimizer='sgd')

import numpy as np
inp = tf.Variable(np.random.random((1,28,28)), dtype=tf.float32, name='input')
with tf.GradientTape(persistent=True) as g:
g.watch(inp)
result = model(inp, training=False)[:,0]

print(tf.reduce_max(tf.abs(g.gradient(result, inp))))

请注意,我仅获取第 0 列中的结果,对应于第一个类(我删除了 target 因为它未使用)。这将仅计算此类的 softmax 值的梯度,这是有意义的。

一些注意事项:

  • 在渐变磁带上下文管理器中进行索引非常重要!如果您在外部执行此操作(例如,在调用 g.gradient 的行中,这将不起作用(无渐变)
  • 您还可以使用 logits(softmax 之前的值)的梯度来代替。这是不同的,因为 softmax 概率可以通过降低其他类别的可能性来增加,而 logits 只能通过增加相关类别的“分数”来增加。

关于python - 带有 Keras 的 GradientTape 返回 0,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61771330/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com