gpt4 book ai didi

Tensorflow:如何获取多个矩阵的对角线(批处理模式)

转载 作者:行者123 更新时间:2023-12-05 06:41:29 25 4
gpt4 key购买 nike

我有一个张量 [100 X 16 X 16] .我想得到这个张量的对角线元素来得到形状为 [100 X 16] 的张量.我尝试了以下方法:

#sum_cov[100 X 16 X 16]diagonal_elements预计为 [100 X 16] .

diagonal_elements = tf.diag_part(sum_cov)

但是,我收到以下错误:

Input must have even rank <= 6, input rank is 3 for 'DiagPart'

谁能告诉我如何实现这一点?

最佳答案

https://www.tensorflow.org/api_docs/python/tf/diag_part https://www.tensorflow.org/api_docs/python/tf/matrix_diag_part

dm0_是正确的。你想要 tf.matrix_diag_part

tf.diag_part 计算张量对角线,如果输入张量的形状为 (D1, ..., Dk, D1, ..., Dk) 则输出张量的形状为 (D1, ..., Dk) 并且是这样的

tf.diag_part(input)[i1, ..., ik] = input[i1, ..., ik, i1, ..., ik]

这就是您收到错误的原因。为了满足上述先决条件,输入张量必须具有偶数秩。

tf.matrix_diag_part 另一方面,将输入张量视为一批二维矩阵,并计算每个矩阵的对角线。因此,如果输入张量的形状为 (I, J, K, ..., M, N),则输出张量的形状将为 (I, J, K, ..., min(M, N)) 并且是这样的

tf.matrix_diag_part(input)[i, j, k, ..., m] = input[i, j, k, ..., m, m]

对于 2 阶张量,这两个函数是相同的,但对于任何更高的张量,它们都是非常不同的野兽。

关于Tensorflow:如何获取多个矩阵的对角线(批处理模式),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40469216/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com