1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/python3
 
# SI 335: Computer Algorithms
# Unit 4
 
import random
 
def add(X, Y, B):
    assert(len(X) >= len(Y))
    carry = 0
    A = [0] * (len(X) + 1)
    for i in range(0, len(Y)):
        carry, A[i] = divmod(X[i] + Y[i] + carry, B)
    for i in range(len(Y), len(X)):
        carry, A[i] = divmod(X[i] + carry, B)
    A[len(X)] = carry
    return A
 
def sub(X, Y, B):
    assert(len(X) >= len(Y))
    carry = 0
    A = [0] * len(X) 
    for i in range(0, len(Y)):
        carry, A[i] = divmod(X[i] - Y[i] + carry, B)
    for i in range(len(Y), len(X)):
        carry, A[i] = divmod(X[i] + carry, B)
    assert(carry == 0)
    return A
 
def smul(X, Y, B):
    assert(len(X) == len(Y))
    n = len(X)
    A = [0] * (2*n)
    T = [0] * n
    for i in range(0, n):
        # set T = X * Y[i]
        carry = 0
        for j in range(0, n):
            T[j] = (X[j] * Y[i] + carry) % B
            carry = (X[j] * Y[i] + carry) // B
        # add T to A, the running sum
        A[i : i+n+1] = add(A[i : i+n], T[0 : n], B)
        A[i+n] += carry
    return A
 
def kmul(X, Y, B):
    assert(len(X) == len(Y))
    n = len(X)
    if n <= 3:
        return smul(X, Y, B)
    else:
        m = n // 2
        X0, X1 = X[0 : m], X[m : n]
        Y0, Y1 = Y[0 : m], Y[m : n]
        U = add(X1, X0, B)
        V = add(Y1, Y0, B)
        P0 = kmul(X0, Y0, B)
        P1 = kmul(X1, Y1, B)
        P2 = kmul(U, V, B)
        A = [0] * (2*n + 1)
        A[0 : 2*m] = P0
        A[2*m : 2*n] = P1
        A[m : 2*n+1] = add(A[m : 2*n], P2, B)
        A[m : 2*n+1] = sub(A[m : 2*n+1], P0, B)
        A[m : 2*n+1] = sub(A[m : 2*n+1], P1, B)
        assert(A[2*n] == 0)
        return A[0 : 2*n]
 
def fib(n):
    if n <= 1: 
        return n
    else:
        return fib(n-1) + fib(n-2)
 
fib_table = {}
def fib_memo(n):
    if n not in fib_table:
        if n <= 1:
            return n
        else:
            fib_table[n] = fib_memo(n-1) + fib_memo(n-2)
    return fib_table[n]
 
# Note: the *'s are so that the arguments to this function
# get interpreted as a tuple. A tuple is basically a list that
# can't be changed. So for example, mm(5,2,6,3) is the
# correct way to call this function (and it returns 66).
def mm(*D):
    n = len(D) - 1
    if n == 1:
        return 0
    else:
        fewest = float('inf') # (just a placeholder)
        for i in range(1, n):
            t = mm(*D[0 : i+1]) + D[0]*D[i]*D[n] + mm(*D[i : n+1])
            if t < fewest:
                fewest = t
        return fewest
 
mm_table = {}
def mmm(*D):
    n = len(D) - 1
    if D not in mm_table:
        if n == 1:
            mm_table[D] = 0
        else:
            fewest = float('inf')
            for i in range(1, n):
                t = mmm(*D[0 : i+1]) + D[0]*D[i]*D[n] + mmm(*D[i : n+1])
                if t < fewest:
                    fewest = t
            mm_table[D] = fewest
    return mm_table[D]
 
def dmm(*D):
    n = len(D) - 1
    # A will be a (n+1) by (n+1) array
    A = [[0] * (n+1) for i in range(n+1)]
    for diag in range(1,n+1):
        for row in range(0, n-diag+1):
            col = diag + row
            # This part is just like the original!
            if diag == 1:
                A[row][col] = 0
            else:
                A[row][col] = float('inf')
                for i in range(row+1, col):
                    t = A[row][i] + D[row]*D[i]*D[col] + A[i][col]
                    if t < A[row][col]:
                        A[row][col] = t
    printTable(A)
    return A[0][n]
 
 
# The rest is just for testing purposes
 
def toDigits(n):
    return list(map(int,reversed(str(n))))
 
def fromDigits(X):
    return int(''.join(map(str,reversed(X))))
 
def printTable(A):
    def printRow(r):
        for x in r:
            print("{:>4} ".format(x), end="")
    print(" " * (3 + 2), end="")
    printRow(range(len(A[0])))
    print()
    for i in range(len(A)):
        print("{:>3} [".format(i), end="")
        printRow(A[i])
        print("]")