gpt4 book ai didi

python - Theano Scan Op 梯度中的断开输入

转载 作者:行者123 更新时间:2023-12-01 05:44:49 24 4
gpt4 key购买 nike

我有许多不同大小的项目。对于每一组,一个(已知)项目是“正确”的。有一个函数可以为每个项目分配一个分数。这会产生项目分数的平面向量,以及告诉索引每个组从哪里开始以及有多大的向量。我希望对每组中的分数进行“softmax”操作来分配项目概率,然后取正确答案概率的对数之和。这是一个更简单的版本,我们只是返回正确答案的分数,而不使用 softmax 和对数。

import numpy                                                                                                                                                                                                                                                                          
import theano
import theano.tensor as T
from theano.printing import Print

def scoreForCorrectAnswer(groupSize, offset, correctAnswer, preds):
# for each group, this will get called with the size of
# the group, the offset of where the group begins in the
# predictions vector, and which item in that group is correct
relevantPredictions = preds[offset:offset+groupSize]
ans = Print("CorrectAnswer")(correctAnswer)
return relevantPredictions[ans]

groupSizes = T.ivector('groupSizes')
offsets = T.ivector('offsets')
x = T.fvector('x')
W = T.vector('W')
correctAnswers = T.ivector('correctAnswers')

# for this simple example, we'll just score the items by
# element-wise product with a weight vector
predictions = x * W

(values, updates) = theano.map(fn=scoreForCorrectAnswer,
sequences = [groupSizes, offsets, correctAnswers],
non_sequences = [predictions] )

func = theano.function([groupSizes, offsets, correctAnswers,
W, x], [values])

sampleInput = numpy.array([0.1,0.7,0.3,0.05,0.3,0.3,0.3], dtype='float32')
sampleW = numpy.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype='float32')
sampleOffsets = numpy.array([0,4], dtype='int32')
sampleGroupSizes = numpy.array([4,3], dtype='int32')
sampleCorrectAnswers = numpy.array([1,2], dtype='int32')

data = func (sampleGroupSizes, sampleOffsets, sampleCorrectAnswers, sampleW, sampleInput)
print data

#these all three raise the same exception (see below)
gW1 = T.grad(cost=T.sum(values), wrt=W)
gW2 = T.grad(cost=T.sum(values), wrt=W, disconnected_inputs='warn')
gW3 = T.grad(cost=T.sum(values), wrt=W, consider_constant=[groupSizes,offsets])

这可以正确计算输出,但是当我尝试获取相对于参数 W 的梯度时,我得到(路径缩写):

Traceback (most recent call last):
File "test_scan_for_stackoverflow.py", line 37, in <module>
gW = T.grad(cost=T.sum(values), wrt=W)
File "Theano-0.6.0rc2-py2.7.egg/theano/gradient.py", line 438, in grad
outputs, wrt, consider_constant)
File "Theano-0.6.0rc2-py2.7.egg/theano/gradient.py", line 698, in _populate_var_to_app_to_idx
account_for(output)
File "Theano-0.6.0rc2-py2.7.egg/theano/gradient.py", line 694, in account_for
account_for(ipt)
File "Theano-0.6.0rc2-py2.7.egg/theano/gradient.py", line 669, in account_for
connection_pattern = _node_to_pattern(app)
File "Theano-0.6.0rc2-py2.7.egg/theano/gradient.py", line 554, in _node_to_pattern
connection_pattern = node.op.connection_pattern(node)
File "Theano-0.6.0rc2-py2.7.egg/theano/scan_module/scan_op.py", line 1331, in connection_pattern
ils)
File "Theano-0.6.0rc2-py2.7.egg/theano/scan_module/scan_op.py", line 1266, in compute_gradient
known_grads={y: g_y}, wrt=x)
File "Theano-0.6.0rc2-py2.7.egg/theano/gradient.py", line 511, in grad
handle_disconnected(elem)
File "Theano-0.6.0rc2-py2.7.egg/theano/gradient.py", line 497, in handle_disconnected
raise DisconnectedInputError(message)
theano.gradient.DisconnectedInputError: grad method was asked to compute
the gradient with respect to a variable that is not part of the
computational graph of the cost, or is used only by a
non-differentiable operator: groupSizes[t]

现在,groupSizes 是恒定的,因此没有理由需要对其进行任何渐变。通常,您可以通过抑制 DisconnectedInputError 或告诉 Theano 将 groupSizes 视为 T.grad 调用中的常量来处理此问题(请参阅示例脚本的最后几行)。但似乎没有任何方法可以将这些东西传递给 ScanOp 梯度计算中的内部 T.grad 调用。

我错过了什么吗?这是通过 ScanOp 进行梯度计算的方法吗?

最佳答案

截至 2 月中旬,这被证明是 Theano 的一个错误。 2013(0.6.0rc-2)。截至本文发布之日,该问题已在 github 上的开发版本中修复。

关于python - Theano Scan Op 梯度中的断开输入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/16426641/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com