commit 2d643548cb4b67baa46d316eaef0df5e14b292b5
parent 2e2f27ed7b81b172420a52bb42167dbf7c23ce7b
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 30 Jul 2023 05:27:20 -0230
MT19937 fixes.
Diffstat:
2 files changed, 50 insertions(+), 27 deletions(-)
diff --git a/docs/s3.md b/docs/s3.md
@@ -108,18 +108,17 @@ gist of it:
#### 3.21
-The only annoying thing about this problem is finding a test vector to
-check the implementation on. `Cryptopals.Stream.RNG.MT19937` implements
-the Mersenne Twister (MT19937) PRNG in standard return-the-generator
-fashion:
+`Cryptopals.Stream.RNG.MT19937` implements the Mersenne Twister
+(MT19937) PRNG in standard return-the-generator fashion:
> let gen = seed 42
- > :{
- ghci| let loop g = let (x, g1) = extract g
- ghci| (y, g2) = extract g1
- ghci| (z, g3) = extract g2
- ghci| in ([x,y,z], g3)
- ghci| :}
- > fst $ loop gen
- [1880174161,2921269485,2329092166]
+ > bytes 3 gen
+ > ([1608637542,3421126067,4083286876],<MT19937.Gen>)
+
+The only annoying thing about this problem was finding a test vector
+to check the implementation against. I used the outputs on [this
+guy's](https://create.stephan-brumme.com/mersenne-twister/) page;
+the implementations he cites return signed 32-bit integers, but I
+use (unsigned) Word32. One can convert results to e.g. Int32 with
+fromIntegral to verify.
diff --git a/lib/Cryptopals/Stream/RNG/MT19937.hs b/lib/Cryptopals/Stream/RNG/MT19937.hs
@@ -2,11 +2,14 @@ module Cryptopals.Stream.RNG.MT19937 (
Gen
, seed
, extract
+ , bytes
) where
+import qualified Control.Monad.ST as ST
import Data.Bits ((.&.))
import qualified Data.Bits as B
import qualified Data.Vector.Unboxed as VU
+import qualified Data.Vector.Unboxed.Mutable as VUM
import GHC.Word (Word32)
fi :: (Integral a, Num b) => a -> b
@@ -38,7 +41,10 @@ um :: Word32
um = B.complement lm -- 0x1000 0000 0000 0000 0000 0000 0000 0000
data Gen = Gen !Word32 !(VU.Vector Word32)
- deriving (Eq, Show)
+ deriving Eq
+
+instance Show Gen where
+ show Gen {} = "<MT19937.Gen>"
seed :: Word32 -> Gen
seed s = Gen n (loop 0 mempty) where
@@ -47,12 +53,12 @@ seed s = Gen n (loop 0 mempty) where
| otherwise = case acc of
[] -> loop (succ j) (pure s)
(h:_) ->
- let v = f * (h `B.xor` (B.shiftR h (fi w - 2))) + j -- XX can overflow?
+ let v = f * (h `B.xor` (B.shiftR h (fi w - 2))) + j
in loop (succ j) (v : acc)
extract :: Gen -> (Word32, Gen)
extract gen@(Gen idx _) =
- let Gen i g = if idx == n
+ let Gen i g = if idx >= n
then twist gen
else gen
@@ -60,22 +66,40 @@ extract gen@(Gen idx _) =
y1 = y0 `B.xor` ((B.shiftR y0 (fi u)) .&. d)
y2 = y1 `B.xor` ((B.shiftL y1 (fi s)) .&. b)
y3 = y2 `B.xor` ((B.shiftL y2 (fi t)) .&. c)
- y4 = y3 `B.xor` (B.shiftR y3 1)
+ y4 = y3 `B.xor` (B.shiftR y3 18)
in (y4, Gen (succ i) g)
twist :: Gen -> Gen
-twist (Gen i g) = loop 0 mempty where
- loop j !acc
- | j == fi n = Gen 0 (g VU.// acc)
- | otherwise =
- let x = ((g `VU.unsafeIndex` j) .&. um) +
- ((g `VU.unsafeIndex` (succ j `mod` fi n)) .&. lm) -- XX check
- xa = B.shiftR x 1
- xA | x `mod` 2 /= 0 = xa `B.xor` a
- | otherwise = xa
+twist (Gen i gen) = ST.runST $ do
+ g <- VU.thaw gen
+
+ let loop j
+ | j == fi n = pure ()
+ | otherwise = do
+ x0 <- g `VUM.unsafeRead` j
+ x1 <- g `VUM.unsafeRead` ((succ j) `mod` fi n)
+
+ let x = (x0 .&. um) + (x1 .&. lm)
+ xa = B.shiftR x 1
+ xA | x `mod` 2 /= 0 = xa `B.xor` a
+ | otherwise = xa
- v = (g `VU.unsafeIndex` ((j + fi m) `mod` fi n)) `B.xor` xA
+ v <- g `VUM.unsafeRead` ((j + fi m) `mod` fi n)
- in loop (succ j) ((j, v) : acc)
+ VUM.write g j $ v `B.xor` xA
+ loop (succ j)
+
+ loop 0
+
+ fen <- VU.freeze g
+ pure (Gen 0 fen)
+
+bytes :: Int -> Gen -> ([Word32], Gen)
+bytes = loop mempty where
+ loop !acc j gen
+ | j == 0 = (reverse acc, gen)
+ | otherwise =
+ let (w, g) = extract gen
+ in loop (w : acc) (pred j) g