Spec.hs (4204B)
1 {-# OPTIONS_GHC -Wall #-} 2 {-# LANGUAGE RecordWildCards #-} 3 4 import qualified Control.Foldl as L 5 import Data.Functor.Identity 6 import Data.Maybe (mapMaybe) 7 import Data.Sampling.Types 8 import Numeric.MCMC.Metropolis (chain, chain') 9 import System.Random.MWC 10 import Test.Hspec 11 12 withinPercent :: Double -> Double -> Double -> Bool 13 withinPercent b n a 14 | b == 0 = a == 0 15 | otherwise = d / b < n / 100 16 where 17 d = abs (a - b) 18 19 mean :: [Double] -> Double 20 mean = L.fold L.mean 21 22 variance :: [Double] -> Double 23 variance xs = L.fold alg xs where 24 alg = (/) <$> L.premap csq L.sum <*> L.genericLength - 1 25 csq = (** 2.0) . subtract m 26 m = mean xs 27 28 stdDev :: [Double] -> Double 29 stdDev = sqrt . variance 30 31 stdErr :: [Double] -> Double 32 stdErr xs = stdDev xs / sqrt n where 33 n = fromIntegral (length xs) 34 35 thin :: Int -> [a] -> [a] 36 thin n xs = case xs of 37 (h:t) -> h : thin n (drop (pred n) t) 38 _ -> mempty 39 40 data Params = Params { 41 pepochs :: Int 42 , pradial :: Double 43 , porigin :: Identity Double 44 , ptunable :: Maybe (Identity Double -> Double) 45 , pltarget :: Identity Double -> Double 46 , pthin :: Int 47 } 48 49 testParams :: Params 50 testParams = Params { 51 pepochs = 1000000 52 , pradial = 0.2 53 , porigin = Identity 1.0 54 , ptunable = Just (\(Identity x) -> x ** 3.0) 55 , pltarget = \(Identity x) -> if x > 0 then negate x else negate 1 / 0 56 , pthin = 1000 57 } 58 59 vanillaTrace :: IO [Double] 60 vanillaTrace = do 61 let Params {..} = testParams 62 63 boxed <- withSystemRandom . asGenIO $ 64 chain pepochs pradial porigin pltarget 65 66 let positions = fmap (runIdentity . chainPosition) boxed 67 pure (thin pthin positions) 68 69 tunedTrace :: IO [Double] 70 tunedTrace = do 71 let Params {..} = testParams 72 73 boxed <- withSystemRandom . asGenIO $ 74 chain' pepochs pradial porigin pltarget ptunable 75 76 let positions = mapMaybe chainTunables boxed 77 pure (thin pthin positions) 78 79 testWithinPercent :: SpecWith () 80 testWithinPercent = describe "withinPercent" $ 81 it "works as expected" $ do 82 106 `shouldNotSatisfy` withinPercent 100 5 83 105 `shouldNotSatisfy` withinPercent 100 5 84 104 `shouldSatisfy` withinPercent 100 5 85 96 `shouldSatisfy` withinPercent 100 5 86 95 `shouldNotSatisfy` withinPercent 100 5 87 94 `shouldNotSatisfy` withinPercent 100 5 88 89 testMean :: SpecWith () 90 testMean = describe "mean" $ 91 it "works as expected" $ do 92 mean [1, 2, 3] `shouldSatisfy` withinPercent 2 1e-3 93 mean [1..100] `shouldSatisfy` withinPercent 50.5 1e-3 94 mean [1..1000000] `shouldSatisfy` withinPercent 500000.5 1e-3 95 96 testVariance :: SpecWith () 97 testVariance = describe "variance" $ 98 it "works as expected" $ do 99 variance [0, 1] `shouldSatisfy` withinPercent 0.5 1e-3 100 variance [1, 1, 1] `shouldSatisfy` withinPercent 0 1e-3 101 variance [1..100] `shouldSatisfy` withinPercent 841.66666666 1e-3 102 103 testStdErr :: SpecWith () 104 testStdErr = describe "stdErr" $ 105 it "works as expected" $ do 106 stdErr [1..100] `shouldSatisfy` withinPercent 2.901149 1e-3 107 stdErr [1..1000] `shouldSatisfy` withinPercent 9.133273 1e-3 108 109 testHelperFunctions :: SpecWith () 110 testHelperFunctions = describe "helper functions" $ do 111 testWithinPercent 112 testMean 113 testVariance 114 testStdErr 115 116 testSamples :: [Double] -> SpecWith () 117 testSamples xs = describe "sampled trace over exp(1)" $ do 118 let meanStdErr = stdErr xs 119 varStdErr = stdErr (fmap (\x -> pred x ** 2.0) xs) 120 121 context "within three standard errors" $ do 122 it "has the expected mean" $ do 123 mean xs `shouldSatisfy` (< 1 + 3 * meanStdErr) 124 mean xs `shouldSatisfy` (> 1 - 3 * meanStdErr) 125 126 it "has the expected variance" $ do 127 variance xs `shouldSatisfy` (< 1 + 3 * varStdErr) 128 variance xs `shouldSatisfy` (> 1 - 3 * varStdErr) 129 130 testTunables :: [Double] -> SpecWith () 131 testTunables ts = describe "sampled tunables over exp(1)" $ do 132 let meanStdErr = stdErr ts 133 134 context "within three standard errors" $ 135 it "has the expected third moment" $ do 136 mean ts `shouldSatisfy` (< 6 + 3 * meanStdErr) 137 mean ts `shouldSatisfy` (> 6 - 3 * meanStdErr) 138 139 main :: IO () 140 main = do 141 xs <- vanillaTrace 142 ts <- tunedTrace 143 144 hspec $ do 145 testHelperFunctions 146 testSamples xs 147 testTunables ts