mighty-metropolis

The classic Metropolis algorithm.
Log | Files | Refs | README | LICENSE

commit d9dec7542e2635e4b24c7ad6a31d35120933703c
parent fad0d48adf4977443f82f49f9ab2ded6a9888839
Author: Alex Zarebski <aezarebski@gmail.com>
Date:   Thu, 14 May 2020 15:40:21 +0100

implement handling of tunables

Diffstat:
MNumeric/MCMC/Metropolis.hs | 63++++++++++++++++++++++++++++++++++++++-------------------------
Mtest/test/Spec.hs | 29+++++++++++++++++++++++++++--
2 files changed, 65 insertions(+), 27 deletions(-)

diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs @@ -22,6 +22,7 @@ module Numeric.MCMC.Metropolis ( mcmc , chain + , chain' , metropolis -- * Re-exported @@ -61,29 +62,55 @@ propose radial = traverse perturb where metropolis :: (Traversable f, PrimMonad m) => Double + -> Maybe (f Double -> b) -> Transition m (Chain (f Double) b) -metropolis radial = do +metropolis radial tunable = do Chain {..} <- get proposal <- lift (propose radial chainPosition) let proposalScore = lTarget chainTarget proposal acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore))) accept <- lift (MWC.bernoulli acceptProb) - when accept (put (Chain chainTarget proposalScore proposal chainTunables)) + when accept (put (Chain chainTarget proposalScore proposal (tunable <*> Just proposal))) -- Drive a Markov chain via the Metropolis transition operator. drive :: (Traversable f, PrimMonad m) => Double + -> Maybe (f Double -> b) -> Chain (f Double) b -> Gen (PrimState m) -> Producer (Chain (f Double) b) m c -drive radial = loop where +drive radial tunable = loop where loop state prng = do - next <- lift (MWC.sample (execStateT (metropolis radial) state) prng) + next <- lift (MWC.sample (execStateT (metropolis radial tunable) state) prng) yield next loop next prng +-- | Return a list of @Chain@ values potentially with tunable values. +chain' :: + (PrimMonad m, Traversable f) + => Int + -> Double + -> f Double + -> (f Double -> Double) + -> Maybe (f Double -> b) + -> Gen (PrimState m) + -> m [Chain (f Double) b] +chain' n radial position target tunable gen = + runEffect $ drive radial tunable origin gen >-> collect n + where + ctarget = Target target Nothing + origin = + Chain + { chainScore = lTarget ctarget position + , chainTunables = tunable <*> Just position + , chainTarget = ctarget + , chainPosition = position + } + collect :: Monad m => Int -> Consumer a m [a] + collect size = lowerCodensity $ replicateM size (lift Pipes.await) + -- | Trace 'n' iterations of a Markov chain and collect the results in a list. -- -- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) @@ -94,28 +121,14 @@ drive radial = loop where -- 3.8379699517007895e-3,0.24627131099479127 chain :: (PrimMonad m, Traversable f) - => Int - -> Double - -> f Double - -> (f Double -> Double) + => Int -- ^ Number of iterations + -> Double -- ^ Step standard deviation + -> f Double -- ^ Starting position + -> (f Double -> Double) -- ^ The log-density (up to additive constant) -> Gen (PrimState m) -> m [Chain (f Double) b] -chain n radial position target gen = runEffect $ - drive radial origin gen - >-> collect n - where - ctarget = Target target Nothing - - origin = Chain { - chainScore = lTarget ctarget position - , chainTunables = Nothing - , chainTarget = ctarget - , chainPosition = position - } - - collect :: Monad m => Int -> Consumer a m [a] - collect size = lowerCodensity $ - replicateM size (lift Pipes.await) +chain n radial position target = + chain' n radial position target Nothing -- | Trace 'n' iterations of a Markov chain and stream them to stdout. -- @@ -133,7 +146,7 @@ mcmc -> Gen (PrimState m) -> m () mcmc n radial chainPosition target gen = runEffect $ - drive radial Chain {..} gen + drive radial Nothing Chain {..} gen >-> Pipes.take n >-> Pipes.mapM_ (liftIO . print) where diff --git a/test/test/Spec.hs b/test/test/Spec.hs @@ -1,6 +1,7 @@ import Test.Hspec import Data.Sampling.Types -import Numeric.MCMC.Metropolis (chain) +import Data.Maybe (fromJust) +import Numeric.MCMC.Metropolis (chain,chain') import System.Random.MWC withinPercent :: Double -> Double -> Double -> Bool @@ -70,7 +71,7 @@ getChainResults = testChainResults :: [Double] -> SpecWith () testChainResults xs = - describe "Testing chain on samples from exponential distribution with rate 1" $ do + describe "Testing samples from exponential distribution with rate 1" $ do it "test mean is estimated correctly" $ do mean xs < 1 + 2 * stdErr xs `shouldBe` True mean xs > 1 - 2 * stdErr xs `shouldBe` True @@ -78,9 +79,33 @@ testChainResults xs = variance xs < 1 + 2 * stdErr [(x - 1) ** 2.0 | x <- xs] `shouldBe` True variance xs > 1 - 2 * stdErr [(x - 1) ** 2.0 | x <- xs] `shouldBe` True +getTunableResults :: IO [Double] +getTunableResults = + let numIters = 1000000 + radialSize = 0.2 + x0 = [1.0] + lnObj [x] = + if x > 0 + then -x + else -1 / 0 + tunable [x] = x ** 3.0 + thinning = 1000 + in do boxedXs <- + withSystemRandom . asGenIO $ chain' numIters radialSize x0 lnObj (Just tunable) + return $ thin thinning $ fromJust . chainTunables <$> boxedXs + +testTunableResults :: [Double] -> SpecWith () +testTunableResults ts = + describe "Testing third moment of exponential distribution with rate 1" $ do + it "test third moment (which is 6) is estimated correctly" $ do + mean ts < 6 + 2 * stdErr ts `shouldBe` True + mean ts > 6 - 2 * stdErr ts `shouldBe` True + main :: IO () main = do xs <- getChainResults + ts <- getTunableResults hspec $ do testHelperFunctions testChainResults xs + testTunableResults ts