gpt4 book ai didi

python - 稀疏张量的 while_loop 中的 InvalidArgumentError

转载 作者:太空宇宙 更新时间:2023-11-03 21:19:06 25 4
gpt4 key购买 nike

我正在使用 while_loop 迭代更新矩阵。对于密集张量,循环运行良好,但是当我使用稀疏张量时,出现以下错误:

InvalidArgumentError: Number of rows of a_indices does not match number of entries in a_values [[Node: while/SparseTensorDenseMatMul/SparseTensorDenseMatMul = SparseTensorDenseMatMul[T=DT_FLOAT, Tindices=DT_INT64, adjoint_a=false, adjoint_b=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](while/SparseTensorDenseMatMul/SparseTensorDenseMatMul/Enter, while/SparseTensorDenseMatMul/SparseTensorDenseMatMul/Enter_1, ConstantFolding/dense_to_sparse/Shape_enter/_1, while/Switch_1:1)]]
[[Node: while/Exit_1/_5 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_62_while/Exit_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

我在两个版本之间唯一改变的是用 HH=tf.contrib.layers.dense_to_sparse(HH) 转换 HH 并使用 tf.sparse_tensor_dense_matmul(HH,f) 而不是 tf.matmul(HH,f) - 如下面的注释代码所示。

with tf.device('/gpu:0'):
g=tf.constant(g,shape=[np.size(g),1],dtype=tf.float32)
H=tf.constant(H,dtype=tf.float32);
Ht=tf.transpose(H)
HH=tf.matmul(Ht,H)
#HH=tf.contrib.layers.dense_to_sparse(HH)
a=tf.matmul(Ht,g)
i=tf.constant(0,dtype=tf.int32)
f=tf.constant(f,dtype=tf.float32)
body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.matmul(HH,f)+10e-9))
#body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.sparse_tensor_dense_matmul(HH,f)+10e-9))
cond= lambda i,f:tf.less(i,iterations)
i,f=tf.while_loop(cond,body,(i,f))
sess=tf.Session()
i,f=sess.run([i,f])

请注意,只要 H、g 和 f 足够小,此代码就可以工作。例如,此错误发生在 H.shape=(8000,3840) 、g.shape=(8000,1)、f.shape=(3840,1) 和更大的情况下,但对于 H.shape=(8000,第3584章 ,g.shape=(8000,1),f.shape=(3584,1)或更小我是否需要在 while 循环中对稀疏张量做一些特殊的事情以确保它们保持其形状?

最佳答案

我尝试从tensorflow 1.8更新到1.12,但tensorflow完全停止工作(ts.Session将无限期挂起)。因此,我破坏了 anaconda 环境,并从头开始使用 TensorFlow 1.12。更新/重新安装后,稀疏张量的问题消失了,但尚不清楚问题是否出在我的 anaconda 环境中的 tensorflow 版本或其他问题上。

关于python - 稀疏张量的 while_loop 中的 InvalidArgumentError,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54444626/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com