gpt4 book ai didi

python - 为什么这个最小的 RNN 代码会针对从未使用过的类型抛出类型错误?

转载 作者:行者123 更新时间:2023-12-01 04:21:02 24 4
gpt4 key购买 nike

我正在尝试在 theano 中实现一个最小的循环神经网络示例。我期望以下 python 脚本打印一个表示隐藏状态序列的 10×20 矩阵。

# import packages/functions
from theano import shared, scan, function, tensor as T
import numpy as np

# declare variables
X = T.dmatrix("X")
Wx = shared(np.random.uniform(-1.0, 1.0, (10, 20)))
Wh = shared(np.random.uniform(-1.0, 1.0, (20, 20)))
b = shared(np.random.uniform(-1.0, 1.0, (1, 20)))

# define recurrence function
def recurrence(x_t, h_tm1):
return T.nnet.sigmoid(T.dot(h_tm1, Wh) + T.dot(x_t, Wx) + b)

# compute hidden state sequence with scan
ht, _ = scan(fn = recurrence, sequences = X,
outputs_info = np.zeros((1, 20)))

# define function producing hidden state sequence
fn = function([X], ht)reshape((1,3))

# test function
print fn(np.eye(10))

相反,它返回错误: TypeError: Cannot convert Type TensorType(float64, 3D) (of Variable IncSubtensor{Set;:int64:}.0) into Type TensorType(float64, (False, True, False)). You can try to manually convert IncSubtensor{Set;:int64:}.0 into a TensorType(float64, (False, True, False)).

这尤其令人困惑,因为据我所知,我的变量都不是 3 张量!

最佳答案

代码中存在语法错误问题的代码:fn = 末尾的 reshape((1,3)) 无效,并且似乎已被错误添加。当该行的这个元素被简单删除时,代码就会运行。该答案的其余部分假设编辑是按照问题作者的意图进行的。

我没有重现所述错误。这可能是因为我使用的是 Theano 的前沿版本。这种情况的错误消息很可能在最新版本的代码中已更改。然而,上面代码中的语法错误暗示了另一种可能性:问题代码实际上并不是产生粘贴到问题中的错误的代码。

使用编辑后的代码和最新版本的 Theano,我收到错误

TypeError: ('The following error happened while compiling the node', forall_inplace,cpu,scan_fn}(Shape_i{0}.0, Subtensor{int64:int64:int8}.0, IncSubtensor{InplaceSet;:int64:}.0, , , ), '\n', "Inconsistency in the inner graph of scan 'scan_fn' : an input and an output are associated with the same recurrent state and should have the same type but have type 'TensorType(float64, row)' and 'TensorType(float64, matrix)' respectively.")

这在本质上与问题的错误类似,但指的是矩阵和行向量之间的不匹配;没有提及 3 张量。

避免此错误的最简单更改是更改 b 共享变量的形状。

而不是

b = shared(np.random.uniform(-1.0, 1.0, (1, 20)))

使用

b = shared(np.random.uniform(-1.0, 1.0, (20,)))

我还建议对 h_tm1 的初始值执行相同的操作。

而不是

outputs_info = np.zeros((1, 20))

使用

outputs_info = np.zeros((20,))

关于python - 为什么这个最小的 RNN 代码会针对从未使用过的类型抛出类型错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33696775/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com