MCMC.hs (6529B)
1 {-# OPTIONS_GHC -Wall #-} 2 {-# LANGUAGE RecordWildCards #-} 3 {-# LANGUAGE FlexibleContexts #-} 4 5 -- | 6 -- Module: Numeric.MCMC 7 -- Copyright: (c) 2015 Jared Tobin 8 -- License: MIT 9 -- 10 -- Maintainer: Jared Tobin <jared@jtobin.ca> 11 -- Stability: unstable 12 -- Portability: ghc 13 -- 14 -- This module presents a simple combinator language for Markov transition 15 -- operators that are useful in MCMC. 16 -- 17 -- Any transition operators sharing the same stationary distribution and 18 -- obeying the Markov and reversibility properties can be combined in a couple 19 -- of ways, such that the resulting operator preserves the stationary 20 -- distribution and desirable properties amenable for MCMC. 21 -- 22 -- We can deterministically concatenate operators end-to-end, or sample from 23 -- a collection of them according to some probability distribution. See 24 -- <www.stat.umn.edu/geyer/f05/8931/n1998.pdf Geyer, 2005> for details. 25 -- 26 -- The result is a simple grammar for building composite, property-preserving 27 -- transition operators from existing ones: 28 -- 29 -- @ 30 -- transition ::= primitive <transition> 31 -- | concatT transition transition 32 -- | sampleT transition transition 33 -- @ 34 -- 35 -- In addition to the above, this module provides a number of combinators for 36 -- building composite transition operators. It re-exports a number of 37 -- production-quality transition operators from the /mighty-metropolis/, 38 -- /speedy-slice/, and /hasty-hamiltonian/ libraries. 39 -- 40 -- Markov chains can then be run over arbitrary 'Target's using whatever 41 -- transition operator is desired. 42 -- 43 -- > import Numeric.MCMC 44 -- > import Data.Sampling.Types 45 -- > 46 -- > target :: [Double] -> Double 47 -- > target [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) 48 -- > 49 -- > rosenbrock :: Target [Double] 50 -- > rosenbrock = Target target Nothing 51 -- > 52 -- > transition :: Transition IO (Chain [Double] b) 53 -- > transition = 54 -- > concatT 55 -- > (sampleT (metropolis 0.5) (metropolis 1.0)) 56 -- > (sampleT (slice 2.0) (slice 3.0)) 57 -- > 58 -- > main :: IO () 59 -- > main = withSystemRandom . asGenIO $ mcmc 10000 [0, 0] transition rosenbrock 60 -- 61 -- See the attached test suite for other examples. 62 63 module Numeric.MCMC ( 64 concatT 65 , concatAllT 66 , sampleT 67 , sampleAllT 68 , bernoulliT 69 , frequency 70 , anneal 71 , mcmc 72 , chain 73 74 -- * Re-exported 75 , module Data.Sampling.Types 76 77 , metropolis 78 , hamiltonian 79 , slice 80 81 , MWC.create 82 , MWC.createSystemRandom 83 , MWC.withSystemRandom 84 , MWC.asGenIO 85 86 , PrimMonad 87 , PrimState 88 , RealWorld 89 ) where 90 91 import Control.Monad (replicateM) 92 import Control.Monad.Codensity (lowerCodensity) 93 import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld) 94 import Control.Monad.Trans.State.Strict (execStateT) 95 import Data.Sampling.Types 96 import Numeric.MCMC.Anneal 97 import qualified Numeric.MCMC.Metropolis as M (metropolis) 98 import Numeric.MCMC.Hamiltonian (hamiltonian) 99 import Numeric.MCMC.Slice (slice) 100 import Pipes hiding (next) 101 import qualified Pipes.Prelude as Pipes 102 import System.Random.MWC.Probability (Gen) 103 import qualified System.Random.MWC.Probability as MWC 104 105 -- | Deterministically concat transition operators together. 106 concatT :: Monad m => Transition m a -> Transition m a -> Transition m a 107 concatT = (>>) 108 109 -- | Deterministically concat a list of transition operators together. 110 concatAllT :: Monad m => [Transition m a] -> Transition m a 111 concatAllT = foldl1 (>>) 112 113 -- | Probabilistically concat transition operators together. 114 sampleT :: PrimMonad m => Transition m a -> Transition m a -> Transition m a 115 sampleT = bernoulliT 0.5 116 117 -- | Probabilistically concat transition operators together using a Bernoulli 118 -- distribution with the supplied success probability. 119 -- 120 -- This is just a generalization of sampleT. 121 bernoulliT 122 :: PrimMonad m 123 => Double 124 -> Transition m a 125 -> Transition m a 126 -> Transition m a 127 bernoulliT p t0 t1 = do 128 heads <- lift (MWC.bernoulli p) 129 if heads then t0 else t1 130 131 -- | Probabilistically concat transition operators together via a uniform 132 -- distribution. 133 sampleAllT :: PrimMonad m => [Transition m a] -> Transition m a 134 sampleAllT ts = do 135 j <- lift (MWC.uniformR (0, length ts - 1)) 136 ts !! j 137 138 -- | Probabilistically concat transition operators together using the supplied 139 -- frequency distribution. 140 -- 141 -- This function is more-or-less an exact copy of 'QuickCheck.frequency', 142 -- except here applied to transition operators. 143 frequency :: PrimMonad m => [(Int, Transition m a)] -> Transition m a 144 frequency xs = lift (MWC.uniformR (1, tot)) >>= (`pick` xs) where 145 tot = sum . map fst $ xs 146 pick n ((k, v):vs) 147 | n <= k = v 148 | otherwise = pick (n - k) vs 149 pick _ _ = error "frequency: no distribution specified" 150 151 -- | Trace 'n' iterations of a Markov chain and stream them to stdout. 152 -- 153 -- >>> withSystemRandom . asGenIO $ mcmc 3 [0, 0] (metropolis 0.5) rosenbrock 154 -- -0.48939312153007863,0.13290702689491818 155 -- 1.4541485365128892e-2,-0.4859905564050404 156 -- 0.22487398491619448,-0.29769783186855125 157 mcmc 158 :: (MonadIO m, PrimMonad m, Show (t a)) 159 => Int 160 -> t a 161 -> Transition m (Chain (t a) b) 162 -> Target (t a) 163 -> Gen (PrimState m) 164 -> m () 165 mcmc n chainPosition transition chainTarget gen = runEffect $ 166 drive transition Chain {..} gen 167 >-> Pipes.take n 168 >-> Pipes.mapM_ (liftIO . print) 169 where 170 chainScore = lTarget chainTarget chainPosition 171 chainTunables = Nothing 172 173 -- | Trace 'n' iterations of a Markov chain and collect them in a list. 174 -- 175 -- >>> results <- withSystemRandom . asGenIO $ chain 3 [0, 0] (metropolis 0.5) rosenbrock 176 chain 177 :: (MonadIO m, PrimMonad m) 178 => Int 179 -> t a 180 -> Transition m (Chain (t a) b) 181 -> Target (t a) 182 -> Gen (PrimState m) 183 -> m [Chain (t a) b] 184 chain n chainPosition transition chainTarget gen = runEffect $ 185 drive transition Chain {..} gen 186 >-> collect n 187 where 188 chainScore = lTarget chainTarget chainPosition 189 chainTunables = Nothing 190 191 collect :: Monad m => Int -> Consumer a m [a] 192 collect size = lowerCodensity $ 193 replicateM size (lift Pipes.await) 194 195 -- A Markov chain driven by an arbitrary transition operator. 196 drive 197 :: PrimMonad m 198 => Transition m b 199 -> b 200 -> Gen (PrimState m) 201 -> Producer b m a 202 drive transition = loop where 203 loop state prng = do 204 next <- lift (MWC.sample (execStateT transition state) prng) 205 yield next 206 loop next prng 207 208 -- | A generic Metropolis transition operator. 209 metropolis 210 :: (Traversable f, PrimMonad m) 211 => Double 212 -> Transition m (Chain (f Double) b) 213 metropolis radial = M.metropolis radial Nothing