1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
| def strassen (A, B): n = len(A) if (n <= threshold): return matrixmult(A,B) A11, A12, A21, A22 = divide(A) B11, B12, B21, B22 = divide(B) M1 = strassen(madd(A11, A22), madd(B11, B22)) M2 = strassen(madd(A21, A22), B11) M3 = strassen(A11, msub(B12, B22)) M4 = strassen(A22, madd(B21, B11)) M5 = strassen(madd(A11, A12), B22) M6 = strassen(madd(A21, A11), madd(B11, B12)) M7 = strassen(madd(A12, A22), madd(B21, B22)) return conquer(M1, M2, M3, M4, M5, M6, M7)
def divide (A): n = len(A) m = n // 2 A11 = [[0] * m for _ in range(m)] A12 = [[0] * m for _ in range(m)] A21 = [[0] * m for _ in range(m)] A22 = [[0] * m for _ in range(m)] for i in range(m): for j in range(m): A11[i][j] = A[i][j] A12[i][j] = A[i][j + m] A21[i][j] = A[i + m][j] A22[i][j] = A[i + m][j + m] return A11, A12, A21, A22 def conquer(M1, M2, M3, M4, M5, M6, M7): C11 = madd(msub(madd(M1, M4), M5), M7) C12 = madd(M3, M5) C21 = madd(M2, M4) C22 = madd(msub((madd(M1, M3), M2), M6) m = len(C11) n = 2 * m C = [[0] * n for _ in range(n)] for i in range(m): for j in range(m): C[i][j] = C11[i][j] C[i][j + m] = C12[i][j] C[i + m][j] = C21[i][j] C[i + m][j + m] = C22[i][j] return C
def madd(A, B): n = len(A) C = [[0] * n for _ in range(n)] for i in range(n): for j in range(n): C[i][j] = A[i][j] + B[i][j] return C
def msub (A, B): n = len(A) C = [[0] * n for _ in range(n)] for i in range(n): for j in range(n): C[i][j] = A[i][j] - B[i][j] return C
def matrixmult (A, B): n = len(A) C = [[0] * n for _ in range(n)] for i in range(n): for j in range(n): for k in range(n): C[i][j] += A[i][k] * B[k][j] return C
|