gpt4 book ai didi

python - 广播旋转矩阵乘法

转载 作者:太空宇宙 更新时间:2023-11-03 14:48:13 25 4
gpt4 key购买 nike

如何做标有# <----的行以更直接的方式?

在程序中,每一行x是一个点的坐标,rot_mat[0]rot_mat[1]是两个旋转矩阵。程序轮换x通过每个旋转矩阵。

改变每个旋转矩阵和坐标之间的乘法顺序很好,如果它能让事情变得更简单的话。我想要每一行 x或表示点坐标的结果。

结果应与检查相符。

程序:

# Rotation of coordinates of 4 points by 
# each of the 2 rotation matrices.
import numpy as np
from scipy.stats import special_ortho_group
rot_mats = special_ortho_group.rvs(dim=3, size=2) # 2 x 3 x 3
x = np.arange(12).reshape(4, 3)
result = np.dot(rot_mats, x.T).transpose((0, 2, 1)) # <----
print("---- result ----")
print(result)
print("---- check ----")
print(np.dot(x, rot_mats[0].T))
print(np.dot(x, rot_mats[1].T))

结果:

---- result ----
[[[ 0.20382264 1.15744672 1.90230739]
[ -2.68064533 3.71537598 5.38610452]
[ -5.56511329 6.27330525 8.86990165]
[ -8.44958126 8.83123451 12.35369878]]

[[ 1.86544623 0.53905202 -1.10884323]
[ 5.59236544 -1.62845022 -4.00918928]
[ 9.31928465 -3.79595246 -6.90953533]
[ 13.04620386 -5.9634547 -9.80988139]]]
---- check ----
[[ 0.20382264 1.15744672 1.90230739]
[ -2.68064533 3.71537598 5.38610452]
[ -5.56511329 6.27330525 8.86990165]
[ -8.44958126 8.83123451 12.35369878]]
[[ 1.86544623 0.53905202 -1.10884323]
[ 5.59236544 -1.62845022 -4.00918928]
[ 9.31928465 -3.79595246 -6.90953533]
[ 13.04620386 -5.9634547 -9.80988139]]

最佳答案

使用np.tensordot对于涉及此类 tensors 的乘法 -

np.tensordot(rot_mats, x, axes=((2),(1))).swapaxes(1,2)

这里有一些时间可以说服我们自己为什么 tensordottensors 配合得更好 -

In [163]: rot_mats = np.random.rand(20,30,30)
...: x = np.random.rand(40,30)

# With numpy.dot
In [164]: %timeit np.dot(rot_mats, x.T).transpose((0, 2, 1))
1000 loops, best of 3: 670 µs per loop

# With numpy.tensordot
In [165]: %timeit np.tensordot(rot_mats, x, axes=((2),(1))).swapaxes(1,2)
10000 loops, best of 3: 75.7 µs per loop

In [166]: rot_mats = np.random.rand(200,300,300)
...: x = np.random.rand(400,300)

# With numpy.dot
In [167]: %timeit np.dot(rot_mats, x.T).transpose((0, 2, 1))
1 loop, best of 3: 1.82 s per loop

# With numpy.tensordot
In [168]: %timeit np.tensordot(rot_mats, x, axes=((2),(1))).swapaxes(1,2)
10 loops, best of 3: 185 ms per loop

关于python - 广播旋转矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47892097/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com