commit fad0d48adf4977443f82f49f9ab2ded6a9888839
parent 744b88b537e2fe27aa60224f61518c46183d9fa6
Author: Alex Zarebski <aezarebski@gmail.com>
Date: Thu, 14 May 2020 14:52:12 +0100
include basic testing of chain function
Diffstat:
2 files changed, 40 insertions(+), 4 deletions(-)
diff --git a/mighty-metropolis.cabal b/mighty-metropolis.cabal
@@ -83,4 +83,6 @@ Test-suite tests
, containers >= 0.5 && < 0.6
, mighty-metropolis
, mwc-probability >= 1.0.1
- , hspec
-\ No newline at end of file
+ , hspec
+ , mwc-random
+ , mcmc-types
+\ No newline at end of file
diff --git a/test/test/Spec.hs b/test/test/Spec.hs
@@ -1,4 +1,7 @@
import Test.Hspec
+import Data.Sampling.Types
+import Numeric.MCMC.Metropolis (chain)
+import System.Random.MWC
withinPercent :: Double -> Double -> Double -> Bool
withinPercent a 0 _ = a == 0
@@ -47,6 +50,37 @@ testHelperFunctions = describe "Testing helper functions" $ do
withinPercent (stdErr [1..100]) 2.901149 1e-3 `shouldBe` True
withinPercent (stdErr [1..1000]) 9.133273 1e-3 `shouldBe` True
+thin :: Int -> [a] -> [a]
+thin _ [] = []
+thin n (x:xs) = x : thin n (drop (n - 1) xs)
+
+getChainResults :: IO [Double]
+getChainResults =
+ let numIters = 1000000
+ radialSize = 0.2
+ x0 = [1.0]
+ lnObj [x] =
+ if x > 0
+ then -x
+ else -1 / 0
+ thinning = 1000
+ in do boxedXs <-
+ withSystemRandom . asGenIO $ chain numIters radialSize x0 lnObj
+ return $ thin thinning $ head . chainPosition <$> boxedXs
+
+testChainResults :: [Double] -> SpecWith ()
+testChainResults xs =
+ describe "Testing chain on samples from exponential distribution with rate 1" $ do
+ it "test mean is estimated correctly" $ do
+ mean xs < 1 + 2 * stdErr xs `shouldBe` True
+ mean xs > 1 - 2 * stdErr xs `shouldBe` True
+ it "test variance is estimated correctly" $ do
+ variance xs < 1 + 2 * stdErr [(x - 1) ** 2.0 | x <- xs] `shouldBe` True
+ variance xs > 1 - 2 * stdErr [(x - 1) ** 2.0 | x <- xs] `shouldBe` True
+
main :: IO ()
-main = hspec $ do
- testHelperFunctions
+main = do
+ xs <- getChainResults
+ hspec $ do
+ testHelperFunctions
+ testChainResults xs