슈트라센 행렬 곱셈

  • 단순 참고용으로만 공부함 (행렬 계산은 어려워…)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    def strassen (A, B):
    n = len(A)
    if (n <= threshold):
    return matrixmult(A,B)
    A11, A12, A21, A22 = divide(A)
    B11, B12, B21, B22 = divide(B)
    M1 = strassen(madd(A11, A22), madd(B11, B22))
    M2 = strassen(madd(A21, A22), B11)
    M3 = strassen(A11, msub(B12, B22))
    M4 = strassen(A22, madd(B21, B11))
    M5 = strassen(madd(A11, A12), B22)
    M6 = strassen(madd(A21, A11), madd(B11, B12))
    M7 = strassen(madd(A12, A22), madd(B21, B22))
    return conquer(M1, M2, M3, M4, M5, M6, M7)

    def divide (A):
    n = len(A)
    m = n // 2
    A11 = [[0] * m for _ in range(m)]
    A12 = [[0] * m for _ in range(m)]
    A21 = [[0] * m for _ in range(m)]
    A22 = [[0] * m for _ in range(m)]
    for i in range(m):
    for j in range(m):
    A11[i][j] = A[i][j]
    A12[i][j] = A[i][j + m]
    A21[i][j] = A[i + m][j]
    A22[i][j] = A[i + m][j + m]
    return A11, A12, A21, A22

    def conquer(M1, M2, M3, M4, M5, M6, M7):
    C11 = madd(msub(madd(M1, M4), M5), M7)
    C12 = madd(M3, M5)
    C21 = madd(M2, M4)
    C22 = madd(msub((madd(M1, M3), M2), M6)
    m = len(C11)
    n = 2 * m
    C = [[0] * n for _ in range(n)]
    for i in range(m):
    for j in range(m):
    C[i][j] = C11[i][j]
    C[i][j + m] = C12[i][j]
    C[i + m][j] = C21[i][j]
    C[i + m][j + m] = C22[i][j]
    return C

    def madd(A, B):
    n = len(A)
    C = [[0] * n for _ in range(n)]
    for i in range(n):
    for j in range(n):
    C[i][j] = A[i][j] + B[i][j]
    return C

    def msub (A, B):
    n = len(A)
    C = [[0] * n for _ in range(n)]
    for i in range(n):
    for j in range(n):
    C[i][j] = A[i][j] - B[i][j]
    return C

    def matrixmult (A, B):
    n = len(A)
    C = [[0] * n for _ in range(n)]
    for i in range(n):
    for j in range(n):
    for k in range(n):
    C[i][j] += A[i][k] * B[k][j]
    return C
    자바스크립트에서 행렬계산 (참고용)
  • map과 reduce를 사용하여 iterable 객체를 돌며 행렬을 계산
  • 단순 덧셈은 하나의 행렬을 돌며 두번째 인수로 인덱스를 받아 double for문 처럼 사용
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13

    function sumMatrix (arr1, arr2) {
    return arr1.map((arr, i) => arr.map((v, j) => v + arr2[i][j]))
    }

    function multiplyMatrix(arr1, arr2) {
    return arr1.map(
    row => row.map(
    // x = arr[0] 값, y = index 값
    (_, i) => row.reduce(
    // reduce는 (acc, cur, index)
    (sum, cell, j) => sum + cell * arr2[j][i], 0))) }