flat-mcmc

Painless, efficient, general-purpose sampling from continuous distributions.
git clone git://git.jtobin.io/flat-mcmc.git
Log | Files | Refs | README | LICENSE

Flat.hs (6941B)


      1 {-# OPTIONS_GHC -fno-warn-type-defaults #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE OverloadedStrings #-}
      4 {-# LANGUAGE RecordWildCards #-}
      5 
      6 -- |
      7 -- Module: Numeric.MCMC.Flat
      8 -- Copyright: (c) 2016 Jared Tobin
      9 -- License: MIT
     10 --
     11 -- Maintainer: Jared Tobin <jared@jtobin.ca>
     12 -- Stability: unstable
     13 -- Portability: ghc
     14 --
     15 -- This is the 'affine invariant ensemble' or AIEMCMC algorithm described in
     16 -- Goodman and Weare, 2010.  It is a reasonably efficient and hassle-free
     17 -- sampler, requiring no mucking with tuning parameters or local proposal
     18 -- distributions.
     19 --
     20 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
     21 -- while the `flat` transition can be used for more flexible purposes,
     22 -- such as working with samples in memory.
     23 --
     24 -- See <http://msp.org/camcos/2010/5-1/camcos-v5-n1-p04-p.pdf> for the definitive
     25 -- reference of the algorithm.
     26 
     27 module Numeric.MCMC.Flat (
     28     mcmc
     29   , flat
     30   , Particle
     31   , Ensemble
     32   , Chain
     33 
     34   , module Sampling.Types
     35   , MWC.create
     36   , MWC.createSystemRandom
     37   , MWC.withSystemRandom
     38   , MWC.asGenIO
     39 
     40   , VE.ensemble
     41   , VE.particle
     42   ) where
     43 
     44 import Control.Monad (replicateM)
     45 import Control.Monad.IO.Class (MonadIO, liftIO)
     46 import Control.Monad.Par (NFData)
     47 import Control.Monad.Par.Combinator (parMap)
     48 import Control.Monad.Par.Scheds.Sparks hiding (get)
     49 import Control.Monad.Primitive (PrimMonad, PrimState)
     50 import Control.Monad.Trans.State.Strict (get, put, execStateT)
     51 import Data.Sampling.Types as Sampling.Types hiding (Chain(..))
     52 import qualified Data.Text as T
     53 import qualified Data.Text.IO as T (putStrLn)
     54 import Data.Vector (Vector)
     55 import qualified Data.Vector as V
     56 import qualified Data.Vector.Extended as VE (ensemble, particle)
     57 import qualified Data.Vector.Unboxed as U
     58 import Formatting ((%))
     59 import qualified Formatting as F
     60 import Pipes (Producer, lift, yield, runEffect, (>->))
     61 import qualified Pipes.Prelude as Pipes
     62 import System.Random.MWC.Probability as MWC
     63 
     64 data Chain = Chain {
     65     chainTarget   :: Target Particle
     66   , chainPosition :: !Ensemble
     67   }
     68 
     69 -- | Render a Chain as a text value.
     70 render :: Chain -> T.Text
     71 render Chain {..} = renderEnsemble chainPosition
     72 {-# INLINE render #-}
     73 
     74 renderParticle :: Particle -> T.Text
     75 renderParticle =
     76       T.drop 1
     77     . U.foldl' glue mempty
     78   where
     79     glue = F.sformat (F.stext % "," % F.float)
     80 {-# INLINE renderParticle #-}
     81 
     82 renderEnsemble :: Ensemble -> T.Text
     83 renderEnsemble =
     84       T.drop 1
     85     . V.foldl' glue mempty
     86   where
     87     glue a b = a <> "\n" <> renderParticle b
     88 {-# INLINE renderEnsemble #-}
     89 
     90 -- | A particle is an n-dimensional point in Euclidean space.
     91 --
     92 --   You can create a particle by using the 'particle' helper function, or just
     93 --   use Data.Vector.Unboxed.fromList.
     94 type Particle = U.Vector Double
     95 
     96 -- | An ensemble is a collection of particles.
     97 --
     98 --   The Markov chain we're interested in will run over the space of ensembles,
     99 --   so you'll want to build an ensemble out of a reasonable number of
    100 --   particles to kick off the chain.
    101 --
    102 --   You can create an ensemble by using the 'ensemble' helper function, or just
    103 --   use Data.Vector.fromList.
    104 type Ensemble = Vector Particle
    105 
    106 symmetric :: PrimMonad m => Prob m Double
    107 symmetric = fmap transform uniform where
    108   transform z = 0.5 * (z + 1) ^ (2 :: Int)
    109 {-# INLINE symmetric #-}
    110 
    111 stretch :: Particle -> Particle -> Double -> Particle
    112 stretch p0 p1 z = U.zipWith str p0 p1 where
    113   str x y = z * x + (1 - z) * y
    114 {-# INLINE stretch #-}
    115 
    116 acceptProb :: Target Particle -> Particle -> Particle -> Double -> Double
    117 acceptProb target particle proposal z =
    118     lTarget target proposal
    119   - lTarget target particle
    120   + log z * (fromIntegral (U.length particle) - 1)
    121 {-# INLINE acceptProb #-}
    122 
    123 move :: Target Particle -> Particle -> Particle -> Double -> Double -> Particle
    124 move target !p0 p1 z zc =
    125   let !proposal = stretch p0 p1 z
    126       pAccept  = acceptProb target p0 proposal z
    127   in  if   zc <= min 1 (exp pAccept)
    128       then proposal
    129       else p0
    130 {-# INLINE move #-}
    131 
    132 execute
    133   :: PrimMonad m
    134   => Target Particle
    135   -> Ensemble
    136   -> Ensemble
    137   -> Int
    138   -> Prob m Ensemble
    139 execute target e0 e1 n = do
    140   zs  <- replicateM n symmetric
    141   zcs <- replicateM n uniform
    142   js  <- U.replicateM n (uniformR (1, n))
    143 
    144   let granularity = n `div` 2
    145 
    146       w0 k    = e0 `V.unsafeIndex` pred k
    147       w1 k ks = e1 `V.unsafeIndex` pred (ks `U.unsafeIndex` pred k)
    148 
    149       worker (k, z, zc) = move target (w0 k) (w1 k js) z zc
    150       !result = runPar $
    151         parMapChunk granularity worker (zip3 [1..n] zs zcs)
    152 
    153   return $! V.fromList result
    154 {-# INLINE execute #-}
    155 
    156 -- | The 'flat' transition operator for driving a Markov chain over a space
    157 --   of ensembles.
    158 flat
    159   :: PrimMonad m
    160   => Transition m Chain
    161 flat = do
    162   Chain {..} <- get
    163   let size = V.length chainPosition
    164       n    = truncate (fromIntegral size / 2)
    165       e0   = V.unsafeSlice 0 n chainPosition
    166       e1   = V.unsafeSlice n n chainPosition
    167   result0 <- lift (execute chainTarget e0 e1 n)
    168   result1 <- lift (execute chainTarget e1 result0 n)
    169   let !ensemble = V.concat [result0, result1]
    170   put $! (Chain chainTarget ensemble)
    171 {-# INLINE flat #-}
    172 
    173 chain :: PrimMonad m => Chain -> Gen (PrimState m) -> Producer Chain m ()
    174 chain = loop where
    175   loop state prng = do
    176     next <- lift (MWC.sample (execStateT flat state) prng)
    177     yield next
    178     loop next prng
    179 {-# INLINE chain #-}
    180 
    181 -- | Trace 'n' iterations of a Markov chain and stream them to stdout.
    182 --
    183 --   Note that the Markov chain is defined over the space of ensembles, so
    184 --   you'll need to provide an ensemble of particles for the start location.
    185 --
    186 -- >>> import Numeric.MCMC.Flat
    187 -- >>> import Data.Vector.Unboxed (toList)
    188 -- >>> :{
    189 -- >>>   let rosenbrock xs = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
    190 --             where [x0, x1] = toList xs
    191 -- >>> :}
    192 -- >>> :{
    193 -- >>> let origin = ensemble [
    194 -- >>>       particle [negate 1.0, negate 1.0]
    195 -- >>>     , particle [negate 1.0, 1.0]
    196 -- >>>     , particle [1.0, negate 1.0]
    197 -- >>>     , particle [1.0, 1.0]
    198 -- >>>     ]
    199 -- >>> :}
    200 -- >>> withSystemRandom . asGenIO $ mcmc 2 origin rosenbrock
    201 -- -1.0,-1.0
    202 -- -1.0,1.0
    203 -- 1.0,-1.0
    204 -- 0.7049046915549257,0.7049046915549257
    205 -- -0.843493377618159,-0.843493377618159
    206 -- -1.1655594505975082,1.1655594505975082
    207 -- 0.5466534497342876,-0.9615123448709006
    208 -- 0.7049046915549257,0.7049046915549257
    209 mcmc
    210   :: (MonadIO m, PrimMonad m)
    211   => Int
    212   -> Ensemble
    213   -> (Particle -> Double)
    214   -> Gen (PrimState m)
    215   -> m ()
    216 mcmc n chainPosition target gen = runEffect $
    217         chain Chain {..} gen
    218     >-> Pipes.take n
    219     >-> Pipes.mapM_ (liftIO . T.putStrLn . render)
    220   where
    221     chainTarget = Target target Nothing
    222 {-# INLINE mcmc #-}
    223 
    224 -- A parallel map with the specified granularity.
    225 parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b]
    226 parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where
    227   chunk _ [] = []
    228   chunk m ys =
    229     let (as, bs) = splitAt m ys
    230     in  as : chunk m bs
    231 {-# INLINE parMapChunk #-}
    232