gpt4 book ai didi

tensorflow - 为稀疏密集乘法提供 tensorflow 稀疏矩阵。收到以下错误 : TypeError: Input must be a SparseTensor

转载 作者:行者123 更新时间:2023-12-03 10:04:57 25 4
gpt4 key购买 nike

我正在尝试使用 tensorflow tf.sparse_tensor_dense_matmul(X, W1)。X 定义为 tf.placeholder:

X = tf.placeholder("float", [None, size])

W 是 tf.Variable

在输入字典时,我传入了 tensorflow 稀疏矩阵。但我收到错误:

TypeError: Input must be a SparseTensor.

如何让 sparse_tensor_dense_matmul 模块知道我将传入稀疏张量?

最佳答案

要通过占位符传递 SparseTensor,您可以使用 sparse_placeholder :

sparse_place = tf.sparse_placeholder(tf.float64)
mul_result = tf.sparse_tensor_dense_matmul(sparse_place, some_dense_tensor)

您可以按如下方式使用它:

dense_version = tf.sparse_tensor_to_dense(sparse_place)
sess.run(dense_version, feed_dict={sparse_place: tf.SparseTensorValue([[0,1], [2,2]], [1.5, 2.9], [3, 3])})
Out: array([[ 0. , 1.5, 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 2.9]])

或者,您可以为值、形状和索引创建三个单独的占位符,例如:

indices = tf.placeholder(tf.int64)
shape = tf.placeholder(tf.int64)
values = tf.placeholder(tf.float64)
sparse_tensor = tf.SparseTensor(indices, shape, values)
sess.run(sparse_tensor, feed_dict={shape: [3, 3], indices: [[0,1], [2,2]], values: [1.5, 2.9]})

关于tensorflow - 为稀疏密集乘法提供 tensorflow 稀疏矩阵。收到以下错误 : TypeError: Input must be a SparseTensor,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41105751/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com