gpt4 book ai didi

python - 如何在较高排名的张量中找到第一个匹配张量的索引

转载 作者:行者123 更新时间:2023-11-30 09:28:54 24 4
gpt4 key购买 nike

与这个问题非常相似 How to find an index of the first matching element in TensorFlow

我尝试了解决方案但不同之处在于 val 不是单个数字,而是一个张量

举个例子

np.array([1, 1, 1],
[1, 0, 1],
[0, 0, 1])
val = np.array([1, 0, 1])


some tensorflow magic happens here!

result = 1

我知道我可以使用 while 循环,但这看起来很困惑。我可以尝试映射函数,但是有更优雅的东西吗?

最佳答案

这是一种方法 -

(arr == val).all(axis=-1).argmax()

示例运行 -

In [977]: arr
Out[977]:
array([[1, 1, 1],
[1, 0, 1],
[0, 0, 1]])

In [978]: val
Out[978]: array([1, 0, 1])

In [979]: (arr == val).all(axis=1).argmax()
Out[979]: 1

使用 View 可能会提高性能 -

# https://stackoverflow.com/a/44999009/ @Divakar
def view1D(a): # a is array
a = np.ascontiguousarray(a)
void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
return a.view(void_dt).ravel()

out = (view1D(arr) == view1D(val[None])).argmax()

扩展到 n 维情况

扩展到 n 维数组情况需要更多步骤 -

def first_match_index_along_axis(arr, val, axis):    
s = [None]*arr.ndim
s[axis] = Ellipsis
mask = val[np.s_[s]] == arr
idx = mask.all(axis=axis,keepdims=True).argmax()
shp = list(arr.shape)
del shp[axis]
return np.unravel_index(idx, shp)

示例运行 -

In [74]: arr = np.random.randint(0,9,(4,5,6,7))

In [75]: first_match_index_along_axis(arr, arr[2,:,1,0], axis=1)
Out[75]: (2, 1, 0)

In [76]: first_match_index_along_axis(arr, arr[2,1,3,:], axis=3)
Out[76]: (2, 1, 3)

关于python - 如何在较高排名的张量中找到第一个匹配张量的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47744964/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com