Flat.hs (6941B)
1 {-# OPTIONS_GHC -fno-warn-type-defaults #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE OverloadedStrings #-} 4 {-# LANGUAGE RecordWildCards #-} 5 6 -- | 7 -- Module: Numeric.MCMC.Flat 8 -- Copyright: (c) 2016 Jared Tobin 9 -- License: MIT 10 -- 11 -- Maintainer: Jared Tobin <jared@jtobin.ca> 12 -- Stability: unstable 13 -- Portability: ghc 14 -- 15 -- This is the 'affine invariant ensemble' or AIEMCMC algorithm described in 16 -- Goodman and Weare, 2010. It is a reasonably efficient and hassle-free 17 -- sampler, requiring no mucking with tuning parameters or local proposal 18 -- distributions. 19 -- 20 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere, 21 -- while the `flat` transition can be used for more flexible purposes, 22 -- such as working with samples in memory. 23 -- 24 -- See <http://msp.org/camcos/2010/5-1/camcos-v5-n1-p04-p.pdf> for the definitive 25 -- reference of the algorithm. 26 27 module Numeric.MCMC.Flat ( 28 mcmc 29 , flat 30 , Particle 31 , Ensemble 32 , Chain 33 34 , module Sampling.Types 35 , MWC.create 36 , MWC.createSystemRandom 37 , MWC.withSystemRandom 38 , MWC.asGenIO 39 40 , VE.ensemble 41 , VE.particle 42 ) where 43 44 import Control.Monad (replicateM) 45 import Control.Monad.IO.Class (MonadIO, liftIO) 46 import Control.Monad.Par (NFData) 47 import Control.Monad.Par.Combinator (parMap) 48 import Control.Monad.Par.Scheds.Sparks hiding (get) 49 import Control.Monad.Primitive (PrimMonad, PrimState) 50 import Control.Monad.Trans.State.Strict (get, put, execStateT) 51 import Data.Sampling.Types as Sampling.Types hiding (Chain(..)) 52 import qualified Data.Text as T 53 import qualified Data.Text.IO as T (putStrLn) 54 import Data.Vector (Vector) 55 import qualified Data.Vector as V 56 import qualified Data.Vector.Extended as VE (ensemble, particle) 57 import qualified Data.Vector.Unboxed as U 58 import Formatting ((%)) 59 import qualified Formatting as F 60 import Pipes (Producer, lift, yield, runEffect, (>->)) 61 import qualified Pipes.Prelude as Pipes 62 import System.Random.MWC.Probability as MWC 63 64 data Chain = Chain { 65 chainTarget :: Target Particle 66 , chainPosition :: !Ensemble 67 } 68 69 -- | Render a Chain as a text value. 70 render :: Chain -> T.Text 71 render Chain {..} = renderEnsemble chainPosition 72 {-# INLINE render #-} 73 74 renderParticle :: Particle -> T.Text 75 renderParticle = 76 T.drop 1 77 . U.foldl' glue mempty 78 where 79 glue = F.sformat (F.stext % "," % F.float) 80 {-# INLINE renderParticle #-} 81 82 renderEnsemble :: Ensemble -> T.Text 83 renderEnsemble = 84 T.drop 1 85 . V.foldl' glue mempty 86 where 87 glue a b = a <> "\n" <> renderParticle b 88 {-# INLINE renderEnsemble #-} 89 90 -- | A particle is an n-dimensional point in Euclidean space. 91 -- 92 -- You can create a particle by using the 'particle' helper function, or just 93 -- use Data.Vector.Unboxed.fromList. 94 type Particle = U.Vector Double 95 96 -- | An ensemble is a collection of particles. 97 -- 98 -- The Markov chain we're interested in will run over the space of ensembles, 99 -- so you'll want to build an ensemble out of a reasonable number of 100 -- particles to kick off the chain. 101 -- 102 -- You can create an ensemble by using the 'ensemble' helper function, or just 103 -- use Data.Vector.fromList. 104 type Ensemble = Vector Particle 105 106 symmetric :: PrimMonad m => Prob m Double 107 symmetric = fmap transform uniform where 108 transform z = 0.5 * (z + 1) ^ (2 :: Int) 109 {-# INLINE symmetric #-} 110 111 stretch :: Particle -> Particle -> Double -> Particle 112 stretch p0 p1 z = U.zipWith str p0 p1 where 113 str x y = z * x + (1 - z) * y 114 {-# INLINE stretch #-} 115 116 acceptProb :: Target Particle -> Particle -> Particle -> Double -> Double 117 acceptProb target particle proposal z = 118 lTarget target proposal 119 - lTarget target particle 120 + log z * (fromIntegral (U.length particle) - 1) 121 {-# INLINE acceptProb #-} 122 123 move :: Target Particle -> Particle -> Particle -> Double -> Double -> Particle 124 move target !p0 p1 z zc = 125 let !proposal = stretch p0 p1 z 126 pAccept = acceptProb target p0 proposal z 127 in if zc <= min 1 (exp pAccept) 128 then proposal 129 else p0 130 {-# INLINE move #-} 131 132 execute 133 :: PrimMonad m 134 => Target Particle 135 -> Ensemble 136 -> Ensemble 137 -> Int 138 -> Prob m Ensemble 139 execute target e0 e1 n = do 140 zs <- replicateM n symmetric 141 zcs <- replicateM n uniform 142 js <- U.replicateM n (uniformR (1, n)) 143 144 let granularity = n `div` 2 145 146 w0 k = e0 `V.unsafeIndex` pred k 147 w1 k ks = e1 `V.unsafeIndex` pred (ks `U.unsafeIndex` pred k) 148 149 worker (k, z, zc) = move target (w0 k) (w1 k js) z zc 150 !result = runPar $ 151 parMapChunk granularity worker (zip3 [1..n] zs zcs) 152 153 return $! V.fromList result 154 {-# INLINE execute #-} 155 156 -- | The 'flat' transition operator for driving a Markov chain over a space 157 -- of ensembles. 158 flat 159 :: PrimMonad m 160 => Transition m Chain 161 flat = do 162 Chain {..} <- get 163 let size = V.length chainPosition 164 n = truncate (fromIntegral size / 2) 165 e0 = V.unsafeSlice 0 n chainPosition 166 e1 = V.unsafeSlice n n chainPosition 167 result0 <- lift (execute chainTarget e0 e1 n) 168 result1 <- lift (execute chainTarget e1 result0 n) 169 let !ensemble = V.concat [result0, result1] 170 put $! (Chain chainTarget ensemble) 171 {-# INLINE flat #-} 172 173 chain :: PrimMonad m => Chain -> Gen (PrimState m) -> Producer Chain m () 174 chain = loop where 175 loop state prng = do 176 next <- lift (MWC.sample (execStateT flat state) prng) 177 yield next 178 loop next prng 179 {-# INLINE chain #-} 180 181 -- | Trace 'n' iterations of a Markov chain and stream them to stdout. 182 -- 183 -- Note that the Markov chain is defined over the space of ensembles, so 184 -- you'll need to provide an ensemble of particles for the start location. 185 -- 186 -- >>> import Numeric.MCMC.Flat 187 -- >>> import Data.Vector.Unboxed (toList) 188 -- >>> :{ 189 -- >>> let rosenbrock xs = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) 190 -- where [x0, x1] = toList xs 191 -- >>> :} 192 -- >>> :{ 193 -- >>> let origin = ensemble [ 194 -- >>> particle [negate 1.0, negate 1.0] 195 -- >>> , particle [negate 1.0, 1.0] 196 -- >>> , particle [1.0, negate 1.0] 197 -- >>> , particle [1.0, 1.0] 198 -- >>> ] 199 -- >>> :} 200 -- >>> withSystemRandom . asGenIO $ mcmc 2 origin rosenbrock 201 -- -1.0,-1.0 202 -- -1.0,1.0 203 -- 1.0,-1.0 204 -- 0.7049046915549257,0.7049046915549257 205 -- -0.843493377618159,-0.843493377618159 206 -- -1.1655594505975082,1.1655594505975082 207 -- 0.5466534497342876,-0.9615123448709006 208 -- 0.7049046915549257,0.7049046915549257 209 mcmc 210 :: (MonadIO m, PrimMonad m) 211 => Int 212 -> Ensemble 213 -> (Particle -> Double) 214 -> Gen (PrimState m) 215 -> m () 216 mcmc n chainPosition target gen = runEffect $ 217 chain Chain {..} gen 218 >-> Pipes.take n 219 >-> Pipes.mapM_ (liftIO . T.putStrLn . render) 220 where 221 chainTarget = Target target Nothing 222 {-# INLINE mcmc #-} 223 224 -- A parallel map with the specified granularity. 225 parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b] 226 parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where 227 chunk _ [] = [] 228 chunk m ys = 229 let (as, bs) = splitAt m ys 230 in as : chunk m bs 231 {-# INLINE parMapChunk #-} 232