hasty-hamiltonian

Speedy gradient-based traversal through parameter space.
git clone git://git.jtobin.io/hasty-hamiltonian.git
Log | Files | Refs | README | LICENSE

Hamiltonian.hs (7792B)


      1 {-# OPTIONS_GHC -Wall #-}
      2 {-# LANGUAGE RecordWildCards #-}
      3 {-# LANGUAGE FlexibleContexts #-}
      4 {-# LANGUAGE TypeFamilies #-}
      5 
      6 -- |
      7 -- Module: Numeric.MCMC.Hamiltonian
      8 -- Copyright: (c) 2015 Jared Tobin
      9 -- License: MIT
     10 --
     11 -- Maintainer: Jared Tobin <jared@jtobin.ca>
     12 -- Stability: unstable
     13 -- Portability: ghc
     14 --
     15 -- This implementation performs Hamiltonian Monte Carlo using an identity mass
     16 -- matrix.
     17 --
     18 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
     19 -- while the `slice` transition can be used for more flexible purposes, such as
     20 -- working with samples in memory.
     21 --
     22 -- See <http://arxiv.org/pdf/1206.1901.pdf Neal, 2012> for the definitive
     23 -- reference of the algorithm.
     24 
     25 module Numeric.MCMC.Hamiltonian (
     26     mcmc
     27   , chain
     28   , hamiltonian
     29 
     30   -- * Re-exported
     31   , Target(..)
     32   , MWC.create
     33   , MWC.createSystemRandom
     34   , MWC.withSystemRandom
     35   , MWC.asGenIO
     36   ) where
     37 
     38 import Control.Lens hiding (index)
     39 import Control.Monad (replicateM)
     40 import Control.Monad.Codensity (lowerCodensity)
     41 import Control.Monad.Primitive (PrimState, PrimMonad)
     42 import Control.Monad.Trans.State.Strict hiding (state)
     43 import qualified Data.Foldable as Foldable (sum)
     44 import Data.Maybe (fromMaybe)
     45 import Data.Sampling.Types
     46 import Data.Traversable (for)
     47 import Pipes hiding (for, next)
     48 import qualified Pipes.Prelude as Pipes
     49 import System.Random.MWC.Probability (Prob, Gen)
     50 import qualified System.Random.MWC.Probability as MWC
     51 
     52 -- | Trace 'n' iterations of a Markov chain and stream them to stdout.
     53 --
     54 -- >>> withSystemRandom . asGenIO $ mcmc 10000 0.05 20 [0, 0] target
     55 mcmc
     56   :: ( MonadIO m, PrimMonad m
     57      , Num (IxValue (t Double)), Show (t Double), Traversable t
     58      , FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
     59      , IxValue (t Double) ~ Double)
     60   => Int
     61   -> Double
     62   -> Int
     63   -> t Double
     64   -> Target (t Double)
     65   -> Gen (PrimState m)
     66   -> m ()
     67 mcmc n step leaps chainPosition chainTarget gen = runEffect $
     68         drive step leaps Chain {..} gen
     69     >-> Pipes.take n
     70     >-> Pipes.mapM_ (liftIO . print)
     71   where
     72     chainScore    = lTarget chainTarget chainPosition
     73     chainTunables = Nothing
     74 
     75 -- | Trace 'n' iterations of a Markov chain and collect the results in a list.
     76 --
     77 -- >>> results <- withSystemRandom . asGenIO $ chain 1000 0.05 20 [0, 0] target
     78 chain
     79   :: (PrimMonad m, Traversable f
     80      , FunctorWithIndex (Index (f Double)) f, Ixed (f Double)
     81      , IxValue (f Double) ~ Double)
     82   => Int
     83   -> Double
     84   -> Int
     85   -> f Double
     86   -> Target (f Double)
     87   -> Gen (PrimState m)
     88   -> m [Chain (f Double) b]
     89 chain n step leaps position target gen = runEffect $
     90         drive step leaps origin gen
     91     >-> collect n
     92   where
     93     origin = Chain {
     94         chainScore    = lTarget target position
     95       , chainTunables = Nothing
     96       , chainTarget   = target
     97       , chainPosition = position
     98       }
     99 
    100     collect :: Monad m => Int -> Consumer a m [a]
    101     collect size = lowerCodensity $
    102       replicateM size (lift Pipes.await)
    103 
    104 -- Drive a Markov chain.
    105 drive
    106   :: (Num (IxValue (t Double)), Traversable t
    107      , FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
    108      , PrimMonad m, IxValue (t Double) ~ Double)
    109   => Double
    110   -> Int
    111   -> Chain (t Double) b
    112   -> Gen (PrimState m)
    113   -> Producer (Chain (t Double) b) m c
    114 drive step leaps = loop where
    115   loop state prng = do
    116     next <- lift (MWC.sample (execStateT (hamiltonian step leaps) state) prng)
    117     yield next
    118     loop next prng
    119 
    120 -- | A Hamiltonian transition operator.
    121 hamiltonian
    122   :: (Num (IxValue (t Double)), Traversable t
    123      , FunctorWithIndex (Index (t Double)) t, Ixed (t Double), PrimMonad m
    124      , IxValue (t Double) ~ Double)
    125   => Double -> Int -> Transition m (Chain (t Double) b)
    126 hamiltonian e l = do
    127   Chain {..} <- get
    128   r0 <- lift (for chainPosition (const MWC.standardNormal))
    129   zc <- lift (MWC.uniform :: PrimMonad m => Prob m Double)
    130   let (q, r) = leapfrogIntegrator chainTarget e l (chainPosition, r0)
    131       perturbed      = nextState chainTarget (chainPosition, q) (r0, r) zc
    132       perturbedScore = lTarget chainTarget perturbed
    133   put (Chain chainTarget perturbedScore perturbed chainTunables)
    134 
    135 -- Calculate the next state of the chain.
    136 nextState
    137   :: (Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t
    138      , FunctorWithIndex (Index (s Double)) s, Ixed (s Double)
    139      , Ixed (t Double), IxValue (t Double) ~ Double
    140      , IxValue (s Double) ~ Double)
    141   => Target b
    142   -> (b, b)
    143   -> (s Double, t Double)
    144   -> Double
    145   -> b
    146 nextState target position momentum z
    147     | z < pAccept = snd position
    148     | otherwise   = fst position
    149   where
    150     pAccept = acceptProb target position momentum
    151 
    152 -- Calculate the acceptance probability of a proposed moved.
    153 acceptProb
    154   :: (Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t
    155      , FunctorWithIndex (Index (s Double)) s, Ixed (t Double)
    156      , Ixed (s Double), IxValue (t Double) ~ Double
    157      , IxValue (s Double) ~ Double)
    158   => Target a
    159   -> (a, a)
    160   -> (s Double, t Double)
    161   -> Double
    162 acceptProb target (q0, q1) (r0, r1) = exp . min 0 $
    163   auxilliaryTarget target (q1, r1) - auxilliaryTarget target (q0, r0)
    164 
    165 -- A momentum-augmented target.
    166 auxilliaryTarget
    167   :: (Foldable t, FunctorWithIndex (Index (t Double)) t
    168      , Ixed (t Double), IxValue (t Double) ~ Double)
    169   => Target a
    170   -> (a, t Double)
    171   -> Double
    172 auxilliaryTarget target (t, r) = f t - 0.5 * innerProduct r r where
    173   f = lTarget target
    174 
    175 innerProduct
    176   :: (Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t, Ixed s)
    177   => t (IxValue s) -> s -> IxValue s
    178 innerProduct xs ys = Foldable.sum $ gzipWith (*) xs ys
    179 
    180 -- A container-generic zipwith.
    181 gzipWith
    182   :: (FunctorWithIndex (Index s) f, Ixed s)
    183   => (a -> IxValue s -> b) -> f a -> s -> f b
    184 gzipWith f xs ys = imap (\j x -> f x (fromMaybe err (ys ^? ix j))) xs where
    185   err = error "gzipWith: invalid index"
    186 
    187 -- The leapfrog or Stormer-Verlet integrator.
    188 leapfrogIntegrator
    189   :: (Num (IxValue (f Double))
    190      , FunctorWithIndex (Index (f Double)) t
    191      , FunctorWithIndex (Index (t Double)) f
    192      , Ixed (f Double), Ixed (t Double)
    193      , IxValue (f Double) ~ Double
    194      , IxValue (t Double) ~ Double)
    195   => Target (f Double)
    196   -> Double
    197   -> Int
    198   -> (f Double, t (IxValue (f Double)))
    199   -> (f Double, t (IxValue (f Double)))
    200 leapfrogIntegrator target e l (q0, r0) = go q0 r0 l where
    201   go q r 0 = (q, r)
    202   go q r n = go q1 r1 (pred n) where
    203     (q1, r1) = leapfrog target e (q, r)
    204 
    205 -- A single leapfrog step.
    206 leapfrog
    207   :: (Num (IxValue (f Double))
    208      , FunctorWithIndex (Index (f Double)) t
    209      , FunctorWithIndex (Index (t Double)) f
    210      , Ixed (t Double), Ixed (f Double)
    211      , IxValue (f Double) ~ Double, IxValue (t Double) ~ Double)
    212   => Target (f Double)
    213   -> Double
    214   -> (f Double, t (IxValue (f Double)))
    215   -> (f Double, t (IxValue (f Double)))
    216 leapfrog target e (q, r) = (qf, rf) where
    217   rm = adjustMomentum target e (q, r)
    218   qf = adjustPosition e (rm, q)
    219   rf = adjustMomentum target e (qf, rm)
    220 
    221 adjustMomentum
    222   :: (Functor f, Num (IxValue (f Double))
    223      , FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
    224   => Target (f Double)
    225   -> Double
    226   -> (f Double, t (IxValue (f Double)))
    227   -> t (IxValue (f Double))
    228 adjustMomentum target e (q, r) = r .+ ((0.5 * e) .* g q) where
    229   g   = fromMaybe err (glTarget target)
    230   err = error "adjustMomentum: no gradient provided"
    231 
    232 adjustPosition
    233   :: (Functor f, Num (IxValue (f Double))
    234      , FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
    235   => Double
    236   -> (f Double, t (IxValue (f Double)))
    237   -> t (IxValue (f Double))
    238 adjustPosition e (r, q) = q .+ (e .* r)
    239 
    240 -- Scalar-vector product.
    241 (.*) :: (Num a, Functor f) => a -> f a -> f a
    242 z .* xs = fmap (* z) xs
    243 
    244 -- Vector addition.
    245 (.+)
    246   :: (Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t)
    247   => f (IxValue t)
    248   -> t
    249   -> f (IxValue t)
    250 (.+) = gzipWith (+)
    251