gpt4 book ai didi

gradient - 仅计算共享变量数组的一部分的梯度

转载 作者:行者123 更新时间:2023-12-02 01:27:29 26 4
gpt4 key购买 nike

我想做以下事情:

import theano, numpy, theano.tensor as T

a = T.fvector('a')

w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX))
w_sub = w[1]

b = T.sum(a * w)

grad = T.grad(b, w_sub)

这里,w_sub 例如 w[1] 但我不想在 w_sub 的函数中显式写出 b。尽管经历了 this以及其他相关问题我无法解决。

这只是为了向您展示我的问题。实际上,我真正想做的是与 Lasagne 进行稀疏卷积。权重矩阵中的零项不需要更新,因此不需要为 w 的这些项计算梯度。

这是完整的错误信息:

Traceback (most recent call last):
File "D:/Jeroen/Project_Lasagne_General/test_script.py", line 9, in <module>
grad = T.grad(b, w_sub)
File "C:\Anaconda2\lib\site-packages\theano\gradient.py", line 545, in grad
handle_disconnected(elem)
File "C:\Anaconda2\lib\site-packages\theano\gradient.py", line 532, in handle_disconnected
raise DisconnectedInputError(message)
theano.gradient.DisconnectedInputError: grad method was asked to compute the gradient with respect to a variable that is not part of the computational graph of the cost, or is used only by a non-differentiable operator: Subtensor{int64}.0
Backtrace when the node is created:
File "D:/Jeroen/Project_Lasagne_General/test_script.py", line 6, in <module>
w_sub = w[1]

最佳答案

当 theano 编译图表时,它只会看到图表中明确定义的变量。在您的示例中,w_sub 未明确用于 b 的计算,因此不是计算图的一部分。

使用theano打印库,代码如下,可以在上面看到 graph vizualization确实 w_sub 不是 b 图的一部分。

import theano
import theano.tensor as T
import numpy
import theano.d3viz as d3v

a = T.fvector('a')
w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX))
w_sub = w[1]
b = T.sum(a * w)

o = b, w_sub

d3v.d3viz(o, 'b.html')

要解决此问题,您需要在 b 的计算中显式使用 w_sub

然后您将能够计算 b wrt w_sub 的梯度并更新共享变量的值,如下例所示:

import theano
import theano.tensor as T
import numpy


a = T.fvector('a')
w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX))
w_sub = w[1]
b = T.sum(a * w_sub)
grad = T.grad(b, w_sub)
updates = [(w, T.inc_subtensor(w_sub, -0.1*grad))]

f = theano.function([a], b, updates=updates, allow_input_downcast=True)

f(numpy.arange(10))

关于gradient - 仅计算共享变量数组的一部分的梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36197759/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com