-
Notifications
You must be signed in to change notification settings - Fork 0
/
roots.py
222 lines (202 loc) · 6.41 KB
/
roots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#!/usr/bin/python3
#
# EECS 388 Project 1 functions
#
# This source code taken from mpmath by Fredrik Johnson
# https://raw.githubusercontent.com/fredrik-johansson/mpmath/43b5c9bffeaa3ec2909b9bd10175c140c056f89a/mpmath/libmp/libintmath.py
# Used under fair-use.
from bisect import bisect as _bisect
from codecs import decode
powers = [1<<_ for _ in range(300)]
def _trailing(n):
"""Count the number of trailing zero bits in abs(n)."""
if not n:
return 0
t = 0
while not n & 1:
n >>= 1
t += 1
return t
def _bitcount(n):
"""Calculate bit size of the nonnegative integer n."""
bc = _bisect(powers, n)
if bc != 300:
return bc
bc = int(math.log(n, 2)) - 4
return bc + bctable[n>>bc]
trailtable = [_trailing(n) for n in range(256)]
bctable = [_bitcount(n) for n in range(1024)]
_1_800 = 1<<800
_1_600 = 1<<600
_1_400 = 1<<400
_1_200 = 1<<200
_1_100 = 1<<100
_1_50 = 1<<50
def _isqrt_small_python(x):
"""
Correctly (floor) rounded integer square root, using
division. Fast up to ~200 digits.
"""
if not x:
return x
if x < _1_800:
# Exact with IEEE double precision arithmetic
if x < _1_50:
return int(x**0.5)
# Initial estimate can be any integer >= the true root; round up
r = int(x**0.5 * 1.00000000000001) + 1
else:
bc = bitcount(x)
n = bc//2
r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up
# The following iteration now precisely computes floor(sqrt(x))
# See e.g. Crandall & Pomerance, "Prime Numbers: A Computational
# Perspective"
while 1:
y = (r+x//r)>>1
if y >= r:
return r
r = y
def _isqrt_fast_python(x):
"""
Fast approximate integer square root, computed using division-free
Newton iteration for large x. For random integers the result is almost
always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly
0.1% probability. If x is very close to an exact square, the answer is
1 ulp wrong with high probability.
With 0 guard bits, the largest error over a set of 10^5 random
inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits
almost certainly guarantees a max 1 ulp error.
"""
# Use direct division-based iteration if sqrt(x) < 2^400
# Assume floating-point square root accurate to within 1 ulp, then:
# 0 Newton iterations good to 52 bits
# 1 Newton iterations good to 104 bits
# 2 Newton iterations good to 208 bits
# 3 Newton iterations good to 416 bits
if x < _1_800:
y = int(x**0.5)
if x >= _1_100:
y = (y + x//y) >> 1
if x >= _1_200:
y = (y + x//y) >> 1
if x >= _1_400:
y = (y + x//y) >> 1
return y
bc = _bitcount(x)
guard_bits = 10
x <<= 2*guard_bits
bc += 2*guard_bits
bc += (bc&1)
hbc = bc//2
startprec = min(50, hbc)
# Newton iteration for 1/sqrt(x), with floating-point starting value
r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5)
pp = startprec
for p in giant_steps(startprec, hbc):
# r**2, scaled from real size 2**(-bc) to 2**p
r2 = (r*r) >> (2*pp - p)
# x*r**2, scaled from real size ~1.0 to 2**p
xr2 = ((x >> (bc-p)) * r2) >> p
# New value of r, scaled from real size 2**(-bc/2) to 2**p
r = (r * ((3<<p) - xr2)) >> (pp+1)
pp = p
# (1/sqrt(x))*x = sqrt(x)
return (r*(x>>hbc)) >> (p+guard_bits)
def _sqrtrem_python(x):
"""Correctly rounded integer (floor) square root with remainder."""
# to check cutoff:
# plot(lambda x: timing(isqrt, 2**int(x)), [0,2000])
if x < _1_600:
y = _isqrt_small_python(x)
return y, x - y*y
y = _isqrt_fast_python(x) + 1
rem = x - y*y
# Correct remainder
while rem < 0:
y -= 1
rem += (1+2*y)
else:
if rem:
while rem > 2*(1+y):
y += 1
rem -= (1+2*y)
return y, rem
# The below code is taken from SymPy
# Used under fair-use.
# https://raw.githubusercontent.com/sympy/sympy/733da515a7638bba4e08be366bf24c996ad84a61/sympy/core/power.py
from math import log as _log
def integer_nthroot(y, n):
"""
Return a tuple containing x = floor(y**(1/n))
and a boolean indicating whether the result is exact (that is,
whether x**n == y).
>>> from sympy import integer_nthroot
>>> integer_nthroot(16,2)
(4, True)
>>> integer_nthroot(26,2)
(5, False)
"""
if not isinstance(y, int):
raise ValueError("y must be an integer")
if not isinstance(n, int):
raise ValueError("n must be an integer")
y, n = int(y), int(n)
if y < 0:
raise ValueError("y must be nonnegative")
if n < 1:
raise ValueError("n must be positive")
if y in (0, 1):
return y, True
if n == 1:
return y, True
if n == 2:
x, rem = _sqrtrem_python(y)
return int(x), not rem
if n > y:
return 1, False
# Get initial estimate for Newton's method. Care must be taken to
# avoid overflow
try:
guess = int(y**(1./n) + 0.5)
except OverflowError:
exp = _log(y, 2)/n
if exp > 53:
shift = int(exp - 53)
guess = int(2.0**(exp - shift) + 1) << shift
else:
guess = int(2.0**exp)
if guess > 2**50:
# Newton iteration
xprev, x = -1, guess
while 1:
t = x**(n - 1)
xprev, x = x, ((n - 1)*x + y//t)//n
if abs(x - xprev) < 2:
break
else:
x = guess
# Compensate
t = x**n
while t < y:
x += 1
t = x**n
while t > y:
x -= 1
t = x**n
return x, t == y
# Helper functions, for your convenience:
from base64 import b64encode, b64decode
def bytes_to_base64(b):
'''Converts bytes to a base64 string'''
return b64encode(b).decode()
def base64_to_bytes(s):
'''Converts base64 string to a byte array'''
return b64decode(s.encode())
def integer_to_bytes(n, size):
'''Converts an integer to size bytes in big-endian byte order'''
return int.to_bytes(n, size, byteorder='big')
def bytes_to_integer(b):
'''Converts big-endian bytes to an arbitrarily long integer'''
return int.from_bytes(b, byteorder='big')
assert(bytes_to_integer(base64_to_bytes(bytes_to_base64(integer_to_bytes(2**64-1,256))))==2**64-1)