gpt4 book ai didi

keras - 如何获得 keras 模型相对于其输入的梯度?

转载 作者:行者123 更新时间:2023-12-03 16:37:49 26 4
gpt4 key购买 nike

我只是问了一个关于相同主题的问题,但针对自定义模型( How do I find the derivative of a custom model in Keras? ),但很快意识到这是在我可以走路之前试图运行,因此该问题已被标记为该问题的重复。

我试图简化我的场景,现在有一个(非自定义)keras模型包括 2 Dense层:

inputs = tf.keras.Input((cols,), name='input')

layer_1 = tf.keras.layers.Dense(
10,
name='layer_1',
input_dim=cols,
use_bias=True,
kernel_initializer=tf.constant_initializer(0.5),
bias_initializer=tf.constant_initializer(0.1))(inputs)

outputs = tf.keras.layers.Dense(
1,
name='alpha',
use_bias=True,
kernel_initializer=tf.constant_initializer(0.1),
bias_initializer=tf.constant_initializer(0))(layer_1)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

prediction = model.predict(input_data)
# gradients = ...

现在想知道 outputs的导数关于 inputsinputs = input_data .

到目前为止我尝试过的:

This answer to a different question建议运行 grads = K.gradients(model.output, model.input) .但是,如果我运行它,则会出现此错误;

tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.



我只能假设这与现在默认的急切执行有关。

另一种方法是回答我关于自定义 keras 模型的问题,其中包括添加以下内容:
with tf.GradientTape() as tape:
x = tf.Variable(np.random.normal(size=(10, rows, cols)), dtype=tf.float32)
out = model(x)

我不明白这种方法是我应该如何加载数据。它需要 x成为 variable ,但我的 xtf.keras.Input目的。我也不明白那是什么 with声明正在做,某种魔法,但我不明白。

这里有一个与此非常相似的问题: Get Gradients with Keras Tensorflow 2.0尽管应用程序和场景大不相同,我很难将答案应用于此场景。它确实让我将以下内容添加到我的代码中:
with tf.GradientTape() as t:
t.watch(outputs)

那确实有效,但现在呢?我跑 model.predict(...) ,但是我如何获得我的渐变?答案说我应该运行 t.gradient(outputs, x_tensor).numpy() ,但我要为 x_tensor 投入什么? ?我没有输入变量。我试过运行 t.gradient(outputs, model.inputs)运行后 predict ,但这导致了:

enter image description here

最佳答案

我最终得到了这个问题的答案的变体:Get Gradients with Keras Tensorflow 2.0

x_tensor = tf.convert_to_tensor(input_data, dtype=tf.float32)
with tf.GradientTape() as t:
t.watch(x_tensor)
output = model(x_tensor)

result = output
gradients = t.gradient(output, x_tensor)

这使我无需冗余计算即可获得输出和梯度。

关于keras - 如何获得 keras 模型相对于其输入的梯度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59590766/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com