mighty-metropolis

The classic Metropolis algorithm.
git clone git://git.jtobin.io/mighty-metropolis.git
Log | Files | Refs | README | LICENSE

Metropolis.hs (5360B)


      1 {-# OPTIONS_GHC -Wall #-}
      2 {-# LANGUAGE CPP #-}
      3 {-# LANGUAGE RecordWildCards #-}
      4 
      5 -- |
      6 -- Module: Numeric.MCMC.Metropolis
      7 -- Copyright: (c) 2015 Jared Tobin
      8 -- License: MIT
      9 --
     10 -- Maintainer: Jared Tobin <jared@jtobin.ca>
     11 -- Stability: unstable
     12 -- Portability: ghc
     13 --
     14 -- This implementation uses spherical Gaussian proposals to implement a
     15 -- reliable and computationally inexpensive sampling routine.  It can be used
     16 -- as a baseline from which to benchmark other algorithms for a given problem.
     17 --
     18 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
     19 -- while the `metropolis` transition can be used for more flexible purposes,
     20 -- such as working with samples in memory.
     21 
     22 module Numeric.MCMC.Metropolis (
     23     mcmc
     24   , chain
     25   , chain'
     26   , metropolis
     27 
     28   -- * Re-exported
     29   , module Data.Sampling.Types
     30   , MWC.create
     31   , MWC.createSystemRandom
     32   , MWC.withSystemRandom
     33   , MWC.asGenIO
     34   ) where
     35 
     36 import Control.Monad (when, replicateM)
     37 import Control.Monad.Codensity (lowerCodensity)
     38 import Control.Monad.Primitive (PrimMonad, PrimState)
     39 import Control.Monad.Trans.Class (lift)
     40 import Control.Monad.Trans.State.Strict (execStateT, get, put)
     41 import Control.Monad.IO.Class (MonadIO, liftIO)
     42 import Data.Sampling.Types (Target(..), Chain(..), Transition)
     43 #if __GLASGOW_HASKELL__ < 710
     44 import Data.Traversable (Traversable, traverse)
     45 #endif
     46 import Pipes (Producer, Consumer, yield, (>->), runEffect, await)
     47 import qualified Pipes.Prelude as Pipes (mapM_, take)
     48 import System.Random.MWC.Probability (Gen, Prob)
     49 import qualified System.Random.MWC.Probability as MWC
     50 
     51 -- Propose a state transition according to a Gaussian proposal distribution
     52 -- with the specified standard deviation.
     53 propose
     54   :: (PrimMonad m, Traversable f)
     55   => Double
     56   -> f Double
     57   -> Prob m (f Double)
     58 propose radial = traverse perturb where
     59   perturb m = MWC.normal m radial
     60 
     61 -- | A generic Metropolis transition operator.
     62 metropolis
     63   :: (Traversable f, PrimMonad m)
     64   => Double
     65   -> Maybe (f Double -> b)
     66   -> Transition m (Chain (f Double) b)
     67 metropolis radial tunable = do
     68   Chain {..} <- get
     69   proposal <- lift (propose radial chainPosition)
     70   let proposalScore = lTarget chainTarget proposal
     71       acceptProb    = whenNaN 0 (exp (min 0 (proposalScore - chainScore)))
     72 
     73   accept <- lift (MWC.bernoulli acceptProb)
     74   when accept $ do
     75     let tuned = tunable <*> Just proposal
     76     put (Chain chainTarget proposalScore proposal tuned)
     77 
     78 -- Drive a Markov chain via the Metropolis transition operator.
     79 drive
     80   :: (Traversable f, PrimMonad m)
     81   => Double
     82   -> Maybe (f Double -> b)
     83   -> Chain (f Double) b
     84   -> Gen (PrimState m)
     85   -> Producer (Chain (f Double) b) m c
     86 drive radial tunable = loop where
     87   loop state prng = do
     88     let rvar = execStateT (metropolis radial tunable) state
     89     next <- lift (MWC.sample rvar prng)
     90     yield next
     91     loop next prng
     92 
     93 -- | Return a list of @Chain@ values potentially with tunable values computed
     94 --   from each position.
     95 chain' ::
     96      (PrimMonad m, Traversable f)
     97   => Int
     98   -> Double
     99   -> f Double
    100   -> (f Double -> Double)
    101   -> Maybe (f Double -> b)
    102   -> Gen (PrimState m)
    103   -> m [Chain (f Double) b]
    104 chain' n radial position target tunable gen =
    105     runEffect $ drive radial tunable origin gen >-> collect n
    106   where
    107     ctarget = Target target Nothing
    108     origin  = Chain
    109       { chainScore    = lTarget ctarget position
    110       , chainTunables = tunable <*> Just position
    111       , chainTarget   = ctarget
    112       , chainPosition = position
    113       }
    114     collect :: Monad m => Int -> Consumer a m [a]
    115     collect size = lowerCodensity $ replicateM size (lift Pipes.await)
    116 
    117 -- | Trace 'n' iterations of a Markov chain and collect the results in a list.
    118 --
    119 -- >>> let rosenbrock [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
    120 -- >>> results <- withSystemRandom . asGenIO $ chain 3 1 [0, 0] rosenbrock
    121 -- >>> mapM_ print results
    122 -- 0.0,0.0
    123 -- 1.4754117657794871e-2,0.5033208261760778
    124 -- 3.8379699517007895e-3,0.24627131099479127
    125 chain
    126   :: (PrimMonad m, Traversable f)
    127   => Int                           -- ^ Number of iterations
    128   -> Double                        -- ^ Step standard deviation
    129   -> f Double                      -- ^ Starting position
    130   -> (f Double -> Double)          -- ^ Log-density (to additive constant)
    131   -> Gen (PrimState m)             -- ^ PRNG
    132   -> m [Chain (f Double) b]
    133 chain n radial position target = chain' n radial position target Nothing
    134 
    135 -- | Trace 'n' iterations of a Markov chain and stream them to stdout.
    136 --
    137 -- >>> let rosenbrock [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
    138 -- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
    139 -- 0.5000462419822702,0.5693944056267897
    140 -- 0.5000462419822702,0.5693944056267897
    141 -- -0.7525995304580824,1.2240725505283248
    142 mcmc
    143   :: (MonadIO m, PrimMonad m, Traversable f, Show (f Double))
    144   => Int
    145   -> Double
    146   -> f Double
    147   -> (f Double -> Double)
    148   -> Gen (PrimState m)
    149   -> m ()
    150 mcmc n radial chainPosition target gen = runEffect $
    151         drive radial Nothing Chain {..} gen
    152     >-> Pipes.take n
    153     >-> Pipes.mapM_ (liftIO . print)
    154   where
    155     chainScore    = lTarget chainTarget chainPosition
    156     chainTunables = Nothing
    157     chainTarget   = Target target Nothing
    158 
    159 -- Use a provided default value when the argument is NaN.
    160 whenNaN :: RealFloat a => a -> a -> a
    161 whenNaN val x
    162   | isNaN x   = val
    163   | otherwise = x