Skip to content

Commit

Permalink
Simplify and speed up prime sieve for size we care about.
Browse files Browse the repository at this point in the history
Note that the peak memory usage is modestly (50%) higher, but for the current value of _MAX_DIMENSION, that'll be around 50kb instead of 33 kb.

PiperOrigin-RevId: 563201131
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Sep 6, 2023
1 parent e3c3244 commit b23a82b
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,24 +362,16 @@ def _base_expansion_size(num, bases):


def _primes_less_than(n):
# Based on
# https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
"""Returns sorted array of primes such that `2 <= prime < n`."""
small_primes = np.array((2, 3, 5))
if n <= 6:
return small_primes[small_primes < n]
sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool_)
sieve[0] = False
m = int(n ** 0.5) // 3 + 1
for i in range(m):
if not sieve[i]:
continue
k = 3 * i + 1 | 1
sieve[k ** 2 // 3::2 * k] = False
sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False
return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1]
primes = np.ones((n + 1) // 2, dtype=bool)
j = 3
while j * j <= n:
if primes[j//2]:
primes[j*j//2::j] = False
j += 2
ret = 2 * np.where(primes)[0] + 1
ret[0] = 2 # :(
return ret

_PRIMES = _primes_less_than(104729 + 1)


assert len(_PRIMES) == _MAX_DIMENSION

0 comments on commit b23a82b

Please sign in to comment.