cryptopals

Matasano's cryptopals challenges (cryptopals.com).
git clone git://git.jtobin.io/cryptopals.git
Log | Files | Refs | README | LICENSE

MT19937.hs (3798B)


      1 module Cryptopals.Stream.RNG.MT19937 (
      2     Gen
      3   , seed
      4   , extract
      5   , tap
      6 
      7   , clone
      8   ) where
      9 
     10 import qualified Control.Monad.ST as ST
     11 import Data.Bits ((.&.))
     12 import qualified Data.Bits as B
     13 import qualified Data.Vector.Unboxed as VU
     14 import qualified Data.Vector.Unboxed.Mutable as VUM
     15 import GHC.Word (Word32)
     16 
     17 fi :: (Integral a, Num b) => a -> b
     18 fi = fromIntegral
     19 
     20 -- following notation in https://en.wikipedia.org/wiki/Mersenne_Twister
     21 
     22 w, n, m, r, a, u, s, b, t, c, l :: Word32
     23 w = 32            -- word size
     24 n = 624           -- degree of recurrence
     25 m = 397           -- 'middle term'
     26 r = 31            -- word separation index
     27 a = 0x9908B0DF    -- rational normal form twist matrix coefficients
     28 u = 11            -- tempering parameter
     29 s = 7             -- tempering parameter (shift)
     30 b = 0x9D2C5680    -- tempering parameter (mask)
     31 t = 15            -- tempering parameter (shift)
     32 c = 0xEFC60000    -- tempering parameter (mask)
     33 l = 18            -- tempering parameter
     34 
     35 f :: Word32
     36 f = 1812433253
     37 
     38 lm :: Word32
     39 lm = B.shiftL 1 (fi r) - 1 -- 0b0111 1111 1111 1111 1111 1111 1111 1111
     40 
     41 um :: Word32
     42 um = B.complement lm       -- 0b1000 0000 0000 0000 0000 0000 0000 0000
     43 
     44 data Gen = Gen !Word32 !(VU.Vector Word32)
     45   deriving Eq
     46 
     47 instance Show Gen where
     48   show Gen {} = "<MT19937.Gen>"
     49 
     50 tap :: Int -> Gen -> ([Word32], Gen)
     51 tap = loop mempty where
     52   loop !acc j gen
     53     | j == 0    = (reverse acc, gen)
     54     | otherwise =
     55         let (w, g) = extract gen
     56         in  loop (w : acc) (pred j) g
     57 
     58 seed :: Word32 -> Gen
     59 seed s = Gen n (loop 0 mempty) where
     60   loop j !acc
     61     | j == n    = VU.fromList (reverse acc)
     62     | otherwise = case acc of
     63         []    -> loop (succ j) (pure s)
     64         (h:_) ->
     65           let v = f * (h `B.xor` (B.shiftR h (fi w - 2))) + j
     66           in  loop (succ j) (v : acc)
     67 
     68 extract :: Gen -> (Word32, Gen)
     69 extract gen@(Gen idx _) =
     70   let Gen i g = if   idx >= n
     71                 then twist gen
     72                 else gen
     73 
     74       y = g `VU.unsafeIndex` fi i
     75 
     76   in  (temper y, Gen (succ i) g)
     77 
     78 temper :: Word32 -> Word32
     79 temper = e4 . e3 . e2 . e1 where
     80   e1 = rs u
     81   e2 = ls s b
     82   e3 = ls t c
     83   e4 = rs l
     84 
     85 untemper :: Word32 -> Word32
     86 untemper = n1 . n2 . n3 . n4 where
     87   n1 = rsinv u
     88   n2 = lsinv s b
     89   n3 = lsinv t c
     90   n4 = rsinv l
     91 
     92 mask :: B.Bits b => Int -> Int -> b
     93 mask l h = loop l B.zeroBits where
     94   loop j !b
     95     | j > h = b
     96     | otherwise =
     97         loop (succ j) (B.setBit b j)
     98 
     99 ls :: Word32 -> Word32 -> Word32 -> Word32
    100 ls s m a = a `B.xor` (B.shiftL a (fi s) .&. m)
    101 
    102 lsinv :: Word32 -> Word32 -> Word32 -> Word32
    103 lsinv s bm = loop 0 where
    104   loop j !b
    105     | j >= fi w = b
    106     | otherwise =
    107         let m = mask j (min (fi w - 1) (j + fi s - 1))
    108             x = ((m .&. b) `B.shiftL` fi s) .&. bm
    109         in  loop (j + fi s) (b `B.xor` x)
    110 
    111 rs :: Word32 -> Word32 -> Word32
    112 rs s a = a `B.xor` B.shiftR a (fi s)
    113 
    114 rsinv :: Word32 -> Word32 -> Word32
    115 rsinv s = loop (fi w - 1) where
    116   loop j !b
    117     | j <= 0    = b
    118     | otherwise =
    119         let m = mask (max 0 (j - fi s + 1)) j
    120             x = (m .&. b) `B.shiftR` fi s
    121         in  loop (j - fi s) (b `B.xor` x)
    122 
    123 twist :: Gen -> Gen
    124 twist (Gen i gen) = ST.runST $ do
    125     g <- VU.thaw gen
    126 
    127     let loop j
    128           | j == fi n = pure ()
    129           | otherwise = do
    130               x0 <- g `VUM.unsafeRead` j
    131               x1 <- g `VUM.unsafeRead` ((succ j) `mod` fi n)
    132 
    133               let x  = (x0 .&. um) + (x1 .&. lm)
    134                   xa = B.shiftR x 1
    135                   xA | x `mod` 2 /= 0 = xa `B.xor` a
    136                      | otherwise      = xa
    137 
    138               v <- g `VUM.unsafeRead` ((j + fi m) `mod` fi n)
    139 
    140               VUM.write g j $ v `B.xor` xA
    141               loop (succ j)
    142 
    143     loop 0
    144 
    145     fen <- VU.freeze g
    146     pure (Gen 0 fen)
    147 
    148 clone :: [Word32] -> Gen
    149 clone = Gen n . VU.fromList . fmap untemper
    150