declarative

DIY Markov Chains
git clone git://git.jtobin.io/declarative.git
Log | Files | Refs | README | LICENSE

MCMC.hs (6529B)


      1 {-# OPTIONS_GHC -Wall #-}
      2 {-# LANGUAGE RecordWildCards #-}
      3 {-# LANGUAGE FlexibleContexts #-}
      4 
      5 -- |
      6 -- Module: Numeric.MCMC
      7 -- Copyright: (c) 2015 Jared Tobin
      8 -- License: MIT
      9 --
     10 -- Maintainer: Jared Tobin <jared@jtobin.ca>
     11 -- Stability: unstable
     12 -- Portability: ghc
     13 --
     14 -- This module presents a simple combinator language for Markov transition
     15 -- operators that are useful in MCMC.
     16 --
     17 -- Any transition operators sharing the same stationary distribution and
     18 -- obeying the Markov and reversibility properties can be combined in a couple
     19 -- of ways, such that the resulting operator preserves the stationary
     20 -- distribution and desirable properties amenable for MCMC.
     21 --
     22 -- We can deterministically concatenate operators end-to-end, or sample from
     23 -- a collection of them according to some probability distribution.  See
     24 -- <www.stat.umn.edu/geyer/f05/8931/n1998.pdf Geyer, 2005> for details.
     25 --
     26 -- The result is a simple grammar for building composite, property-preserving
     27 -- transition operators from existing ones:
     28 --
     29 -- @
     30 -- transition ::= primitive <transition>
     31 --              | concatT transition transition
     32 --              | sampleT transition transition
     33 -- @
     34 --
     35 -- In addition to the above, this module provides a number of combinators for
     36 -- building composite transition operators.  It re-exports a number of
     37 -- production-quality transition operators from the /mighty-metropolis/,
     38 -- /speedy-slice/, and /hasty-hamiltonian/ libraries.
     39 --
     40 -- Markov chains can then be run over arbitrary 'Target's using whatever
     41 -- transition operator is desired.
     42 --
     43 -- > import Numeric.MCMC
     44 -- > import Data.Sampling.Types
     45 -- >
     46 -- > target :: [Double] -> Double
     47 -- > target [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
     48 -- >
     49 -- > rosenbrock :: Target [Double]
     50 -- > rosenbrock = Target target Nothing
     51 -- >
     52 -- > transition :: Transition IO (Chain [Double] b)
     53 -- > transition =
     54 -- >   concatT
     55 -- >     (sampleT (metropolis 0.5) (metropolis 1.0))
     56 -- >     (sampleT (slice 2.0) (slice 3.0))
     57 -- >
     58 -- > main :: IO ()
     59 -- > main = withSystemRandom . asGenIO $ mcmc 10000 [0, 0] transition rosenbrock
     60 --
     61 -- See the attached test suite for other examples.
     62 
     63 module Numeric.MCMC (
     64     concatT
     65   , concatAllT
     66   , sampleT
     67   , sampleAllT
     68   , bernoulliT
     69   , frequency
     70   , anneal
     71   , mcmc
     72   , chain
     73 
     74   -- * Re-exported
     75   , module Data.Sampling.Types
     76 
     77   , metropolis
     78   , hamiltonian
     79   , slice
     80 
     81   , MWC.create
     82   , MWC.createSystemRandom
     83   , MWC.withSystemRandom
     84   , MWC.asGenIO
     85 
     86   , PrimMonad
     87   , PrimState
     88   , RealWorld
     89   ) where
     90 
     91 import Control.Monad (replicateM)
     92 import Control.Monad.Codensity (lowerCodensity)
     93 import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld)
     94 import Control.Monad.Trans.State.Strict (execStateT)
     95 import Data.Sampling.Types
     96 import Numeric.MCMC.Anneal
     97 import qualified Numeric.MCMC.Metropolis as M (metropolis)
     98 import Numeric.MCMC.Hamiltonian (hamiltonian)
     99 import Numeric.MCMC.Slice (slice)
    100 import Pipes hiding (next)
    101 import qualified Pipes.Prelude as Pipes
    102 import System.Random.MWC.Probability (Gen)
    103 import qualified System.Random.MWC.Probability as MWC
    104 
    105 -- | Deterministically concat transition operators together.
    106 concatT :: Monad m => Transition m a -> Transition m a -> Transition m a
    107 concatT = (>>)
    108 
    109 -- | Deterministically concat a list of transition operators together.
    110 concatAllT :: Monad m => [Transition m a] -> Transition m a
    111 concatAllT = foldl1 (>>)
    112 
    113 -- | Probabilistically concat transition operators together.
    114 sampleT :: PrimMonad m => Transition m a -> Transition m a -> Transition m a
    115 sampleT = bernoulliT 0.5
    116 
    117 -- | Probabilistically concat transition operators together using a Bernoulli
    118 --   distribution with the supplied success probability.
    119 --
    120 --   This is just a generalization of sampleT.
    121 bernoulliT
    122   :: PrimMonad m
    123   => Double
    124   -> Transition m a
    125   -> Transition m a
    126   -> Transition m a
    127 bernoulliT p t0 t1 = do
    128   heads <- lift (MWC.bernoulli p)
    129   if heads then t0 else t1
    130 
    131 -- | Probabilistically concat transition operators together via a uniform
    132 --   distribution.
    133 sampleAllT :: PrimMonad m => [Transition m a] -> Transition m a
    134 sampleAllT ts = do
    135   j <- lift (MWC.uniformR (0, length ts - 1))
    136   ts !! j
    137 
    138 -- | Probabilistically concat transition operators together using the supplied
    139 --   frequency distribution.
    140 --
    141 --   This function is more-or-less an exact copy of 'QuickCheck.frequency',
    142 --   except here applied to transition operators.
    143 frequency :: PrimMonad m => [(Int, Transition m a)] -> Transition m a
    144 frequency xs = lift (MWC.uniformR (1, tot)) >>= (`pick` xs) where
    145   tot = sum . map fst $ xs
    146   pick n ((k, v):vs)
    147     | n <= k = v
    148     | otherwise = pick (n - k) vs
    149   pick _ _ = error "frequency: no distribution specified"
    150 
    151 -- | Trace 'n' iterations of a Markov chain and stream them to stdout.
    152 --
    153 -- >>> withSystemRandom . asGenIO $ mcmc 3 [0, 0] (metropolis 0.5) rosenbrock
    154 -- -0.48939312153007863,0.13290702689491818
    155 -- 1.4541485365128892e-2,-0.4859905564050404
    156 -- 0.22487398491619448,-0.29769783186855125
    157 mcmc
    158   :: (MonadIO m, PrimMonad m, Show (t a))
    159   => Int
    160   -> t a
    161   -> Transition m (Chain (t a) b)
    162   -> Target (t a)
    163   -> Gen (PrimState m)
    164   -> m ()
    165 mcmc n chainPosition transition chainTarget gen = runEffect $
    166         drive transition Chain {..} gen
    167     >-> Pipes.take n
    168     >-> Pipes.mapM_ (liftIO . print)
    169   where
    170     chainScore    = lTarget chainTarget chainPosition
    171     chainTunables = Nothing
    172 
    173 -- | Trace 'n' iterations of a Markov chain and collect them in a list.
    174 --
    175 -- >>> results <- withSystemRandom . asGenIO $ chain 3 [0, 0] (metropolis 0.5) rosenbrock
    176 chain
    177   :: (MonadIO m, PrimMonad m)
    178   => Int
    179   -> t a
    180   -> Transition m (Chain (t a) b)
    181   -> Target (t a)
    182   -> Gen (PrimState m)
    183   -> m [Chain (t a) b]
    184 chain n chainPosition transition chainTarget gen = runEffect $
    185         drive transition Chain {..} gen
    186     >-> collect n
    187   where
    188     chainScore    = lTarget chainTarget chainPosition
    189     chainTunables = Nothing
    190 
    191     collect :: Monad m => Int -> Consumer a m [a]
    192     collect size = lowerCodensity $
    193       replicateM size (lift Pipes.await)
    194 
    195 -- A Markov chain driven by an arbitrary transition operator.
    196 drive
    197   :: PrimMonad m
    198   => Transition m b
    199   -> b
    200   -> Gen (PrimState m)
    201   -> Producer b m a
    202 drive transition = loop where
    203   loop state prng = do
    204     next <- lift (MWC.sample (execStateT transition state) prng)
    205     yield next
    206     loop next prng
    207 
    208 -- | A generic Metropolis transition operator.
    209 metropolis
    210   :: (Traversable f, PrimMonad m)
    211   => Double
    212   -> Transition m (Chain (f Double) b)
    213 metropolis radial = M.metropolis radial Nothing