gpt4 book ai didi

python - 从连续的数组切片或卷制作矩阵

转载 作者:行者123 更新时间:2023-11-28 20:33:59 25 4
gpt4 key购买 nike

我有一个像这样的数组:

[10 20 30 40]

我想像这样构建一个矩阵M1:

10  0  0  0
20 10 0 0
30 20 10 0
40 30 20 10

我的方法是首先从数组的连续“滚动”中构建以下矩阵 M2:

10 20 30 40
20 10 40 30
30 20 10 40
40 30 20 10

然后用np.tril取下三角矩阵.我会对直接构建 M2M1 而不通过 M2 的有效方法感兴趣。

构建 M2 的简单方法可能是:

import numpy as np

def M2_simple(a):
a = np.asarray(a)
return np.stack([a[np.arange(-i, len(a) - i)] for i in range(len(a))]).T

print(M2_simple(np.array([10, 20, 30, 40])))
# [[10 40 30 20]
# [20 10 40 30]
# [30 20 10 40]
# [40 30 20 10]]

经过一些尝试,我得出以下更好的解决方案,基于 advanced indexing :

def M2_indexing(a):
a = np.asarray(a)
r = np.arange(len(a))[np.newaxis]
return a[np.newaxis][np.zeros_like(r), r - r.T].T

这显然比以前快得多,但测量性能似乎仍然没有那么快(例如,它比平铺花费的时间长一个数量级,不是那么不同操作),这需要我构建大型索引矩阵。

有没有更好的方法来构建这些矩阵?

最佳答案

编辑:

实际上,您可以使用相同的方法直接构建M1:

import numpy as np

def M1_strided(a):
a = np.asarray(a)
n = len(a)
s, = a.strides
a0 = np.concatenate([np.zeros(len(a) - 1, a.dtype), a])
return np.lib.stride_tricks.as_strided(
a0, (n, n), (s, s), writeable=False)[:, ::-1]

print(M1_strided(np.array([10, 20, 30, 40])))
# [[10 0 0 0]
# [20 10 0 0]
# [30 20 10 0]
# [40 30 20 10]]

在这种情况下,速度优势甚至更好,因为您正在保存对 np.tril 的调用:

N = 100
a = np.square(np.arange(N))
%timeit np.tril(M2_simple(a))
# 792 µs ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.tril(M2_indexing(a))
# 259 µs ± 9.45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.tril(M2_strided(a))
# 134 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit M1_strided(a)
# 45.2 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

您可以使用 np.lib.stride_tricks.as_strided 更有效地构建 M2 矩阵:

import numpy as np
from numpy.lib.stride_tricks import as_strided

def M2_strided(a):
a = np.asarray(a)
n = len(a)
s, = a.strides
return np.lib.stride_tricks.as_strided(
np.tile(a[::-1], 2), (n, n), (s, s), writeable=False)[::-1]

作为一个额外的好处,您将只使用原始数组两倍的内存(而不是平方大小)。你只需要小心不要写入这样创建的数组(如果你稍后要调用 np.tril 应该不是问题) - 我添加了 writeable=False 来禁止写操作。

与 IPython 的快速速度比较:

N = 100
a = np.square(np.arange(N))
%timeit M2_simple(a)
# 693 µs ± 17.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit M2_indexing(a)
# 163 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit M2_strided(a)
# 38.3 µs ± 348 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

关于python - 从连续的数组切片或卷制作矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49532575/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com