我引用了斯特拉森的例子:https://www.geeksforgeeks.org/strassens-matrix-multiplication/如果矩阵的维数是奇数,我用1行1列的0填充它。我在5x5矩阵上得到了这个错误,代码将其转换为6x6填充0:
Traceback (most recent call last):
File "/home/surfacepro/Downloads/strassen.py", line 82, in <module>
main()
File "/home/surfacepro/Downloads/strassen.py", line 77, in main
print(strassen(matrixA, matrixB))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/surfacepro/Downloads/strassen.py", line 33, in strassen
p1 = strassen(a, f - h)
^^^^^^^^^^^^^^^^^^
File "/home/surfacepro/Downloads/strassen.py", line 35, in strassen
p3 = strassen(c + d, e)
^^^^^^^^^^^^^^^^^^
File "/home/surfacepro/Downloads/strassen.py", line 49, in strassen
c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib64/python3.11/site-packages/numpy/core/shape_base.py", line 289, in vstack
return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 1 and the array at index 1 has size 0
matrix_A.txt:https://pastebin.com/9vMKEjJd
matrix_B.txt:https://pastebin.com/F7Xs4Ciz
# Version 3.6
import numpy as np
import re, math
def split(matrix):
"""
Splits a given matrix into quarters.
Input: n1 xn matrix
Output: tuple containing 4 n/2 x n/2 matrices corresponding to a, b, c, d
"""
row, col = matrix.shape
row2, col2 = row//2, col//2
return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]
def strassen(x, y):
"""
Computes matrix product by divide and conquer approach, recursively.
Input: nxn matrices x and y
Output: nxn matrix, product of x and y
"""
# Base case when size of matrices is 1x1
if len(x) == 1:
return x * y
# Splitting the matrices into quadrants. This will be done recursively
# until the base case is reached.
a, b, c, d = split(x)
e, f, g, h = split(y)
# Computing the 7 products, recursively (p1, p2...p7)
p1 = strassen(a, f - h)
p2 = strassen(a + b, h)
p3 = strassen(c + d, e)
p4 = strassen(d, g - e)
p5 = strassen(a + d, e + h)
p6 = strassen(b - d, g + h)
p7 = strassen(a - c, e + f)
# Computing the values of the 4 quadrants of the final matrix c
c11 = p5 + p4 - p2 + p6
c12 = p1 + p2
c21 = p3 + p4
c22 = p1 + p5 - p3 - p7
# Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))
return c
def main():
with open('matrix_A.txt', 'r') as f:
flatMatrix = f.read()
with open('matrix_B.txt', 'r') as f:
flatMatrix2 = f.read()
numbers = re.compile("-?\d+")
result = list(map(int, numbers.findall(flatMatrix)))
result2 = list(map(int, numbers.findall(flatMatrix2)))
matrix_dimension = int(math.sqrt(len(result)))
matrixA = np.array(result).reshape(matrix_dimension, matrix_dimension)
matrixB = np.array(result2).reshape(matrix_dimension, matrix_dimension)
if matrix_dimension % 2 != 0:
matrix_dimension = matrix_dimension + 1
matrixA = np.pad(matrixA, [(0, 1), (0, 1)], mode='constant', constant_values=0)
matrixB = np.pad(matrixB, [(0, 1), (0, 1)], mode='constant', constant_values=0)
print(matrixA)
print(matrixB)
print(strassen(matrixA, matrixB))
print("The sum is: ", strassen(matrixA, matrixB).sum())
if __name__ == '__main__':
main()
1条答案
按热度按时间lqfhib0f1#
这段代码找到了2的下一个最大幂,并将矩阵填充到该幂。我相信这给了你想要的答案。
根据您的测试数据,这将产生185。