2-4. Strassen의 행렬 곱셈

  • 행렬 곱셈의 정의에 충실하게 구현할 경우, 시간 복잡도 Θ(𝑛 ^ 3 )
A = [
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 1, 2, 3],
    [4, 5, 6, 7],
]

B = [
    [8, 9, 1, 2],
    [3, 4, 5, 6],
    [7, 8, 9, 1],
    [2, 3, 4, 5],
]

def matrixmult(A, B):
    N = len(A)
    C = [[0] * N for _ in range(N)]
    for i in range(N):
        for j in range(N):
            for k in range(N):
                C[i][j] += A[i][k] * B[k][j]
    return C

result = matrixmult(A, B)

for line in result:
    print(*line)
  • 쉬트라센의 방법을 사용할 경우, 시간 복잡도 Θ(𝑛 ^ 2.81)

A = [
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 1, 2, 3],
    [4, 5, 6, 7],
]

B = [
    [8, 9, 1, 2],
    [3, 4, 5, 6],
    [7, 8, 9, 1],
    [2, 3, 4, 5],
]

def matrixmult (A, B):
    N = len(A)
    C = [[0] * N for _ in range(N)]
    for i in range(N):
        for j in range(N):
            for k in range(N):
                C[i][j] += A[i][k] * B[k][j]
    return C

def divide(A):
    N = len(A)
    M = N // 2
    A11 = [[0] * M for _ in range(M)]
    A12 = [[0] * M for _ in range(M)]
    A21 = [[0] * M for _ in range(M)]
    A22 = [[0] * M for _ in range(M)]
    for i in range(M):
        for j in range(M):
            A11[i][j] = A[i][j]
            A12[i][j] = A[i][M + j]
            A21[i][j] = A[M + i][j]
            A22[i][j] = A[M + i][M + j]
    return A11, A12, A21, A22

def madd(A, B):
    N = len(A)
    C = [[0] * N for _ in range(N)]
    for i in range(N):
        for j in range(N):
            C[i][j] = A[i][j] + B[i][j]
    return C

def msub(A, B):
    N = len(A)
    C = [[0] * N for _ in range(N)]
    for i in range(N):
        for j in range(N):
            C[i][j] = A[i][j] - B[i][j]
    return C

def conquer(M1, M2, M3, M4, M5, M6, M7):
    C11 = madd(msub(madd(M1, M4), M5), M7)
    C12 = madd(M3, M5)
    C21 = madd(M2, M4)
    C22 = madd(msub(madd(M1, M3), M2), M6)
    M = len(C11)
    N = M * 2
    C = [[0] * N for _ in range(N)]
    for i in range(M):
        for j in range(M):
            C[i][j] = C11[i][j]
            C[i][M + j] = C12[i][j]
            C[M + i][j] = C21[i][j]
            C[M + i][M + j] = C22[i][j]
    return C

def strassen(A, B):
    N = len(A)
    if N <= threshold:
        return matrixmult(A, B)
    A11, A12, A21, A22 = divide(A)
    B11, B12, B21, B22 = divide(B)
    M1 = strassen(madd(A11, A22), madd(B11, B22))
    M2 = strassen(madd(A21, A22), B11)
    M3 = strassen(A11, msub(B12, B22))
    M4 = strassen(A22, msub(B21, B11))
    M5 = strassen(madd(A11, A12), B22)
    M6 = strassen(msub(A21, A11), madd(B11, B12))
    M7 = strassen(msub(A12, A22), madd(B21, B22))
    return conquer(M1, M2, M3, M4, M5, M6, M7)

threshold = 2
result = strassen(A, B)

for line in result:
    print(*line)

2-5. 큰 정수의 산술

  • 특정 컴퓨터/언어가 표현할 수 없는 큰 정수의 산술 연산
  • 10진수를 소프트웨어적으로 표현하는 방법: 리스트를 이용하여 각 자리수를 하나의 원소로 저장
    ex) 567,832 > lst = [2, 3, 8, 7, 6, 5]

1) 덧셈: n개의 자리수를 각각 더하면서 올림수를 고려

def largeadd(A, B):
    N = len(A) if len(A) > len(B) else len(B)
    result = []
    carry = 0
    for k in range(N):
        i = A[k] if k < len(A) else 0
        j = B[k] if k < len(B) else 0
        value = i + j + carry
        carry = value // 10
        result.append(value % 10)
    if carry > 0:
        result.append(carry)
    return result

A = [3, 2, 1]
B = [5, 4]

result = largeadd(A, B)
print(result[::-1])

2) 곱셈 --- 분할정복으로 구현

  • n개의 자리수로 된 숫자를 n/2개의 자리수로 분할(n이 홀수일 경우 분할된 값의 자리수가 다를 수 있음
    ex) 567,832 = 567 * 10 ^ 3 + 832
          9,423,723 = 9,423 * 10 ^ 3 + 723 --- 10의 지수 m은 n/2
  • 두 개의 정수를 분할하여 곱함
      A = (x * 10 ^ m + y)(w * 10 ^ m + z) = xw * 10 ^ 2m + (zx + yw) * 10 ^ m + yz
  • 재귀호출을 4번해서 시간 복잡도가 전통 방식과 같음(n ^ 2)
def largeadd(A, B):
    N = len(A) if len(A) > len(B) else len(B)
    result = []
    carry = 0
    for k in range(N):
        i = A[k] if k < len(A) else 0
        j = B[k] if k < len(B) else 0
        value = i + j + carry
        carry = value // 10
        result.append(value % 10)
    if carry > 0:
        result.append(carry)
    return result

def largemult(A, B): # 전통적인 방법으로 곱셈
    i = A[0] if 0 < len(A) else 0
    j = B[0] if 0 < len(B) else 0
    value = i * j
    carry = value // 10
    result = []
    result.append(value % 10)
    if carry > 0:
        result.append(carry)
    return result

def exp(A, M): # A * (10 ** M)
    if A == 0:
        return [0]
    else:
        return M * [0] + A

def div(A, M): # A // (10 ** M)
    if len(A) < M:
        A.append(0)
    return A[M:]

def rem(A, M): # A % (10 ** M)
    if len(A) < M:
        A.append(0)
    return A[:M]

def prod(A, B):
    N = len(A) if len(A) > len(B) else len(B)
    if len(A) == 0 or len(B) == 0:
        return [0]
    elif N <= threshold:
        return largemult(A, B)
    else:
        M = N // 2
        X = div(A, M)
        Y = rem(A, M)
        W = div(B, M)
        Z = rem(B, M)
        p1 = prod(X, W)
        p2 = largeadd(prod(X, Z), prod(W, Y))
        p3 = prod(Y, Z)
        return largeadd(largeadd(exp(p1, 2*M), exp(p2, M)), p3)

# 임계값, 특정 자리수까지
threshold = 1

A = [4, 3, 2, 1]
B = [9, 8, 7, 6, 5]

print(exp(A, 3))
print(div(A, 3))
print(rem(A, 3))
print(prod(A, B)[::-1])

3) 곱셈 --- 효율성 개선

  • 재귀호출의 횟수를 3번으로 줄임
    A = xw * 10 ^ 2m + (zx + yw) * 10 ^ m + yz
    R = (x + y)(w + z) = xw + (zx + yw) + yz
    (zx + yw) = R - xw - yz
def largeadd(A, B):
    N = len(A) if len(A) > len(B) else len(B)
    result = []
    carry = 0
    for k in range(N):
        i = A[k] if k < len(A) else 0
        j = B[k] if k < len(B) else 0
        value = i + j + carry
        carry = value // 10
        result.append(value % 10)
    if carry > 0:
        result.append(carry)
    return result

def largesub(A, B):
    N = len(A) if len(A) > len(B) else len(B)
    result = []
    borrow = 0
    for k in range(N):
        i = A[k] if k < len(A) else 0
        j = B[k] if k < len(B) else 0
        value = i - j + borrow
        if value < 0:
            value += 10
            borrow = -1
        else:
            borrow = 0
        result.append(value % 10)
    if borrow < 0:
        print('음의 정수는 처리할 수 없음')
    return result

def largemult(A, B): # 전통적인 방법으로 곱셈
    i = A[0] if 0 < len(A) else 0
    j = B[0] if 0 < len(B) else 0
    value = i * j
    carry = value // 10
    result = []
    result.append(value % 10)
    if carry > 0:
        result.append(carry)
    return result

def exp(A, M): # A * (10 ** M)
    if A == 0:
        return [0]
    else:
        return M * [0] + A

def div(A, M): # A // (10 ** M)
    if len(A) < M:
        A.append(0)
    return A[M:]

def rem(A, M): # A % (10 ** M)
    if len(A) < M:
        A.append(0)
    return A[:M]

def prod(A, B):
    N = len(A) if len(A) > len(B) else len(B)
    if len(A) == 0 or len(B) == 0:
        return [0]
    elif N <= threshold:
        return largemult(A, B)
    else:
        M = N // 2
        X = div(A, M)
        Y = rem(A, M)
        W = div(B, M)
        Z = rem(B, M)
        R = prod(largeadd(X, Y), largeadd(W, Z))
        p1 = prod(X, W)
        p3 = prod(Y, Z)
        p2 = largesub(R, largeadd(p1, p3))
        return largeadd(largeadd(exp(p1, 2*M), exp(p2, M)), p3)

# 임계값, 특정 자리수까지
threshold = 1

A = [4, 3, 2, 1]
B = [9, 8, 7, 6, 5]

print(prod(A, B)[::-1])

 

 

 

※ 강의 링크: https://inf.run/v8Rn

+ Recent posts