numpy Python斯特拉森算法在连接时出错

ryhaxcpt  于 12个月前  发布在  Python
关注(0)|答案(1)|浏览(100)

我引用了斯特拉森的例子:https://www.geeksforgeeks.org/strassens-matrix-multiplication/如果矩阵的维数是奇数,我用1行1列的0填充它。我在5x5矩阵上得到了这个错误,代码将其转换为6x6填充0:

Traceback (most recent call last):
  File "/home/surfacepro/Downloads/strassen.py", line 82, in <module>
    main()
  File "/home/surfacepro/Downloads/strassen.py", line 77, in main
    print(strassen(matrixA, matrixB))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/surfacepro/Downloads/strassen.py", line 33, in strassen
    p1 = strassen(a, f - h)
         ^^^^^^^^^^^^^^^^^^
  File "/home/surfacepro/Downloads/strassen.py", line 35, in strassen
    p3 = strassen(c + d, e)
         ^^^^^^^^^^^^^^^^^^
  File "/home/surfacepro/Downloads/strassen.py", line 49, in strassen
    c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/numpy/core/shape_base.py", line 289, in vstack
    return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 1 and the array at index 1 has size 0

matrix_A.txt:https://pastebin.com/9vMKEjJd
matrix_B.txt:https://pastebin.com/F7Xs4Ciz

# Version 3.6

import numpy as np
import re, math

def split(matrix):
    """
    Splits a given matrix into quarters.
    Input: n1 xn matrix
    Output: tuple containing 4 n/2 x n/2 matrices corresponding to a, b, c, d
    """
    row, col = matrix.shape
    row2, col2 = row//2, col//2
    return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]

def strassen(x, y):
    """
    Computes matrix product by divide and conquer approach, recursively.
    Input: nxn matrices x and y
    Output: nxn matrix, product of x and y
    """

    # Base case when size of matrices is 1x1
    if len(x) == 1:
        return x * y

    # Splitting the matrices into quadrants. This will be done recursively
    # until the base case is reached.
    a, b, c, d = split(x)
    e, f, g, h = split(y)

    # Computing the 7 products, recursively (p1, p2...p7)
    p1 = strassen(a, f - h)
    p2 = strassen(a + b, h) 
    p3 = strassen(c + d, e) 
    p4 = strassen(d, g - e) 
    p5 = strassen(a + d, e + h) 
    p6 = strassen(b - d, g + h)
    p7 = strassen(a - c, e + f)

    # Computing the values of the 4 quadrants of the final matrix c
    c11 = p5 + p4 - p2 + p6
    c12 = p1 + p2       
    c21 = p3 + p4       
    c22 = p1 + p5 - p3 - p7

    # Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
    c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))

    return c

def main():
    with open('matrix_A.txt', 'r') as f:
        flatMatrix = f.read()

    with open('matrix_B.txt', 'r') as f:
        flatMatrix2 = f.read()

    numbers = re.compile("-?\d+")
    result = list(map(int, numbers.findall(flatMatrix)))
    result2 = list(map(int, numbers.findall(flatMatrix2)))

    matrix_dimension = int(math.sqrt(len(result)))

    matrixA = np.array(result).reshape(matrix_dimension, matrix_dimension)
    matrixB = np.array(result2).reshape(matrix_dimension, matrix_dimension)

    if matrix_dimension % 2 != 0:
        matrix_dimension = matrix_dimension + 1
        matrixA = np.pad(matrixA, [(0, 1), (0, 1)], mode='constant', constant_values=0)
        matrixB = np.pad(matrixB, [(0, 1), (0, 1)], mode='constant', constant_values=0)

    print(matrixA)
    print(matrixB)
    
    print(strassen(matrixA, matrixB))
    print("The sum is: ", strassen(matrixA, matrixB).sum())

if __name__ == '__main__':
    main()
lqfhib0f

lqfhib0f1#

这段代码找到了2的下一个最大幂,并将矩阵填充到该幂。我相信这给了你想要的答案。

pow2 = 1
    while pow2 <= matrix_dimension:
        pow2 <<= 1

    if matrix_dimension != pow2:
        padmd = pow2-matrix_dimension
        matrixA = np.pad(matrixA, ((0,padmd),(0,padmd)), 'constant')
        matrixB = np.pad(matrixB, ((0,padmd),(0,padmd)), 'constant')

根据您的测试数据,这将产生185。

相关问题