Metropolis.hs (5360B)
1 {-# OPTIONS_GHC -Wall #-} 2 {-# LANGUAGE CPP #-} 3 {-# LANGUAGE RecordWildCards #-} 4 5 -- | 6 -- Module: Numeric.MCMC.Metropolis 7 -- Copyright: (c) 2015 Jared Tobin 8 -- License: MIT 9 -- 10 -- Maintainer: Jared Tobin <jared@jtobin.ca> 11 -- Stability: unstable 12 -- Portability: ghc 13 -- 14 -- This implementation uses spherical Gaussian proposals to implement a 15 -- reliable and computationally inexpensive sampling routine. It can be used 16 -- as a baseline from which to benchmark other algorithms for a given problem. 17 -- 18 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere, 19 -- while the `metropolis` transition can be used for more flexible purposes, 20 -- such as working with samples in memory. 21 22 module Numeric.MCMC.Metropolis ( 23 mcmc 24 , chain 25 , chain' 26 , metropolis 27 28 -- * Re-exported 29 , module Data.Sampling.Types 30 , MWC.create 31 , MWC.createSystemRandom 32 , MWC.withSystemRandom 33 , MWC.asGenIO 34 ) where 35 36 import Control.Monad (when, replicateM) 37 import Control.Monad.Codensity (lowerCodensity) 38 import Control.Monad.Primitive (PrimMonad, PrimState) 39 import Control.Monad.Trans.Class (lift) 40 import Control.Monad.Trans.State.Strict (execStateT, get, put) 41 import Control.Monad.IO.Class (MonadIO, liftIO) 42 import Data.Sampling.Types (Target(..), Chain(..), Transition) 43 #if __GLASGOW_HASKELL__ < 710 44 import Data.Traversable (Traversable, traverse) 45 #endif 46 import Pipes (Producer, Consumer, yield, (>->), runEffect, await) 47 import qualified Pipes.Prelude as Pipes (mapM_, take) 48 import System.Random.MWC.Probability (Gen, Prob) 49 import qualified System.Random.MWC.Probability as MWC 50 51 -- Propose a state transition according to a Gaussian proposal distribution 52 -- with the specified standard deviation. 53 propose 54 :: (PrimMonad m, Traversable f) 55 => Double 56 -> f Double 57 -> Prob m (f Double) 58 propose radial = traverse perturb where 59 perturb m = MWC.normal m radial 60 61 -- | A generic Metropolis transition operator. 62 metropolis 63 :: (Traversable f, PrimMonad m) 64 => Double 65 -> Maybe (f Double -> b) 66 -> Transition m (Chain (f Double) b) 67 metropolis radial tunable = do 68 Chain {..} <- get 69 proposal <- lift (propose radial chainPosition) 70 let proposalScore = lTarget chainTarget proposal 71 acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore))) 72 73 accept <- lift (MWC.bernoulli acceptProb) 74 when accept $ do 75 let tuned = tunable <*> Just proposal 76 put (Chain chainTarget proposalScore proposal tuned) 77 78 -- Drive a Markov chain via the Metropolis transition operator. 79 drive 80 :: (Traversable f, PrimMonad m) 81 => Double 82 -> Maybe (f Double -> b) 83 -> Chain (f Double) b 84 -> Gen (PrimState m) 85 -> Producer (Chain (f Double) b) m c 86 drive radial tunable = loop where 87 loop state prng = do 88 let rvar = execStateT (metropolis radial tunable) state 89 next <- lift (MWC.sample rvar prng) 90 yield next 91 loop next prng 92 93 -- | Return a list of @Chain@ values potentially with tunable values computed 94 -- from each position. 95 chain' :: 96 (PrimMonad m, Traversable f) 97 => Int 98 -> Double 99 -> f Double 100 -> (f Double -> Double) 101 -> Maybe (f Double -> b) 102 -> Gen (PrimState m) 103 -> m [Chain (f Double) b] 104 chain' n radial position target tunable gen = 105 runEffect $ drive radial tunable origin gen >-> collect n 106 where 107 ctarget = Target target Nothing 108 origin = Chain 109 { chainScore = lTarget ctarget position 110 , chainTunables = tunable <*> Just position 111 , chainTarget = ctarget 112 , chainPosition = position 113 } 114 collect :: Monad m => Int -> Consumer a m [a] 115 collect size = lowerCodensity $ replicateM size (lift Pipes.await) 116 117 -- | Trace 'n' iterations of a Markov chain and collect the results in a list. 118 -- 119 -- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) 120 -- >>> results <- withSystemRandom . asGenIO $ chain 3 1 [0, 0] rosenbrock 121 -- >>> mapM_ print results 122 -- 0.0,0.0 123 -- 1.4754117657794871e-2,0.5033208261760778 124 -- 3.8379699517007895e-3,0.24627131099479127 125 chain 126 :: (PrimMonad m, Traversable f) 127 => Int -- ^ Number of iterations 128 -> Double -- ^ Step standard deviation 129 -> f Double -- ^ Starting position 130 -> (f Double -> Double) -- ^ Log-density (to additive constant) 131 -> Gen (PrimState m) -- ^ PRNG 132 -> m [Chain (f Double) b] 133 chain n radial position target = chain' n radial position target Nothing 134 135 -- | Trace 'n' iterations of a Markov chain and stream them to stdout. 136 -- 137 -- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) 138 -- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock 139 -- 0.5000462419822702,0.5693944056267897 140 -- 0.5000462419822702,0.5693944056267897 141 -- -0.7525995304580824,1.2240725505283248 142 mcmc 143 :: (MonadIO m, PrimMonad m, Traversable f, Show (f Double)) 144 => Int 145 -> Double 146 -> f Double 147 -> (f Double -> Double) 148 -> Gen (PrimState m) 149 -> m () 150 mcmc n radial chainPosition target gen = runEffect $ 151 drive radial Nothing Chain {..} gen 152 >-> Pipes.take n 153 >-> Pipes.mapM_ (liftIO . print) 154 where 155 chainScore = lTarget chainTarget chainPosition 156 chainTunables = Nothing 157 chainTarget = Target target Nothing 158 159 -- Use a provided default value when the argument is NaN. 160 whenNaN :: RealFloat a => a -> a -> a 161 whenNaN val x 162 | isNaN x = val 163 | otherwise = x