MT19937.hs (3798B)
1 module Cryptopals.Stream.RNG.MT19937 ( 2 Gen 3 , seed 4 , extract 5 , tap 6 7 , clone 8 ) where 9 10 import qualified Control.Monad.ST as ST 11 import Data.Bits ((.&.)) 12 import qualified Data.Bits as B 13 import qualified Data.Vector.Unboxed as VU 14 import qualified Data.Vector.Unboxed.Mutable as VUM 15 import GHC.Word (Word32) 16 17 fi :: (Integral a, Num b) => a -> b 18 fi = fromIntegral 19 20 -- following notation in https://en.wikipedia.org/wiki/Mersenne_Twister 21 22 w, n, m, r, a, u, s, b, t, c, l :: Word32 23 w = 32 -- word size 24 n = 624 -- degree of recurrence 25 m = 397 -- 'middle term' 26 r = 31 -- word separation index 27 a = 0x9908B0DF -- rational normal form twist matrix coefficients 28 u = 11 -- tempering parameter 29 s = 7 -- tempering parameter (shift) 30 b = 0x9D2C5680 -- tempering parameter (mask) 31 t = 15 -- tempering parameter (shift) 32 c = 0xEFC60000 -- tempering parameter (mask) 33 l = 18 -- tempering parameter 34 35 f :: Word32 36 f = 1812433253 37 38 lm :: Word32 39 lm = B.shiftL 1 (fi r) - 1 -- 0b0111 1111 1111 1111 1111 1111 1111 1111 40 41 um :: Word32 42 um = B.complement lm -- 0b1000 0000 0000 0000 0000 0000 0000 0000 43 44 data Gen = Gen !Word32 !(VU.Vector Word32) 45 deriving Eq 46 47 instance Show Gen where 48 show Gen {} = "<MT19937.Gen>" 49 50 tap :: Int -> Gen -> ([Word32], Gen) 51 tap = loop mempty where 52 loop !acc j gen 53 | j == 0 = (reverse acc, gen) 54 | otherwise = 55 let (w, g) = extract gen 56 in loop (w : acc) (pred j) g 57 58 seed :: Word32 -> Gen 59 seed s = Gen n (loop 0 mempty) where 60 loop j !acc 61 | j == n = VU.fromList (reverse acc) 62 | otherwise = case acc of 63 [] -> loop (succ j) (pure s) 64 (h:_) -> 65 let v = f * (h `B.xor` (B.shiftR h (fi w - 2))) + j 66 in loop (succ j) (v : acc) 67 68 extract :: Gen -> (Word32, Gen) 69 extract gen@(Gen idx _) = 70 let Gen i g = if idx >= n 71 then twist gen 72 else gen 73 74 y = g `VU.unsafeIndex` fi i 75 76 in (temper y, Gen (succ i) g) 77 78 temper :: Word32 -> Word32 79 temper = e4 . e3 . e2 . e1 where 80 e1 = rs u 81 e2 = ls s b 82 e3 = ls t c 83 e4 = rs l 84 85 untemper :: Word32 -> Word32 86 untemper = n1 . n2 . n3 . n4 where 87 n1 = rsinv u 88 n2 = lsinv s b 89 n3 = lsinv t c 90 n4 = rsinv l 91 92 mask :: B.Bits b => Int -> Int -> b 93 mask l h = loop l B.zeroBits where 94 loop j !b 95 | j > h = b 96 | otherwise = 97 loop (succ j) (B.setBit b j) 98 99 ls :: Word32 -> Word32 -> Word32 -> Word32 100 ls s m a = a `B.xor` (B.shiftL a (fi s) .&. m) 101 102 lsinv :: Word32 -> Word32 -> Word32 -> Word32 103 lsinv s bm = loop 0 where 104 loop j !b 105 | j >= fi w = b 106 | otherwise = 107 let m = mask j (min (fi w - 1) (j + fi s - 1)) 108 x = ((m .&. b) `B.shiftL` fi s) .&. bm 109 in loop (j + fi s) (b `B.xor` x) 110 111 rs :: Word32 -> Word32 -> Word32 112 rs s a = a `B.xor` B.shiftR a (fi s) 113 114 rsinv :: Word32 -> Word32 -> Word32 115 rsinv s = loop (fi w - 1) where 116 loop j !b 117 | j <= 0 = b 118 | otherwise = 119 let m = mask (max 0 (j - fi s + 1)) j 120 x = (m .&. b) `B.shiftR` fi s 121 in loop (j - fi s) (b `B.xor` x) 122 123 twist :: Gen -> Gen 124 twist (Gen i gen) = ST.runST $ do 125 g <- VU.thaw gen 126 127 let loop j 128 | j == fi n = pure () 129 | otherwise = do 130 x0 <- g `VUM.unsafeRead` j 131 x1 <- g `VUM.unsafeRead` ((succ j) `mod` fi n) 132 133 let x = (x0 .&. um) + (x1 .&. lm) 134 xa = B.shiftR x 1 135 xA | x `mod` 2 /= 0 = xa `B.xor` a 136 | otherwise = xa 137 138 v <- g `VUM.unsafeRead` ((j + fi m) `mod` fi n) 139 140 VUM.write g j $ v `B.xor` xA 141 loop (succ j) 142 143 loop 0 144 145 fen <- VU.freeze g 146 pure (Gen 0 fen) 147 148 clone :: [Word32] -> Gen 149 clone = Gen n . VU.fromList . fmap untemper 150