mighty-metropolis

The classic Metropolis algorithm.
git clone git://git.jtobin.io/mighty-metropolis.git
Log | Files | Refs | README | LICENSE

Spec.hs (4204B)


      1 {-# OPTIONS_GHC -Wall #-}
      2 {-# LANGUAGE RecordWildCards #-}
      3 
      4 import qualified Control.Foldl as L
      5 import Data.Functor.Identity
      6 import Data.Maybe (mapMaybe)
      7 import Data.Sampling.Types
      8 import Numeric.MCMC.Metropolis (chain, chain')
      9 import System.Random.MWC
     10 import Test.Hspec
     11 
     12 withinPercent :: Double -> Double -> Double -> Bool
     13 withinPercent b n a
     14     | b == 0    = a == 0
     15     | otherwise = d / b < n / 100
     16   where
     17     d = abs (a - b)
     18 
     19 mean :: [Double] -> Double
     20 mean = L.fold L.mean
     21 
     22 variance :: [Double] -> Double
     23 variance xs = L.fold alg xs where
     24   alg = (/) <$> L.premap csq L.sum <*> L.genericLength - 1
     25   csq = (** 2.0) . subtract m
     26   m   = mean xs
     27 
     28 stdDev :: [Double] -> Double
     29 stdDev = sqrt . variance
     30 
     31 stdErr :: [Double] -> Double
     32 stdErr xs = stdDev xs / sqrt n where
     33   n = fromIntegral (length xs)
     34 
     35 thin :: Int -> [a] -> [a]
     36 thin n xs = case xs of
     37   (h:t) -> h : thin n (drop (pred n) t)
     38   _     -> mempty
     39 
     40 data Params = Params {
     41     pepochs  :: Int
     42   , pradial  :: Double
     43   , porigin  :: Identity Double
     44   , ptunable :: Maybe (Identity Double -> Double)
     45   , pltarget :: Identity Double -> Double
     46   , pthin    :: Int
     47   }
     48 
     49 testParams :: Params
     50 testParams = Params {
     51     pepochs  = 1000000
     52   , pradial  = 0.2
     53   , porigin  = Identity 1.0
     54   , ptunable = Just (\(Identity x) -> x ** 3.0)
     55   , pltarget = \(Identity x) -> if x > 0 then negate x else negate 1 / 0
     56   , pthin    = 1000
     57   }
     58 
     59 vanillaTrace :: IO [Double]
     60 vanillaTrace = do
     61   let Params {..} = testParams
     62 
     63   boxed <- withSystemRandom . asGenIO $
     64     chain pepochs pradial porigin pltarget
     65 
     66   let positions = fmap (runIdentity . chainPosition) boxed
     67   pure (thin pthin positions)
     68 
     69 tunedTrace :: IO [Double]
     70 tunedTrace = do
     71   let Params {..} = testParams
     72 
     73   boxed <- withSystemRandom . asGenIO $
     74     chain' pepochs pradial porigin pltarget ptunable
     75 
     76   let positions = mapMaybe chainTunables boxed
     77   pure (thin pthin positions)
     78 
     79 testWithinPercent :: SpecWith ()
     80 testWithinPercent = describe "withinPercent" $
     81   it "works as expected" $ do
     82     106 `shouldNotSatisfy` withinPercent 100 5
     83     105 `shouldNotSatisfy` withinPercent 100 5
     84     104 `shouldSatisfy`    withinPercent 100 5
     85     96  `shouldSatisfy`    withinPercent 100 5
     86     95  `shouldNotSatisfy` withinPercent 100 5
     87     94  `shouldNotSatisfy` withinPercent 100 5
     88 
     89 testMean :: SpecWith ()
     90 testMean = describe "mean" $
     91   it "works as expected" $ do
     92     mean [1, 2, 3]    `shouldSatisfy` withinPercent 2 1e-3
     93     mean [1..100]     `shouldSatisfy` withinPercent 50.5 1e-3
     94     mean [1..1000000] `shouldSatisfy` withinPercent 500000.5 1e-3
     95 
     96 testVariance :: SpecWith ()
     97 testVariance = describe "variance" $
     98   it "works as expected" $ do
     99     variance [0, 1]    `shouldSatisfy` withinPercent 0.5 1e-3
    100     variance [1, 1, 1] `shouldSatisfy` withinPercent 0 1e-3
    101     variance [1..100]  `shouldSatisfy` withinPercent 841.66666666 1e-3
    102 
    103 testStdErr :: SpecWith ()
    104 testStdErr = describe "stdErr" $
    105   it "works as expected" $ do
    106     stdErr [1..100]  `shouldSatisfy` withinPercent 2.901149 1e-3
    107     stdErr [1..1000] `shouldSatisfy` withinPercent 9.133273 1e-3
    108 
    109 testHelperFunctions :: SpecWith ()
    110 testHelperFunctions = describe "helper functions" $ do
    111   testWithinPercent
    112   testMean
    113   testVariance
    114   testStdErr
    115 
    116 testSamples :: [Double] -> SpecWith ()
    117 testSamples xs = describe "sampled trace over exp(1)" $ do
    118   let meanStdErr = stdErr xs
    119       varStdErr  = stdErr (fmap (\x -> pred x ** 2.0) xs)
    120 
    121   context "within three standard errors" $ do
    122     it "has the expected mean" $ do
    123       mean xs `shouldSatisfy` (< 1 + 3 * meanStdErr)
    124       mean xs `shouldSatisfy` (> 1 - 3 * meanStdErr)
    125 
    126     it "has the expected variance" $ do
    127       variance xs `shouldSatisfy` (< 1 + 3 * varStdErr)
    128       variance xs `shouldSatisfy` (> 1 - 3 * varStdErr)
    129 
    130 testTunables :: [Double] -> SpecWith ()
    131 testTunables ts = describe "sampled tunables over exp(1)" $ do
    132   let meanStdErr = stdErr ts
    133 
    134   context "within three standard errors" $
    135     it "has the expected third moment" $ do
    136       mean ts `shouldSatisfy` (< 6 + 3 * meanStdErr)
    137       mean ts `shouldSatisfy` (> 6 - 3 * meanStdErr)
    138 
    139 main :: IO ()
    140 main = do
    141   xs <- vanillaTrace
    142   ts <- tunedTrace
    143 
    144   hspec $ do
    145     testHelperFunctions
    146     testSamples xs
    147     testTunables ts