Hamiltonian.hs (7792B)
1 {-# OPTIONS_GHC -Wall #-} 2 {-# LANGUAGE RecordWildCards #-} 3 {-# LANGUAGE FlexibleContexts #-} 4 {-# LANGUAGE TypeFamilies #-} 5 6 -- | 7 -- Module: Numeric.MCMC.Hamiltonian 8 -- Copyright: (c) 2015 Jared Tobin 9 -- License: MIT 10 -- 11 -- Maintainer: Jared Tobin <jared@jtobin.ca> 12 -- Stability: unstable 13 -- Portability: ghc 14 -- 15 -- This implementation performs Hamiltonian Monte Carlo using an identity mass 16 -- matrix. 17 -- 18 -- The 'mcmc' function streams a trace to stdout to be processed elsewhere, 19 -- while the `slice` transition can be used for more flexible purposes, such as 20 -- working with samples in memory. 21 -- 22 -- See <http://arxiv.org/pdf/1206.1901.pdf Neal, 2012> for the definitive 23 -- reference of the algorithm. 24 25 module Numeric.MCMC.Hamiltonian ( 26 mcmc 27 , chain 28 , hamiltonian 29 30 -- * Re-exported 31 , Target(..) 32 , MWC.create 33 , MWC.createSystemRandom 34 , MWC.withSystemRandom 35 , MWC.asGenIO 36 ) where 37 38 import Control.Lens hiding (index) 39 import Control.Monad (replicateM) 40 import Control.Monad.Codensity (lowerCodensity) 41 import Control.Monad.Primitive (PrimState, PrimMonad) 42 import Control.Monad.Trans.State.Strict hiding (state) 43 import qualified Data.Foldable as Foldable (sum) 44 import Data.Maybe (fromMaybe) 45 import Data.Sampling.Types 46 import Data.Traversable (for) 47 import Pipes hiding (for, next) 48 import qualified Pipes.Prelude as Pipes 49 import System.Random.MWC.Probability (Prob, Gen) 50 import qualified System.Random.MWC.Probability as MWC 51 52 -- | Trace 'n' iterations of a Markov chain and stream them to stdout. 53 -- 54 -- >>> withSystemRandom . asGenIO $ mcmc 10000 0.05 20 [0, 0] target 55 mcmc 56 :: ( MonadIO m, PrimMonad m 57 , Num (IxValue (t Double)), Show (t Double), Traversable t 58 , FunctorWithIndex (Index (t Double)) t, Ixed (t Double) 59 , IxValue (t Double) ~ Double) 60 => Int 61 -> Double 62 -> Int 63 -> t Double 64 -> Target (t Double) 65 -> Gen (PrimState m) 66 -> m () 67 mcmc n step leaps chainPosition chainTarget gen = runEffect $ 68 drive step leaps Chain {..} gen 69 >-> Pipes.take n 70 >-> Pipes.mapM_ (liftIO . print) 71 where 72 chainScore = lTarget chainTarget chainPosition 73 chainTunables = Nothing 74 75 -- | Trace 'n' iterations of a Markov chain and collect the results in a list. 76 -- 77 -- >>> results <- withSystemRandom . asGenIO $ chain 1000 0.05 20 [0, 0] target 78 chain 79 :: (PrimMonad m, Traversable f 80 , FunctorWithIndex (Index (f Double)) f, Ixed (f Double) 81 , IxValue (f Double) ~ Double) 82 => Int 83 -> Double 84 -> Int 85 -> f Double 86 -> Target (f Double) 87 -> Gen (PrimState m) 88 -> m [Chain (f Double) b] 89 chain n step leaps position target gen = runEffect $ 90 drive step leaps origin gen 91 >-> collect n 92 where 93 origin = Chain { 94 chainScore = lTarget target position 95 , chainTunables = Nothing 96 , chainTarget = target 97 , chainPosition = position 98 } 99 100 collect :: Monad m => Int -> Consumer a m [a] 101 collect size = lowerCodensity $ 102 replicateM size (lift Pipes.await) 103 104 -- Drive a Markov chain. 105 drive 106 :: (Num (IxValue (t Double)), Traversable t 107 , FunctorWithIndex (Index (t Double)) t, Ixed (t Double) 108 , PrimMonad m, IxValue (t Double) ~ Double) 109 => Double 110 -> Int 111 -> Chain (t Double) b 112 -> Gen (PrimState m) 113 -> Producer (Chain (t Double) b) m c 114 drive step leaps = loop where 115 loop state prng = do 116 next <- lift (MWC.sample (execStateT (hamiltonian step leaps) state) prng) 117 yield next 118 loop next prng 119 120 -- | A Hamiltonian transition operator. 121 hamiltonian 122 :: (Num (IxValue (t Double)), Traversable t 123 , FunctorWithIndex (Index (t Double)) t, Ixed (t Double), PrimMonad m 124 , IxValue (t Double) ~ Double) 125 => Double -> Int -> Transition m (Chain (t Double) b) 126 hamiltonian e l = do 127 Chain {..} <- get 128 r0 <- lift (for chainPosition (const MWC.standardNormal)) 129 zc <- lift (MWC.uniform :: PrimMonad m => Prob m Double) 130 let (q, r) = leapfrogIntegrator chainTarget e l (chainPosition, r0) 131 perturbed = nextState chainTarget (chainPosition, q) (r0, r) zc 132 perturbedScore = lTarget chainTarget perturbed 133 put (Chain chainTarget perturbedScore perturbed chainTunables) 134 135 -- Calculate the next state of the chain. 136 nextState 137 :: (Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t 138 , FunctorWithIndex (Index (s Double)) s, Ixed (s Double) 139 , Ixed (t Double), IxValue (t Double) ~ Double 140 , IxValue (s Double) ~ Double) 141 => Target b 142 -> (b, b) 143 -> (s Double, t Double) 144 -> Double 145 -> b 146 nextState target position momentum z 147 | z < pAccept = snd position 148 | otherwise = fst position 149 where 150 pAccept = acceptProb target position momentum 151 152 -- Calculate the acceptance probability of a proposed moved. 153 acceptProb 154 :: (Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t 155 , FunctorWithIndex (Index (s Double)) s, Ixed (t Double) 156 , Ixed (s Double), IxValue (t Double) ~ Double 157 , IxValue (s Double) ~ Double) 158 => Target a 159 -> (a, a) 160 -> (s Double, t Double) 161 -> Double 162 acceptProb target (q0, q1) (r0, r1) = exp . min 0 $ 163 auxilliaryTarget target (q1, r1) - auxilliaryTarget target (q0, r0) 164 165 -- A momentum-augmented target. 166 auxilliaryTarget 167 :: (Foldable t, FunctorWithIndex (Index (t Double)) t 168 , Ixed (t Double), IxValue (t Double) ~ Double) 169 => Target a 170 -> (a, t Double) 171 -> Double 172 auxilliaryTarget target (t, r) = f t - 0.5 * innerProduct r r where 173 f = lTarget target 174 175 innerProduct 176 :: (Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t, Ixed s) 177 => t (IxValue s) -> s -> IxValue s 178 innerProduct xs ys = Foldable.sum $ gzipWith (*) xs ys 179 180 -- A container-generic zipwith. 181 gzipWith 182 :: (FunctorWithIndex (Index s) f, Ixed s) 183 => (a -> IxValue s -> b) -> f a -> s -> f b 184 gzipWith f xs ys = imap (\j x -> f x (fromMaybe err (ys ^? ix j))) xs where 185 err = error "gzipWith: invalid index" 186 187 -- The leapfrog or Stormer-Verlet integrator. 188 leapfrogIntegrator 189 :: (Num (IxValue (f Double)) 190 , FunctorWithIndex (Index (f Double)) t 191 , FunctorWithIndex (Index (t Double)) f 192 , Ixed (f Double), Ixed (t Double) 193 , IxValue (f Double) ~ Double 194 , IxValue (t Double) ~ Double) 195 => Target (f Double) 196 -> Double 197 -> Int 198 -> (f Double, t (IxValue (f Double))) 199 -> (f Double, t (IxValue (f Double))) 200 leapfrogIntegrator target e l (q0, r0) = go q0 r0 l where 201 go q r 0 = (q, r) 202 go q r n = go q1 r1 (pred n) where 203 (q1, r1) = leapfrog target e (q, r) 204 205 -- A single leapfrog step. 206 leapfrog 207 :: (Num (IxValue (f Double)) 208 , FunctorWithIndex (Index (f Double)) t 209 , FunctorWithIndex (Index (t Double)) f 210 , Ixed (t Double), Ixed (f Double) 211 , IxValue (f Double) ~ Double, IxValue (t Double) ~ Double) 212 => Target (f Double) 213 -> Double 214 -> (f Double, t (IxValue (f Double))) 215 -> (f Double, t (IxValue (f Double))) 216 leapfrog target e (q, r) = (qf, rf) where 217 rm = adjustMomentum target e (q, r) 218 qf = adjustPosition e (rm, q) 219 rf = adjustMomentum target e (qf, rm) 220 221 adjustMomentum 222 :: (Functor f, Num (IxValue (f Double)) 223 , FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) 224 => Target (f Double) 225 -> Double 226 -> (f Double, t (IxValue (f Double))) 227 -> t (IxValue (f Double)) 228 adjustMomentum target e (q, r) = r .+ ((0.5 * e) .* g q) where 229 g = fromMaybe err (glTarget target) 230 err = error "adjustMomentum: no gradient provided" 231 232 adjustPosition 233 :: (Functor f, Num (IxValue (f Double)) 234 , FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) 235 => Double 236 -> (f Double, t (IxValue (f Double))) 237 -> t (IxValue (f Double)) 238 adjustPosition e (r, q) = q .+ (e .* r) 239 240 -- Scalar-vector product. 241 (.*) :: (Num a, Functor f) => a -> f a -> f a 242 z .* xs = fmap (* z) xs 243 244 -- Vector addition. 245 (.+) 246 :: (Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t) 247 => f (IxValue t) 248 -> t 249 -> f (IxValue t) 250 (.+) = gzipWith (+) 251