2-4. Strassen의 행렬 곱셈
- 행렬 곱셈의 정의에 충실하게 구현할 경우, 시간 복잡도 Θ(𝑛 ^ 3 )
A = [
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 1, 2, 3],
[4, 5, 6, 7],
]
B = [
[8, 9, 1, 2],
[3, 4, 5, 6],
[7, 8, 9, 1],
[2, 3, 4, 5],
]
def matrixmult(A, B):
N = len(A)
C = [[0] * N for _ in range(N)]
for i in range(N):
for j in range(N):
for k in range(N):
C[i][j] += A[i][k] * B[k][j]
return C
result = matrixmult(A, B)
for line in result:
print(*line)
- 쉬트라센의 방법을 사용할 경우, 시간 복잡도 Θ(𝑛 ^ 2.81)

A = [
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 1, 2, 3],
[4, 5, 6, 7],
]
B = [
[8, 9, 1, 2],
[3, 4, 5, 6],
[7, 8, 9, 1],
[2, 3, 4, 5],
]
def matrixmult (A, B):
N = len(A)
C = [[0] * N for _ in range(N)]
for i in range(N):
for j in range(N):
for k in range(N):
C[i][j] += A[i][k] * B[k][j]
return C
def divide(A):
N = len(A)
M = N // 2
A11 = [[0] * M for _ in range(M)]
A12 = [[0] * M for _ in range(M)]
A21 = [[0] * M for _ in range(M)]
A22 = [[0] * M for _ in range(M)]
for i in range(M):
for j in range(M):
A11[i][j] = A[i][j]
A12[i][j] = A[i][M + j]
A21[i][j] = A[M + i][j]
A22[i][j] = A[M + i][M + j]
return A11, A12, A21, A22
def madd(A, B):
N = len(A)
C = [[0] * N for _ in range(N)]
for i in range(N):
for j in range(N):
C[i][j] = A[i][j] + B[i][j]
return C
def msub(A, B):
N = len(A)
C = [[0] * N for _ in range(N)]
for i in range(N):
for j in range(N):
C[i][j] = A[i][j] - B[i][j]
return C
def conquer(M1, M2, M3, M4, M5, M6, M7):
C11 = madd(msub(madd(M1, M4), M5), M7)
C12 = madd(M3, M5)
C21 = madd(M2, M4)
C22 = madd(msub(madd(M1, M3), M2), M6)
M = len(C11)
N = M * 2
C = [[0] * N for _ in range(N)]
for i in range(M):
for j in range(M):
C[i][j] = C11[i][j]
C[i][M + j] = C12[i][j]
C[M + i][j] = C21[i][j]
C[M + i][M + j] = C22[i][j]
return C
def strassen(A, B):
N = len(A)
if N <= threshold:
return matrixmult(A, B)
A11, A12, A21, A22 = divide(A)
B11, B12, B21, B22 = divide(B)
M1 = strassen(madd(A11, A22), madd(B11, B22))
M2 = strassen(madd(A21, A22), B11)
M3 = strassen(A11, msub(B12, B22))
M4 = strassen(A22, msub(B21, B11))
M5 = strassen(madd(A11, A12), B22)
M6 = strassen(msub(A21, A11), madd(B11, B12))
M7 = strassen(msub(A12, A22), madd(B21, B22))
return conquer(M1, M2, M3, M4, M5, M6, M7)
threshold = 2
result = strassen(A, B)
for line in result:
print(*line)
2-5. 큰 정수의 산술
- 특정 컴퓨터/언어가 표현할 수 없는 큰 정수의 산술 연산
- 10진수를 소프트웨어적으로 표현하는 방법: 리스트를 이용하여 각 자리수를 하나의 원소로 저장
ex) 567,832 > lst = [2, 3, 8, 7, 6, 5]
1) 덧셈: n개의 자리수를 각각 더하면서 올림수를 고려
def largeadd(A, B):
N = len(A) if len(A) > len(B) else len(B)
result = []
carry = 0
for k in range(N):
i = A[k] if k < len(A) else 0
j = B[k] if k < len(B) else 0
value = i + j + carry
carry = value // 10
result.append(value % 10)
if carry > 0:
result.append(carry)
return result
A = [3, 2, 1]
B = [5, 4]
result = largeadd(A, B)
print(result[::-1])
2) 곱셈 --- 분할정복으로 구현
- n개의 자리수로 된 숫자를 n/2개의 자리수로 분할(n이 홀수일 경우 분할된 값의 자리수가 다를 수 있음
ex) 567,832 = 567 * 10 ^ 3 + 832
9,423,723 = 9,423 * 10 ^ 3 + 723 --- 10의 지수 m은 n/2 - 두 개의 정수를 분할하여 곱함
A = (x * 10 ^ m + y)(w * 10 ^ m + z) = xw * 10 ^ 2m + (zx + yw) * 10 ^ m + yz - 재귀호출을 4번해서 시간 복잡도가 전통 방식과 같음(n ^ 2)
def largeadd(A, B):
N = len(A) if len(A) > len(B) else len(B)
result = []
carry = 0
for k in range(N):
i = A[k] if k < len(A) else 0
j = B[k] if k < len(B) else 0
value = i + j + carry
carry = value // 10
result.append(value % 10)
if carry > 0:
result.append(carry)
return result
def largemult(A, B): # 전통적인 방법으로 곱셈
i = A[0] if 0 < len(A) else 0
j = B[0] if 0 < len(B) else 0
value = i * j
carry = value // 10
result = []
result.append(value % 10)
if carry > 0:
result.append(carry)
return result
def exp(A, M): # A * (10 ** M)
if A == 0:
return [0]
else:
return M * [0] + A
def div(A, M): # A // (10 ** M)
if len(A) < M:
A.append(0)
return A[M:]
def rem(A, M): # A % (10 ** M)
if len(A) < M:
A.append(0)
return A[:M]
def prod(A, B):
N = len(A) if len(A) > len(B) else len(B)
if len(A) == 0 or len(B) == 0:
return [0]
elif N <= threshold:
return largemult(A, B)
else:
M = N // 2
X = div(A, M)
Y = rem(A, M)
W = div(B, M)
Z = rem(B, M)
p1 = prod(X, W)
p2 = largeadd(prod(X, Z), prod(W, Y))
p3 = prod(Y, Z)
return largeadd(largeadd(exp(p1, 2*M), exp(p2, M)), p3)
# 임계값, 특정 자리수까지
threshold = 1
A = [4, 3, 2, 1]
B = [9, 8, 7, 6, 5]
print(exp(A, 3))
print(div(A, 3))
print(rem(A, 3))
print(prod(A, B)[::-1])
3) 곱셈 --- 효율성 개선
- 재귀호출의 횟수를 3번으로 줄임
A = xw * 10 ^ 2m + (zx + yw) * 10 ^ m + yz
R = (x + y)(w + z) = xw + (zx + yw) + yz
(zx + yw) = R - xw - yz
def largeadd(A, B):
N = len(A) if len(A) > len(B) else len(B)
result = []
carry = 0
for k in range(N):
i = A[k] if k < len(A) else 0
j = B[k] if k < len(B) else 0
value = i + j + carry
carry = value // 10
result.append(value % 10)
if carry > 0:
result.append(carry)
return result
def largesub(A, B):
N = len(A) if len(A) > len(B) else len(B)
result = []
borrow = 0
for k in range(N):
i = A[k] if k < len(A) else 0
j = B[k] if k < len(B) else 0
value = i - j + borrow
if value < 0:
value += 10
borrow = -1
else:
borrow = 0
result.append(value % 10)
if borrow < 0:
print('음의 정수는 처리할 수 없음')
return result
def largemult(A, B): # 전통적인 방법으로 곱셈
i = A[0] if 0 < len(A) else 0
j = B[0] if 0 < len(B) else 0
value = i * j
carry = value // 10
result = []
result.append(value % 10)
if carry > 0:
result.append(carry)
return result
def exp(A, M): # A * (10 ** M)
if A == 0:
return [0]
else:
return M * [0] + A
def div(A, M): # A // (10 ** M)
if len(A) < M:
A.append(0)
return A[M:]
def rem(A, M): # A % (10 ** M)
if len(A) < M:
A.append(0)
return A[:M]
def prod(A, B):
N = len(A) if len(A) > len(B) else len(B)
if len(A) == 0 or len(B) == 0:
return [0]
elif N <= threshold:
return largemult(A, B)
else:
M = N // 2
X = div(A, M)
Y = rem(A, M)
W = div(B, M)
Z = rem(B, M)
R = prod(largeadd(X, Y), largeadd(W, Z))
p1 = prod(X, W)
p3 = prod(Y, Z)
p2 = largesub(R, largeadd(p1, p3))
return largeadd(largeadd(exp(p1, 2*M), exp(p2, M)), p3)
# 임계값, 특정 자리수까지
threshold = 1
A = [4, 3, 2, 1]
B = [9, 8, 7, 6, 5]
print(prod(A, B)[::-1])
※ 강의 링크: https://inf.run/v8Rn
'Computer Science > Algorithm' 카테고리의 다른 글
[파이썬으로 배우는 알고리즘 기초] 3. 동적 계획법 (2) (0) | 2022.06.07 |
---|---|
[파이썬으로 배우는 알고리즘 기초] 3. 동적 계획법 (1) (0) | 2022.06.07 |
[파이썬으로 배우는 알고리즘 기초] 2. 분할 정복법 (3) (0) | 2022.06.07 |
[파이썬으로 배우는 알고리즘 기초] 2. 분할 정복법 (1) (0) | 2022.06.01 |
[파이썬으로 배우는 알고리즘 기초] 1. 알고리즘의 개념 (0) | 2022.05.31 |