gpt4 book ai didi

python - Tensorflow,获取矩阵中每一行的非零值索引

转载 作者:太空宇宙 更新时间:2023-11-04 04:10:51 29 4
gpt4 key购买 nike

所以我想获取矩阵中每一行不为零的值的索引。我已尝试使用 tf.where,但输出与我预期的不同。

我现在的代码是:

b = tf.constant([[1,0,0,0,0],
[1,0,1,0,1]],dtype=tf.float32)
zero = tf.constant(0, dtype=tf.float32)
where = tf.not_equal(b, zero)
indices = tf.where(where)

索引输出是:

<tf.Tensor: id=136, shape=(4, 2), dtype=int64, numpy=
array([[0, 0],
[1, 0],
[1, 2],
[1, 4]])>

但我希望输出是:

[[0],
[0,2,4]]

我有一个包含每行索引的列表。

谢谢。

最佳答案

这不可能是一个合适的张量,因为尺寸不均匀。如果您可以使用 ragged tensor你可以这样做:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
b = tf.constant([[1, 0, 0, 0, 0],
[1, 0, 1, 0, 1]],dtype=tf.float32)
num_rows = tf.shape(b)[0]
zero = tf.constant(0, dtype=tf.float32)
where = tf.not_equal(b, zero)
indices = tf.where(where)
s = tf.ragged.segment_ids_to_row_splits(indices[:, 0], num_rows)
row_start = s[:-1]
elem_per_row = s[1:] - row_start
idx = tf.expand_dims(row_start, 1) + tf.ragged.range(elem_per_row)
result = tf.gather(indices[:, 1], idx)
print(sess.run(result))
# <tf.RaggedTensorValue [[0], [0, 2, 4]]>

编辑:如果您不想或不能使用参差不齐的张量,这里有一个替代方案。您可以生成一个填充有“无效”值的张量。例如,您可以在这些无效值中使用 -1,或者只使用一维张量来告诉您每行有多少个有效值:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
b = tf.constant([[1, 0, 0, 0, 0],
[1, 0, 1, 0, 1]],dtype=tf.float32)
num_rows = tf.shape(b)[0]
zero = tf.constant(0, dtype=tf.float32)
where = tf.not_equal(b, zero)
indices = tf.where(where)
num_indices = tf.shape(indices)[0]
elem_per_row = tf.bincount(tf.cast(indices[:, 0], tf.int32), minlength=num_rows)
row_start = tf.concat([[0], tf.cumsum(elem_per_row[:-1])], axis=0)
max_elem_per_row = tf.reduce_max(elem_per_row)
r = tf.range(max_elem_per_row)
idx = tf.expand_dims(row_start, 1) + r
idx = tf.minimum(idx, num_indices - 1)
result = tf.gather(indices[:, 1], idx)
# Optional: replace invalid elements with -1
result = tf.where(tf.expand_dims(elem_per_row, 1) > r, result, -tf.ones_like(result))
print(sess.run(result))
# [[ 0 -1 -1]
# [ 0 2 4]]
print(sess.run(elem_per_row))
# [1 3]

关于python - Tensorflow,获取矩阵中每一行的非零值索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56292809/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com