gpt4 book ai didi

python - 形状为 b*n*3 的 Tensorflow 张量 T1。形状为 b*n 的 T2 -> 一个 bool 张量,指示 T1 中要采用哪些行

转载 作者:太空宇宙 更新时间:2023-11-03 20:53:25 28 4
gpt4 key购买 nike

我正在尝试提取 n 点 3D 坐标和 b 批处理中的特定行。本质上,我的张量 T1 的形状为 b*n*3。我有另一个形状为 b * n 的 bool 张量 T2,指示需要获取 n 的哪些行。本质上我的输出应该是 b*?*3 因为 T2 在每批中可以有不同数量的 1。

我已经使用 bool 掩码实现了以下内容,但输出不符合预期,并且输出形状为 (?,) 但不是 (b*?*3)

# expand T2 to (b,n,3). i.e. 0 replicates to (0,0,0) and so is 1

mask = tf.tile(tf.expand_dims(T2,2), [1,1,3])

# query using boolean mask where there are 1s

valid_KPs = tf.boolean_mask(T1, tf.cast(mask, tf.int32))

最佳答案

由于每个示例选择的元素数量可能不同,因此无法表示为适当的张量。一种选择是使用 ragged tensor 。它们不能做普通张量可以做的所有事情,但可以实现你想要的,例如这样:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
# Input data
t1 = tf.constant([
[
[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
],
[
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
],
])
t2 = tf.constant([
[1, 0, 1],
[0, 1, 0],
])
# Count the number of ones for each row in T2
c = tf.reduce_sum(t2, axis=1)
# Ragged ranges for each row
r = tf.ragged.range(c)
# Sorting indices so indices with a one are first
s = tf.argsort(t2, axis=1, direction='DESCENDING', stable=True)
# First axis dimension index
idx0 = tf.expand_dims(tf.range(tf.shape(t1)[0]), 1) * tf.ones_like(r)
# 2D index for getting indices of ones on each row
idx_s = tf.stack([idx0, r], axis=-1)
# Get indices of ones
idx1 = tf.gather_nd(s, idx_s)
# 2D index to get indices of selected vectors in T1
idx = tf.stack([idx0, idx1], axis=-1)
# Get selected vectors
result = tf.gather_nd(t1, idx)
# Print result
print(sess.run(result))
# <tf.RaggedTensorValue [[[1, 2, 3], [7, 8, 9]], [[13, 14, 15]]]>

关于python - 形状为 b*n*3 的 Tensorflow 张量 T1。形状为 b*n 的 T2 -> 一个 bool 张量,指示 T1 中要采用哪些行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56167596/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com