gpt4 book ai didi

python - 在 tensorflow 中,如何迭代存储在张量中的一系列输入?

转载 作者:太空狗 更新时间:2023-10-29 19:37:19 25 4
gpt4 key购买 nike

我正在尝试使用 RNN 解决可变长度多变量序列分类问题。

我定义了以下函数来获取序列的输出(即在输入序列的最终输入后​​ RNN 单元的输出)

def get_sequence_output(x_sequence, initial_hidden_state):
previous_hidden_state = initial_hidden_state
for x_single in x_sequence:
hidden_state = gru_unit(previous_hidden_state, x_single)
previous_hidden_state = hidden_state
final_hidden_state = hidden_state
return final_hidden_state

这里 x_sequence 是形状为 (?, ?, 10) 的张量,第一个?是批量大小和第二?用于序列长度,每个输入元素的长度为 10。gru 函数采用前一个隐藏状态和当前输入,并吐出下一个隐藏状态(标准门控循环单元)。

我收到一个错误:'Tensor' 对象不可迭代。如何按顺序迭代张量(一次读取单个元素)?

我的目标是对序列中的每个输入应用 gru 函数并获得最终的隐藏状态。

最佳答案

在TF>=1.0中,tf.packtf.unpack重命名为tf.stacktf.unstack 分别

关于python - 在 tensorflow 中,如何迭代存储在张量中的一系列输入?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38524776/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com