      1 {-# OPTIONS_GHC -Wall #-}
      2 {-# LANGUAGE RecordWildCards #-}
      3 {-# LANGUAGE FlexibleContexts #-}
      5 -- |
      6 -- Module: Numeric.MCMC
      7 -- Copyright: (c) 2015 Jared Tobin
      8 -- License: MIT
      9 --
     10 -- Maintainer: Jared Tobin <>
     11 -- Stability: unstable
     12 -- Portability: ghc
     13 --
     14 -- This module presents a simple combinator language for Markov transition
     15 -- operators that are useful in MCMC.
     16 --
     17 -- Any transition operators sharing the same stationary distribution and
     18 -- obeying the Markov and reversibility properties can be combined in a couple
     19 -- of ways, such that the resulting operator preserves the stationary
     20 -- distribution and desirable properties amenable for MCMC.
     21 --
     22 -- We can deterministically concatenate operators end-to-end, or sample from
     23 -- a collection of them according to some probability distribution.  See
     24 -- < Geyer, 2005> for details.
     25 --
     26 -- The result is a simple grammar for building composite, property-preserving
     27 -- transition operators from existing ones:
     28 --
     29 -- @
     30 -- transition ::= primitive <transition>
     31 --              | concatT transition transition
     32 --              | sampleT transition transition
     33 -- @
     34 --
     35 -- In addition to the above, this module provides a number of combinators for
     36 -- building composite transition operators.  It re-exports a number of
     37 -- production-quality transition operators from the /mighty-metropolis/,
     38 -- /speedy-slice/, and /hasty-hamiltonian/ libraries.
     39 --
     40 -- Markov chains can then be run over arbitrary 'Target's using whatever
     41 -- transition operator is desired.
     42 --
     43 -- > import Numeric.MCMC
     44 -- > import Data.Sampling.Types
     45 -- >
     46 -- > target :: [Double] -> Double
     47 -- > target [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
     48 -- >
     49 -- > rosenbrock :: Target [Double]
     50 -- > rosenbrock = Target target Nothing
     51 -- >
     52 -- > transition :: Transition IO (Chain [Double] b)
     53 -- > transition =
     54 -- >   concatT
     55 -- >     (sampleT (metropolis 0.5) (metropolis 1.0))
     56 -- >     (sampleT (slice 2.0) (slice 3.0))
     57 -- >
     58 -- > main :: IO ()
     59 -- > main = withSystemRandom . asGenIO $ mcmc 10000 [0, 0] transition rosenbrock
     60 --
     61 -- See the attached test suite for other examples.
     63 module Numeric.MCMC (
     64     concatT
     65   , concatAllT
     66   , sampleT
     67   , sampleAllT
     68   , bernoulliT
     69   , frequency
     70   , anneal
     71   , mcmc
     72   , chain
     74   -- * Re-exported
     75   , module Data.Sampling.Types
     77   , metropolis
     78   , hamiltonian
     79   , slice
     81   , MWC.create
     82   , MWC.createSystemRandom
     83   , MWC.withSystemRandom
     84   , MWC.asGenIO
     86   , PrimMonad
     87   , PrimState
     88   , RealWorld
     89   ) where
     91 import Control.Monad (replicateM)
     92 import Control.Monad.Codensity (lowerCodensity)
     93 import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld)
     94 import Control.Monad.Trans.State.Strict (execStateT)
     95 import Data.Sampling.Types
     96 import Numeric.MCMC.Anneal
     97 import qualified Numeric.MCMC.Metropolis as M (metropolis)
     98 import Numeric.MCMC.Hamiltonian (hamiltonian)
     99 import Numeric.MCMC.Slice (slice)
    100 import Pipes hiding (next)
    101 import qualified Pipes.Prelude as Pipes
    102 import System.Random.MWC.Probability (Gen)
    103 import qualified System.Random.MWC.Probability as MWC
    105 -- | Deterministically concat transition operators together.
    106 concatT :: Monad m => Transition m a -> Transition m a -> Transition m a
    107 concatT = (>>)
    109 -- | Deterministically concat a list of transition operators together.
    110 concatAllT :: Monad m => [Transition m a] -> Transition m a
    111 concatAllT = foldl1 (>>)
    113 -- | Probabilistically concat transition operators together.
    114 sampleT :: PrimMonad m => Transition m a -> Transition m a -> Transition m a
    115 sampleT = bernoulliT 0.5
    117 -- | Probabilistically concat transition operators together using a Bernoulli
    118 --   distribution with the supplied success probability.
    119 --
    120 --   This is just a generalization of sampleT.
    121 bernoulliT
    122   :: PrimMonad m
    123   => Double
    124   -> Transition m a
    125   -> Transition m a
    126   -> Transition m a
    127 bernoulliT p t0 t1 = do
    128   heads <- lift (MWC.bernoulli p)
    129   if heads then t0 else t1
    131 -- | Probabilistically concat transition operators together via a uniform
    132 --   distribution.
    133 sampleAllT :: PrimMonad m => [Transition m a] -> Transition m a
    134 sampleAllT ts = do
    135   j <- lift (MWC.uniformR (0, length ts - 1))
    136   ts !! j
    138 -- | Probabilistically concat transition operators together using the supplied
    139 --   frequency distribution.
    140 --
    141 --   This function is more-or-less an exact copy of 'QuickCheck.frequency',
    142 --   except here applied to transition operators.
    143 frequency :: PrimMonad m => [(Int, Transition m a)] -> Transition m a
    144 frequency xs = lift (MWC.uniformR (1, tot)) >>= (`pick` xs) where
    145   tot = sum . map fst $ xs
    146   pick n ((k, v):vs)
    147     | n <= k = v
    148     | otherwise = pick (n - k) vs
    149   pick _ _ = error "frequency: no distribution specified"
    151 -- | Trace 'n' iterations of a Markov chain and stream them to stdout.
    152 --
    153 -- >>> withSystemRandom . asGenIO $ mcmc 3 [0, 0] (metropolis 0.5) rosenbrock
    154 -- -0.48939312153007863,0.13290702689491818
    155 -- 1.4541485365128892e-2,-0.4859905564050404
    156 -- 0.22487398491619448,-0.29769783186855125
    157 mcmc
    158   :: (MonadIO m, PrimMonad m, Show (t a))
    159   => Int
    160   -> t a
    161   -> Transition m (Chain (t a) b)
    162   -> Target (t a)
    163   -> Gen (PrimState m)
    164   -> m ()
    165 mcmc n chainPosition transition chainTarget gen = runEffect $
    166         drive transition Chain {..} gen
    167     >-> Pipes.take n
    168     >-> Pipes.mapM_ (liftIO . print)
    169   where
    170     chainScore    = lTarget chainTarget chainPosition
    171     chainTunables = Nothing
    173 -- | Trace 'n' iterations of a Markov chain and collect them in a list.
    174 --
    175 -- >>> results <- withSystemRandom . asGenIO $ chain 3 [0, 0] (metropolis 0.5) rosenbrock
    176 chain
    177   :: (MonadIO m, PrimMonad m)
    178   => Int
    179   -> t a
    180   -> Transition m (Chain (t a) b)
    181   -> Target (t a)
    182   -> Gen (PrimState m)
    183   -> m [Chain (t a) b]
    184 chain n chainPosition transition chainTarget gen = runEffect $
    185         drive transition Chain {..} gen
    186     >-> collect n
    187   where
    188     chainScore    = lTarget chainTarget chainPosition
    189     chainTunables = Nothing
    191     collect :: Monad m => Int -> Consumer a m [a]
    192     collect size = lowerCodensity $
    193       replicateM size (lift Pipes.await)
    195 -- A Markov chain driven by an arbitrary transition operator.
    196 drive
    197   :: PrimMonad m
    198   => Transition m b
    199   -> b
    200   -> Gen (PrimState m)
    201   -> Producer b m a
    202 drive transition = loop where
    203   loop state prng = do
    204     next <- lift (MWC.sample (execStateT transition state) prng)
    205     yield next
    206     loop next prng
    208 -- | A generic Metropolis transition operator.
    209 metropolis
    210   :: (Traversable f, PrimMonad m)
    211   => Double
    212   -> Transition m (Chain (f Double) b)
    213 metropolis radial = M.metropolis radial Nothing