commit d9dec7542e2635e4b24c7ad6a31d35120933703c
parent fad0d48adf4977443f82f49f9ab2ded6a9888839
Author: Alex Zarebski <aezarebski@gmail.com>
Date: Thu, 14 May 2020 15:40:21 +0100
implement handling of tunables
Diffstat:
2 files changed, 65 insertions(+), 27 deletions(-)
diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs
@@ -22,6 +22,7 @@
module Numeric.MCMC.Metropolis (
mcmc
, chain
+ , chain'
, metropolis
-- * Re-exported
@@ -61,29 +62,55 @@ propose radial = traverse perturb where
metropolis
:: (Traversable f, PrimMonad m)
=> Double
+ -> Maybe (f Double -> b)
-> Transition m (Chain (f Double) b)
-metropolis radial = do
+metropolis radial tunable = do
Chain {..} <- get
proposal <- lift (propose radial chainPosition)
let proposalScore = lTarget chainTarget proposal
acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore)))
accept <- lift (MWC.bernoulli acceptProb)
- when accept (put (Chain chainTarget proposalScore proposal chainTunables))
+ when accept (put (Chain chainTarget proposalScore proposal (tunable <*> Just proposal)))
-- Drive a Markov chain via the Metropolis transition operator.
drive
:: (Traversable f, PrimMonad m)
=> Double
+ -> Maybe (f Double -> b)
-> Chain (f Double) b
-> Gen (PrimState m)
-> Producer (Chain (f Double) b) m c
-drive radial = loop where
+drive radial tunable = loop where
loop state prng = do
- next <- lift (MWC.sample (execStateT (metropolis radial) state) prng)
+ next <- lift (MWC.sample (execStateT (metropolis radial tunable) state) prng)
yield next
loop next prng
+-- | Return a list of @Chain@ values potentially with tunable values.
+chain' ::
+ (PrimMonad m, Traversable f)
+ => Int
+ -> Double
+ -> f Double
+ -> (f Double -> Double)
+ -> Maybe (f Double -> b)
+ -> Gen (PrimState m)
+ -> m [Chain (f Double) b]
+chain' n radial position target tunable gen =
+ runEffect $ drive radial tunable origin gen >-> collect n
+ where
+ ctarget = Target target Nothing
+ origin =
+ Chain
+ { chainScore = lTarget ctarget position
+ , chainTunables = tunable <*> Just position
+ , chainTarget = ctarget
+ , chainPosition = position
+ }
+ collect :: Monad m => Int -> Consumer a m [a]
+ collect size = lowerCodensity $ replicateM size (lift Pipes.await)
+
-- | Trace 'n' iterations of a Markov chain and collect the results in a list.
--
-- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
@@ -94,28 +121,14 @@ drive radial = loop where
-- 3.8379699517007895e-3,0.24627131099479127
chain
:: (PrimMonad m, Traversable f)
- => Int
- -> Double
- -> f Double
- -> (f Double -> Double)
+ => Int -- ^ Number of iterations
+ -> Double -- ^ Step standard deviation
+ -> f Double -- ^ Starting position
+ -> (f Double -> Double) -- ^ The log-density (up to additive constant)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
-chain n radial position target gen = runEffect $
- drive radial origin gen
- >-> collect n
- where
- ctarget = Target target Nothing
-
- origin = Chain {
- chainScore = lTarget ctarget position
- , chainTunables = Nothing
- , chainTarget = ctarget
- , chainPosition = position
- }
-
- collect :: Monad m => Int -> Consumer a m [a]
- collect size = lowerCodensity $
- replicateM size (lift Pipes.await)
+chain n radial position target =
+ chain' n radial position target Nothing
-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--
@@ -133,7 +146,7 @@ mcmc
-> Gen (PrimState m)
-> m ()
mcmc n radial chainPosition target gen = runEffect $
- drive radial Chain {..} gen
+ drive radial Nothing Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
diff --git a/test/test/Spec.hs b/test/test/Spec.hs
@@ -1,6 +1,7 @@
import Test.Hspec
import Data.Sampling.Types
-import Numeric.MCMC.Metropolis (chain)
+import Data.Maybe (fromJust)
+import Numeric.MCMC.Metropolis (chain,chain')
import System.Random.MWC
withinPercent :: Double -> Double -> Double -> Bool
@@ -70,7 +71,7 @@ getChainResults =
testChainResults :: [Double] -> SpecWith ()
testChainResults xs =
- describe "Testing chain on samples from exponential distribution with rate 1" $ do
+ describe "Testing samples from exponential distribution with rate 1" $ do
it "test mean is estimated correctly" $ do
mean xs < 1 + 2 * stdErr xs `shouldBe` True
mean xs > 1 - 2 * stdErr xs `shouldBe` True
@@ -78,9 +79,33 @@ testChainResults xs =
variance xs < 1 + 2 * stdErr [(x - 1) ** 2.0 | x <- xs] `shouldBe` True
variance xs > 1 - 2 * stdErr [(x - 1) ** 2.0 | x <- xs] `shouldBe` True
+getTunableResults :: IO [Double]
+getTunableResults =
+ let numIters = 1000000
+ radialSize = 0.2
+ x0 = [1.0]
+ lnObj [x] =
+ if x > 0
+ then -x
+ else -1 / 0
+ tunable [x] = x ** 3.0
+ thinning = 1000
+ in do boxedXs <-
+ withSystemRandom . asGenIO $ chain' numIters radialSize x0 lnObj (Just tunable)
+ return $ thin thinning $ fromJust . chainTunables <$> boxedXs
+
+testTunableResults :: [Double] -> SpecWith ()
+testTunableResults ts =
+ describe "Testing third moment of exponential distribution with rate 1" $ do
+ it "test third moment (which is 6) is estimated correctly" $ do
+ mean ts < 6 + 2 * stdErr ts `shouldBe` True
+ mean ts > 6 - 2 * stdErr ts `shouldBe` True
+
main :: IO ()
main = do
xs <- getChainResults
+ ts <- getTunableResults
hspec $ do
testHelperFunctions
testChainResults xs
+ testTunableResults ts