Slice.hs (6246B)
1 {-# OPTIONS_GHC -Wall #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE RecordWildCards #-} 4 {-# LANGUAGE FlexibleContexts #-} 5 {-# LANGUAGE NoMonomorphismRestriction #-} 6 {-# LANGUAGE TypeFamilies #-} 7 8 -- | 9 -- Module: Numeric.MCMC.Slice 10 -- Copyright: (c) 2015 Jared Tobin 11 -- License: MIT 12 -- 13 -- Maintainer: Jared Tobin <jared@jtobin.ca> 14 -- Stability: unstable 15 -- Portability: ghc 16 -- 17 -- This implementation performs slice sampling by first finding a bracket about 18 -- a mode (using a simple doubling heuristic), and then doing rejection 19 -- sampling along it. The result is a reliable and computationally inexpensive 20 -- sampling routine. 21 -- 22 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere, 23 -- while the `slice` transition can be used for more flexible purposes, such as 24 -- working with samples in memory. 25 -- 26 -- See <http://people.ee.duke.edu/~lcarin/slice.pdf Neal, 2003> for the 27 -- definitive reference of the algorithm. 28 29 module Numeric.MCMC.Slice ( 30 mcmc 31 , chain 32 , slice 33 34 -- * Re-exported 35 , MWC.create 36 , MWC.createSystemRandom 37 , MWC.withSystemRandom 38 , MWC.asGenIO 39 ) where 40 41 import Control.Monad (replicateM) 42 import Control.Monad.Codensity (lowerCodensity) 43 import Control.Monad.Trans.State.Strict (put, get, execStateT) 44 import Control.Monad.Primitive (PrimMonad, PrimState) 45 import Control.Lens hiding (index) 46 import Data.Maybe (fromMaybe) 47 import Data.Sampling.Types 48 import Pipes hiding (next) 49 import qualified Pipes.Prelude as Pipes 50 import System.Random.MWC.Probability (Prob, Gen, Variate) 51 import qualified System.Random.MWC.Probability as MWC 52 53 -- | Trace 'n' iterations of a Markov chain and stream them to stdout. 54 -- 55 -- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) 56 -- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock 57 -- -3.854097694213343e-2,0.16688601288358407 58 -- -9.310661272172682e-2,0.2562387977415508 59 -- -0.48500122500661846,0.46245400501919076 60 mcmc 61 :: (MonadIO m, PrimMonad m, 62 Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a), 63 Num (IxValue (t a)), Variate (IxValue (t a))) 64 => Int 65 -> IxValue (t a) 66 -> t a 67 -> (t a -> Double) 68 -> Gen (PrimState m) 69 -> m () 70 mcmc n radial chainPosition target gen = runEffect $ 71 drive radial Chain {..} gen 72 >-> Pipes.take n 73 >-> Pipes.mapM_ (liftIO . print) 74 where 75 chainScore = lTarget chainTarget chainPosition 76 chainTunables = Nothing 77 chainTarget = Target target Nothing 78 79 -- | Trace 'n' iterations of a Markov chain and collect them in a list. 80 -- 81 -- >>> results <- withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock 82 chain 83 :: (PrimMonad m, FoldableWithIndex (Index (f a)) f, Ixed (f a) 84 , Variate (IxValue (f a)), Num (IxValue (f a))) 85 => Int 86 -> IxValue (f a) 87 -> f a 88 -> (f a -> Double) 89 -> Gen (PrimState m) 90 -> m [Chain (f a) b] 91 chain n radial position target gen = runEffect $ 92 drive radial origin gen 93 >-> collect n 94 where 95 ctarget = Target target Nothing 96 97 origin = Chain { 98 chainScore = lTarget ctarget position 99 , chainTunables = Nothing 100 , chainTarget = ctarget 101 , chainPosition = position 102 } 103 104 collect :: Monad m => Int -> Consumer a m [a] 105 collect size = lowerCodensity $ 106 replicateM size (lift Pipes.await) 107 108 -- A Markov chain driven by the slice transition operator. 109 drive 110 :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a), 111 Num (IxValue (t a)), Variate (IxValue (t a))) 112 => IxValue (t a) 113 -> Chain (t a) b 114 -> Gen (PrimState m) 115 -> Producer (Chain (t a) b) m c 116 drive radial = loop where 117 loop state prng = do 118 next <- lift (MWC.sample (execStateT (slice radial) state) prng) 119 yield next 120 loop next prng 121 122 -- | A slice sampling transition operator. 123 slice 124 :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a), 125 Num (IxValue (t a)), Variate (IxValue (t a))) 126 => IxValue (t a) 127 -> Transition m (Chain (t a) b) 128 slice step = do 129 Chain _ _ position _ <- get 130 ifor_ position $ \index _ -> do 131 Chain {..} <- get 132 let bounds = (0, exp (lTarget chainTarget chainPosition)) 133 height <- lift (fmap log (MWC.uniformR bounds)) 134 135 let bracket = 136 findBracket (lTarget chainTarget) index step height chainPosition 137 138 perturbed <- lift $ 139 rejection (lTarget chainTarget) index bracket height chainPosition 140 141 let perturbedScore = lTarget chainTarget perturbed 142 put (Chain chainTarget perturbedScore perturbed chainTunables) 143 144 -- Find a bracket by expanding its bounds through powers of 2. 145 findBracket 146 :: (Ord a, Ixed s, Num (IxValue s)) 147 => (s -> a) 148 -> Index s 149 -> IxValue s 150 -> a 151 -> s 152 -> (IxValue s, IxValue s) 153 findBracket target index step height xs = go step xs xs where 154 err = error "findBracket: invalid index -- please report this as a bug!" 155 go !e !bl !br 156 | target bl < height && target br < height = 157 let l = fromMaybe err (bl ^? ix index) 158 r = fromMaybe err (br ^? ix index) 159 in (l, r) 160 | target bl < height && target br >= height = 161 let br0 = expandBracketRight index step br 162 in go (2 * e) bl br0 163 | target bl >= height && target br < height = 164 let bl0 = expandBracketLeft index step bl 165 in go (2 * e) bl0 br 166 | otherwise = 167 let bl0 = expandBracketLeft index step bl 168 br0 = expandBracketRight index step br 169 in go (2 * e) bl0 br0 170 171 expandBracketLeft 172 :: (Ixed s, Num (IxValue s)) 173 => Index s 174 -> IxValue s 175 -> s 176 -> s 177 expandBracketLeft = expandBracketBy (-) 178 179 expandBracketRight 180 :: (Ixed s, Num (IxValue s)) 181 => Index s 182 -> IxValue s 183 -> s 184 -> s 185 expandBracketRight = expandBracketBy (+) 186 187 expandBracketBy 188 :: Ixed s 189 => (IxValue s -> t -> IxValue s) 190 -> Index s 191 -> t 192 -> s 193 -> s 194 expandBracketBy f index step xs = xs & ix index %~ (`f` step ) 195 196 -- Perform rejection sampling within the supplied bracket. 197 rejection 198 :: (Ord a, PrimMonad m, Ixed b, Variate (IxValue b)) 199 => (b -> a) 200 -> Index b 201 -> (IxValue b, IxValue b) 202 -> a 203 -> b 204 -> Prob m b 205 rejection target dimension bracket height = go where 206 go zs = do 207 u <- MWC.uniformR bracket 208 let updated = zs & ix dimension .~ u 209 if target updated < height 210 then go updated 211 else return updated 212