Skip to content

Commit

Permalink
Unify subsieves
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodigrim committed Dec 7, 2019
1 parent d58c5f2 commit c260ca6
Showing 1 changed file with 34 additions and 70 deletions.
104 changes: 34 additions & 70 deletions Math/NumberTheory/Primes/Sieve/Atkin.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module Math.NumberTheory.Primes.Sieve.Atkin
import Control.Monad
import Control.Monad.ST
import Data.Bit
import Data.Bits
import Data.Maybe
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
Expand All @@ -30,57 +31,16 @@ import Math.NumberTheory.Utils
atkinPrimeList :: PrimeSieve -> [Int]
atkinPrimeList (PrimeSieve low len segments)
| len60 == 0 = []
| otherwise = takeWhile (< high) $ dropWhile (< low) $ 2 : 3 : 5 : merge l0 l1
| otherwise = takeWhile (< high) $ dropWhile (< low) $ 2 : 3 : 5 : map fromWheel30 (listBits segments)
where
list00 = map (\k -> 60 * (low60 + k) + fromWheel30 0) (listBits $ segments V.! 0)
list01 = map (\k -> 60 * (low60 + k) + fromWheel30 1) (listBits $ segments V.! 1)
list02 = map (\k -> 60 * (low60 + k) + fromWheel30 2) (listBits $ segments V.! 2)
list03 = map (\k -> 60 * (low60 + k) + fromWheel30 3) (listBits $ segments V.! 3)
list04 = map (\k -> 60 * (low60 + k) + fromWheel30 4) (listBits $ segments V.! 4)
list05 = map (\k -> 60 * (low60 + k) + fromWheel30 5) (listBits $ segments V.! 5)
list06 = map (\k -> 60 * (low60 + k) + fromWheel30 6) (listBits $ segments V.! 6)
list07 = map (\k -> 60 * (low60 + k) + fromWheel30 7) (listBits $ segments V.! 7)
list08 = map (\k -> 60 * (low60 + k) + fromWheel30 8) (listBits $ segments V.! 8)
list09 = map (\k -> 60 * (low60 + k) + fromWheel30 9) (listBits $ segments V.! 9)
list10 = map (\k -> 60 * (low60 + k) + fromWheel30 10) (listBits $ segments V.! 10)
list11 = map (\k -> 60 * (low60 + k) + fromWheel30 11) (listBits $ segments V.! 11)
list12 = map (\k -> 60 * (low60 + k) + fromWheel30 12) (listBits $ segments V.! 12)
list13 = map (\k -> 60 * (low60 + k) + fromWheel30 13) (listBits $ segments V.! 13)
list14 = map (\k -> 60 * (low60 + k) + fromWheel30 14) (listBits $ segments V.! 14)
list15 = map (\k -> 60 * (low60 + k) + fromWheel30 15) (listBits $ segments V.! 15)

lst0 = merge list00 list01
lst1 = merge list02 list03
lst2 = merge list04 list05
lst3 = merge list06 list07
lst4 = merge list08 list09
lst5 = merge list10 list11
lst6 = merge list12 list13
lst7 = merge list14 list15

ls0 = merge lst0 lst1
ls1 = merge lst2 lst3
ls2 = merge lst4 lst5
ls3 = merge lst6 lst7

l0 = merge ls0 ls1
l1 = merge ls2 ls3

low60 = low `quot` 60
len60 = (low + len + 59) `quot` 60 - low60
high = low + len

merge :: Ord a => [a] -> [a] -> [a]
merge [] ys = ys
merge xs [] = xs
merge xs@(x:xs') ys@(y:ys')
| x < y = x : merge xs' ys
| otherwise = y : merge xs ys'

data PrimeSieve = PrimeSieve
{ _psLowBound :: !Int
, _psLength :: !Int
, _psSegments :: V.Vector (U.Vector Bit)
, _psSegments :: !(U.Vector Bit)
} deriving (Show)

atkinSieve
Expand All @@ -91,26 +51,24 @@ atkinSieve low len = PrimeSieve low len segments
where
low60 = low `quot` 60
len60 = (low + len + 59) `quot` 60 - low60
params = V.generate 16 (\i -> SieveParams (fromWheel30 i) low60 len60)
segments = V.map sieveSegment params
segments = sieveSegment low60 len60

data SieveParams = SieveParams
{ spDelta :: !Int
, spLowBound :: !Int
, spLength :: !Int
} deriving (Show)

spHighBound :: SieveParams -> Int
spHighBound sp = spLowBound sp + spLength sp

sieveSegment
:: SieveParams
:: Int
-> Int
-> U.Vector Bit
sieveSegment sp = runST $ do
vec <- MU.new (spLength sp)
U.forM_ (fgs V.! toWheel30 (spDelta sp)) $
traverseLatticePoints sp vec
algo3steps456 sp vec
sieveSegment low60 len60 = runST $ do
vec <- MU.new (len60 `shiftL` 4)
forM_ [0..15] $ \i ->
U.forM_ (fgs V.! i) $
traverseLatticePoints (SieveParams (fromWheel30 i) low60 len60) vec
algo3steps456 low60 len60 vec
U.unsafeFreeze vec

-- | Solutions of k * f^2 + l * g^2 = delta (mod 60)
Expand Down Expand Up @@ -178,7 +136,8 @@ traverseLatticePoints1 !sp vec (!x0, !y0) =
-- Step 6
doActions (!k, !y)
| k < spLength sp
= unsafeFlipBit vec k >> doActions (forwardY (k, y))
= unsafeFlipBit vec (k `shiftL` 4 + toWheel30 (spDelta sp))
>> doActions (forwardY (k, y))
| otherwise
= pure ()

Expand Down Expand Up @@ -223,7 +182,8 @@ traverseLatticePoints2 sp vec (x0, y0) =
-- Step 6
doActions (!k, !y)
| k < spLength sp
= unsafeFlipBit vec k >> doActions (forwardY (k, y))
= unsafeFlipBit vec (k `shiftL` 4 + toWheel30 (spDelta sp))
>> doActions (forwardY (k, y))
| otherwise
= pure ()

Expand Down Expand Up @@ -252,7 +212,8 @@ traverseLatticePoints3 sp vec (x0, y0) =
-- Step 6
doActions (!k, !x, !y)
| k >= 0 && y < x
= unsafeFlipBit vec k >> (let (k', y') = forwardY (k, y) in doActions (k', x, y'))
= unsafeFlipBit vec (k `shiftL` 4 + toWheel30 (spDelta sp))
>> (let (k', y') = forwardY (k, y) in doActions (k', x, y'))
| otherwise
= pure ()

Expand All @@ -271,33 +232,36 @@ traverseLatticePoints3 sp vec (x0, y0) =

-- | Perform steps 4-6 of Algorithm 3.X.
algo3steps456
:: SieveParams
:: Int
-> Int
-> MU.MVector s Bit
-> ST s ()
algo3steps456 sp vec =
algo3steps456 low60 len60 vec =
forM_ ps $ \p ->
crossMultiples sp vec (p * p)
crossMultiples low60 len60 vec (p * p)
where
low = 7
high = integerSquareRoot (60 * spHighBound sp - 1)
high = integerSquareRoot (60 * (low60 + len60) - 1)
ps = takeWhile (<= high) $ dropWhile (< low) $ map unPrime E.primes

-- | Cross out multiples of the first argument
-- in a given sieve.
crossMultiples
:: SieveParams
:: Int
-> Int
-> MU.MVector s Bit
-> Int -- coprime with 60
-> ST s ()
crossMultiples sp vec m =
forM_ [k1, k1 + m .. spHighBound sp - 1] $
\k -> MU.unsafeWrite vec (k - spLowBound sp) (Bit False)
crossMultiples low60 len60 vec m =
forM_ [0..15] $ \i -> do
-- k0 is the smallest non-negative k such that 60k+delta = 0 (mod m)
let k0 = solveCongruence (fromWheel30 i) m
-- k1 = k0 (mod m), k1 >= lowBound
k1 = if r < k0 then q * m + k0 else (q + 1) * m + k0
forM_ [k1, k1 + m .. (low60 + len60) - 1] $
\k -> MU.unsafeWrite vec ((k - low60) `shiftL` 4 + i) (Bit False)
where
-- k0 is the smallest non-negative k such that 60k+delta = 0 (mod m)
k0 = solveCongruence (spDelta sp) m
-- k1 = k0 (mod m), k1 >= lowBound
(q, r) = spLowBound sp `quotRem` m
k1 = if r < k0 then q * m + k0 else (q + 1) * m + k0
(q, r) = low60 `quotRem` m

-- Find the smallest k such that 60k+delta = 0 (mod m)
-- Should be equal to
Expand Down

0 comments on commit c260ca6

Please sign in to comment.