gpt4 book ai didi

python - Strassen 矩阵乘法——接近,但仍然存在错误

转载 作者:行者123 更新时间:2023-11-30 23:40:10 26 4
gpt4 key购买 nike

我正在尝试在 Python 中实现 Strassen 矩阵乘法。我已经让它发挥了一些作用。这是我的代码:

a = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]

def new_m(p, q): # create a matrix filled with 0s
matrix = [[0 for row in range(p)] for col in range(q)]
return matrix

def straight(a, b): # multiply the two matrices
if len(a[0]) != len(b): # if # of col != # of rows:
return "Matrices are not m*n and n*p"
else:
p_matrix = new_m(len(a), len(b[0]))
for i in range(len(a)):
for j in range(len(b[0])):
for k in range(len(b)):
p_matrix[i][j] += a[i][k]*b[k][j]
return p_matrix

def split(matrix): # split matrix into quarters
a = matrix
b = matrix
c = matrix
d = matrix
while(len(a) > len(matrix)/2):
a = a[:len(a)//2]
b = b[:len(b)//2]
c = c[len(c)//2:]
d = d[len(d)//2:]
while(len(a[0]) > len(matrix[0])/2):
for i in range(len(a[0])//2):
a[i] = a[i][:len(a[i])//2]
b[i] = b[i][len(b[i])//2:]
c[i] = c[i][:len(c[i])//2]
d[i] = d[i][len(d[i])//2:]
return a,b,c,d

def add_m(a, b):
if type(a) == int:
d = a + b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] + b[i][j])
d.append(c)
return d

def sub_m(a, b):
if type(a) == int:
d = a - b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] - b[i][j])
d.append(c)
return d


def strassen(a, b, q):
# base case: 1x1 matrix
if q == 1:
d = [[0]]
d[0][0] = a[0][0] * b[0][0]
return d
else:
#split matrices into quarters
a11, a12, a21, a22 = split(a)
b11, b12, b21, b22 = split(b)

# p1 = (a11+a22) * (b11+b22)
p1 = strassen(add_m(a11,a22), add_m(b11,b22), q/2)

# p2 = (a21+a22) * b11
p2 = strassen(add_m(a21,a22), b11, q/2)

# p3 = a11 * (b12-b22)
p3 = strassen(a11, sub_m(b12,b22), q/2)

# p4 = a22 * (b12-b11)
p4 = strassen(a22, sub_m(b12,b11), q/2)

# p5 = (a11+a12) * b22
p5 = strassen(add_m(a11,a12), b22, q/2)

# p6 = (a21-a11) * (b11+b12)
p6 = strassen(sub_m(a21,a11), add_m(b11,b12), q/2)

# p7 = (a12-a22) * (b21+b22)
p7 = strassen(sub_m(a12,a22), add_m(b21,b22), q/2)


# c11 = p1 + p4 - p5 + p7
c11 = add_m(sub_m(add_m(p1, p4), p5), p7)

# c12 = p3 + p5
c12 = add_m(p3, p5)

# c21 = p2 + p4
c21 = add_m(p2, p4)

# c22 = p1 + p3 - p2 + p6
c22 = add_m(sub_m(add_m(p1, p3), p2), p6)

c = new_m(len(c11)*2,len(c11)*2)
for i in range(len(c11)):
for j in range(len(c11)):
c[i][j] = c11[i][j]
c[i][j+len(c11)] = c12[i][j]
c[i+len(c11)][j] = c21[i][j]
c[i+len(c11)][j+len(c11)] = c22[i][j]

return c

print "Strassen Outputs:"
print strassen(a, b, 4)
print "Should be:"
print straight(a, b)

我包含了直接矩阵乘法,以引用正确的所需输出。基本上会发生这种情况:

施特拉森输出:

[[10, 14, 22, 26], [32, 36, 48, 52], [58, 66, 70, 78], [80, 88, 96, 104]]

应该是:

[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]

我不确定问题的根源是什么,这意味着我无法解决它!

最佳答案

这不应该是:

# p4 = a22 * (b12-b11)
p4 = strassen(a22, sub_m(b12,b11), q/2)

是:

# p4 = a22 * (b21-b11)
p4 = strassen(a22, sub_m(b21,b11), q/2)

相反?

~/coding$ python -i strass.py
Strassen Outputs:
[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]
Should be:
[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]]
>>> import numpy
>>> def check():
... for i in range(100):
... a = numpy.random.randint(0, 10,size=(4,4)).tolist()
... b = numpy.random.randint(0, 10,size=(4,4)).tolist()
... assert strassen(a,b,4) == straight(a,b)
... assert (numpy.array(strassen(a,b,4)) == numpy.dot(a,b)).all()
... print 'hooray!'
...
>>> check()
hooray!

关于python - Strassen 矩阵乘法——接近,但仍然存在错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/12867099/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com