gpt4 book ai didi

tensorflow - 来自 Tensorflow 中 4-D 张量的 tf.nn.top_k 索引的二进制掩码?

转载 作者:行者123 更新时间:2023-12-03 23:26:13 24 4
gpt4 key购买 nike

我有一个形状为 (10, 32, 32, 128) 的 4-D 张量。我想为所有前 N 个元素生成一个二进制掩码。

arr = tf.random_normal(shape=(10, 32, 32, 128))
values, indices = tf.nn.top_k(arr, N=64)

我的问题是如何获得与 arr 形状相同的二进制掩码使用 indices返回者 tf.nn.top_k

最佳答案

如果有人正在寻找答案:就在这里。

K = 64
arr = tf.random_normal(shape=(10, 32, 32, 128))
values, indices = tf.nn.top_k(arr, k=K, sorted=False)

temp_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(
tf.shape(arr)[:(arr.get_shape().ndims - 1)]) + [K])], indexing='ij')
temp_indices = tf.stack(temp_indices[:-1] + [indices], axis=-1)
full_indices = tf.reshape(temp_indices, [-1, arr.get_shape().ndims])
values = tf.reshape(values, [-1])

mask_st = tf.SparseTensor(indices=tf.cast(
full_indices, dtype=tf.int64), values=tf.ones_like(values), dense_shape=arr.shape)
mask = tf.sparse_tensor_to_dense(tf.sparse_reorder(mask_st))

关于tensorflow - 来自 Tensorflow 中 4-D 张量的 tf.nn.top_k 索引的二进制掩码?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43294421/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com