speedy-slice

Speedy slice sampling.
git clone git://git.jtobin.io/speedy-slice.git
Log | Files | Refs | README | LICENSE

Slice.hs (6246B)


      1 {-# OPTIONS_GHC -Wall #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE RecordWildCards #-}
      4 {-# LANGUAGE FlexibleContexts #-}
      5 {-# LANGUAGE NoMonomorphismRestriction #-}
      6 {-# LANGUAGE TypeFamilies #-}
      7 
      8 -- |
      9 -- Module: Numeric.MCMC.Slice
     10 -- Copyright: (c) 2015 Jared Tobin
     11 -- License: MIT
     12 --
     13 -- Maintainer: Jared Tobin <jared@jtobin.ca>
     14 -- Stability: unstable
     15 -- Portability: ghc
     16 --
     17 -- This implementation performs slice sampling by first finding a bracket about
     18 -- a mode (using a simple doubling heuristic), and then doing rejection
     19 -- sampling along it.  The result is a reliable and computationally inexpensive
     20 -- sampling routine.
     21 --
     22 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
     23 -- while the `slice` transition can be used for more flexible purposes, such as
     24 -- working with samples in memory.
     25 --
     26 -- See <http://people.ee.duke.edu/~lcarin/slice.pdf Neal, 2003> for the
     27 -- definitive reference of the algorithm.
     28 
     29 module Numeric.MCMC.Slice (
     30     mcmc
     31   , chain
     32   , slice
     33 
     34   -- * Re-exported
     35   , MWC.create
     36   , MWC.createSystemRandom
     37   , MWC.withSystemRandom
     38   , MWC.asGenIO
     39   ) where
     40 
     41 import Control.Monad (replicateM)
     42 import Control.Monad.Codensity (lowerCodensity)
     43 import Control.Monad.Trans.State.Strict (put, get, execStateT)
     44 import Control.Monad.Primitive (PrimMonad, PrimState)
     45 import Control.Lens hiding (index)
     46 import Data.Maybe (fromMaybe)
     47 import Data.Sampling.Types
     48 import Pipes hiding (next)
     49 import qualified Pipes.Prelude as Pipes
     50 import System.Random.MWC.Probability (Prob, Gen, Variate)
     51 import qualified System.Random.MWC.Probability as MWC
     52 
     53 -- | Trace 'n' iterations of a Markov chain and stream them to stdout.
     54 --
     55 -- >>> let rosenbrock [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
     56 -- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
     57 -- -3.854097694213343e-2,0.16688601288358407
     58 -- -9.310661272172682e-2,0.2562387977415508
     59 -- -0.48500122500661846,0.46245400501919076
     60 mcmc
     61   :: (MonadIO m, PrimMonad m,
     62      Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a),
     63      Num (IxValue (t a)), Variate (IxValue (t a)))
     64   => Int
     65   -> IxValue (t a)
     66   -> t a
     67   -> (t a -> Double)
     68   -> Gen (PrimState m)
     69   -> m ()
     70 mcmc n radial chainPosition target gen = runEffect $
     71         drive radial Chain {..} gen
     72     >-> Pipes.take n
     73     >-> Pipes.mapM_ (liftIO . print)
     74   where
     75     chainScore    = lTarget chainTarget chainPosition
     76     chainTunables = Nothing
     77     chainTarget   = Target target Nothing
     78 
     79 -- | Trace 'n' iterations of a Markov chain and collect them in a list.
     80 --
     81 -- >>> results <- withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
     82 chain
     83   :: (PrimMonad m, FoldableWithIndex (Index (f a)) f, Ixed (f a)
     84      , Variate (IxValue (f a)), Num (IxValue (f a)))
     85   => Int
     86   -> IxValue (f a)
     87   -> f a
     88   -> (f a -> Double)
     89   -> Gen (PrimState m)
     90   -> m [Chain (f a) b]
     91 chain n radial position target gen = runEffect $
     92         drive radial origin gen
     93     >-> collect n
     94   where
     95     ctarget = Target target Nothing
     96 
     97     origin = Chain {
     98         chainScore    = lTarget ctarget position
     99       , chainTunables = Nothing
    100       , chainTarget   = ctarget
    101       , chainPosition = position
    102       }
    103 
    104     collect :: Monad m => Int -> Consumer a m [a]
    105     collect size = lowerCodensity $
    106       replicateM size (lift Pipes.await)
    107 
    108 -- A Markov chain driven by the slice transition operator.
    109 drive
    110   :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
    111      Num (IxValue (t a)), Variate (IxValue (t a)))
    112   => IxValue (t a)
    113   -> Chain (t a) b
    114   -> Gen (PrimState m)
    115   -> Producer (Chain (t a) b) m c
    116 drive radial = loop where
    117   loop state prng = do
    118     next <- lift (MWC.sample (execStateT (slice radial) state) prng)
    119     yield next
    120     loop next prng
    121 
    122 -- | A slice sampling transition operator.
    123 slice
    124   :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
    125       Num (IxValue (t a)), Variate (IxValue (t a)))
    126   => IxValue (t a)
    127   -> Transition m (Chain (t a) b)
    128 slice step = do
    129   Chain _ _ position _ <- get
    130   ifor_ position $ \index _ -> do
    131     Chain {..} <- get
    132     let bounds = (0, exp (lTarget chainTarget chainPosition))
    133     height    <- lift (fmap log (MWC.uniformR bounds))
    134 
    135     let bracket =
    136           findBracket (lTarget chainTarget) index step height chainPosition
    137 
    138     perturbed <- lift $
    139       rejection (lTarget chainTarget) index bracket height chainPosition
    140 
    141     let perturbedScore = lTarget chainTarget perturbed
    142     put (Chain chainTarget perturbedScore perturbed chainTunables)
    143 
    144 -- Find a bracket by expanding its bounds through powers of 2.
    145 findBracket
    146   :: (Ord a, Ixed s, Num (IxValue s))
    147   => (s -> a)
    148   -> Index s
    149   -> IxValue s
    150   -> a
    151   -> s
    152   -> (IxValue s, IxValue s)
    153 findBracket target index step height xs = go step xs xs where
    154   err = error "findBracket: invalid index -- please report this as a bug!"
    155   go !e !bl !br
    156     | target bl < height && target br < height =
    157         let l = fromMaybe err (bl ^? ix index)
    158             r = fromMaybe err (br ^? ix index)
    159         in  (l, r)
    160     | target bl < height && target br >= height =
    161         let br0 = expandBracketRight index step br
    162         in  go (2 * e) bl br0
    163     | target bl >= height && target br < height =
    164         let bl0 = expandBracketLeft index step bl
    165         in  go (2 * e) bl0 br
    166     | otherwise =
    167         let bl0 = expandBracketLeft index step bl
    168             br0 = expandBracketRight index step br
    169         in  go (2 * e) bl0 br0
    170 
    171 expandBracketLeft
    172   :: (Ixed s, Num (IxValue s))
    173   => Index s
    174   -> IxValue s
    175   -> s
    176   -> s
    177 expandBracketLeft = expandBracketBy (-)
    178 
    179 expandBracketRight
    180   :: (Ixed s, Num (IxValue s))
    181   => Index s
    182   -> IxValue s
    183   -> s
    184   -> s
    185 expandBracketRight = expandBracketBy (+)
    186 
    187 expandBracketBy
    188   :: Ixed s
    189   => (IxValue s -> t -> IxValue s)
    190   -> Index s
    191   -> t
    192   -> s
    193   -> s
    194 expandBracketBy f index step xs = xs & ix index %~ (`f` step )
    195 
    196 -- Perform rejection sampling within the supplied bracket.
    197 rejection
    198   :: (Ord a, PrimMonad m, Ixed b, Variate (IxValue b))
    199   => (b -> a)
    200   -> Index b
    201   -> (IxValue b, IxValue b)
    202   -> a
    203   -> b
    204   -> Prob m b
    205 rejection target dimension bracket height = go where
    206   go zs = do
    207     u <- MWC.uniformR bracket
    208     let  updated = zs & ix dimension .~ u
    209     if   target updated < height
    210     then go updated
    211     else return updated
    212