gpt4 book ai didi

python - tensorflow 中 softmax 的 InvalidArgumentError

转载 作者:太空宇宙 更新时间:2023-11-04 02:51:49 25 4
gpt4 key购买 nike

我有以下功能:

def forward_propagation(self, x):
# The total number of time steps
T = len(x)
# During forward propagation we save all hidden states in s because need them later.
# We add one additional element for the initial hidden, which we set to 0
s = tf.zeros([T+1, self.hidden_dim])
# The outputs at each time step. Again, we save them for later.
o = tf.zeros([T, self.word_dim])


a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)

s_t = tf.nn.tanh(a + tf.reduce_sum(tf.multiply(b, c)))
o_t = tf.nn.softmax(tf.reduce_sum(tf.multiply(a, b)))
# For each time step...
with tf.Session() as sess:
s = sess.run(s)
o = sess.run(o)
for t in range(T):
# Note that we are indexing U by x[t]. This is the same as multiplying U with a one-hot vector.
s[t] = sess.run(s_t, feed_dict={a: self.U[:, x[t]], b: self.W, c: s[t-1]})
o[t] = sess.run(o_t, feed_dict={a: self.V, b: s[t]})
return [o, s]

self.U、self.V 和 self.W 是 numpy 数组。我试着让 softmax 上

o_t = tf.nn.softmax(tf.reduce_sum(tf.multiply(a, b)))

图表,它在这一行给我错误:

o[t] = sess.run(o_t, feed_dict={a: self.V, b: s[t]})

错误是:

InvalidArgumentError (see above for traceback): Expected begin[0] == 0 (got -1) and size[0] == 0 (got 1) when input.dim_size(0) == 0
[[Node: Slice = Slice[Index=DT_INT32, T=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](Shape_1, Slice/begin, Slice/size)]]

我应该如何在 tensorflow 中获得 softmax?

最佳答案

问题的出现是因为您在 tf.nn.softmax 的参数上调用了 tf.reduce_sum。结果,softmax 函数失败,因为标量不是有效的输入参数。您是要使用 tf.matmul 而不是 tf.reduce_sumtf.multiply 的组合吗?

编辑:Tensorflow 不提供开箱即用的 np.dot 等价物。如果你想计算矩阵和向量的点积,你需要显式地对索引求和:

# equivalent to np.dot(a, b) if a.ndim == 2 and b.ndim == 1
c = tf.reduce_sum(a * b, axis=1)

关于python - tensorflow 中 softmax 的 InvalidArgumentError,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43632476/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com