hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

NUTS.hs (12092B)


      1 -- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
      2 --   Lengths in Hamiltonian Monte Carlo.
      3 
      4 module Numeric.MCMC.NUTS where
      5 
      6 import Control.Monad
      7 import Control.Monad.Loops
      8 import Control.Monad.Primitive
      9 import System.Random.MWC
     10 import System.Random.MWC.Distributions hiding (gamma)
     11 import Statistics.Distribution.Normal
     12 
     13 type Parameters = [Double] 
     14 type Density    = Parameters -> Double
     15 type Gradient   = Parameters -> Parameters
     16 type Particle   = (Parameters, Parameters)
     17 type StateSpecs = ([Double], [Double], [Double], [Double], [Double], Int, Int)
     18 
     19 newtype StateTree = StateTree { 
     20     getStateTree :: StateSpecs
     21   }
     22 
     23 instance Show StateTree where
     24   show (StateTree (tm, rm, tp, rp, t', n, s)) = 
     25        "\n" ++ "tm: " ++ show tm 
     26     ++ "\n" ++ "tp: " ++ show tp
     27     ++ "\n" ++ "t': " ++ show t'
     28     ++ "\n" ++ "n : " ++ show n
     29     ++ "\n" ++ "s : " ++ show s
     30 
     31 data DualAveragingParameters = DualAveragingParameters {
     32     mAdapt :: Int
     33   , delta  :: Double
     34   , mu     :: Double
     35   , gamma  :: Double
     36   , tau0   :: Double
     37   , kappa  :: Double
     38   } deriving Show
     39 
     40 -- | The NUTS sampler.
     41 nuts 
     42   :: PrimMonad m
     43   => Density
     44   -> Gradient
     45   -> Int
     46   -> Double
     47   -> Parameters
     48   -> Gen (PrimState m)
     49   -> m [Parameters]
     50 nuts lTarget glTarget n e t g = go t 0 []
     51   where go position j acc
     52           | j >= n    = return acc
     53           | otherwise = do
     54               nextPosition <- nutsKernel lTarget glTarget e position g
     55               go nextPosition (succ j) (nextPosition : acc)
     56 
     57 -- | The NUTS sampler with dual averaging.
     58 nutsDualAveraging
     59   :: PrimMonad m
     60   => Density
     61   -> Gradient
     62   -> Int
     63   -> Int
     64   -> Parameters
     65   -> Gen (PrimState m)
     66   -> m [Parameters]
     67 nutsDualAveraging lTarget glTarget n nAdapt t g = do
     68     e0 <- findReasonableEpsilon lTarget glTarget t g
     69     let daParams = basicDualAveragingParameters e0 nAdapt
     70     chain <- unfoldrM (kernel daParams) (1, e0, 1, 0, t)
     71     return $ drop nAdapt chain
     72   where
     73     kernel params (m, e, eAvg, h, t0) = do
     74       (eNext, eAvgNext, hNext, tNext) <- 
     75         nutsKernelDualAvg lTarget glTarget e eAvg h m params t0 g
     76       return $ if   m > n + nAdapt
     77                then Nothing
     78                else Just (t0, (succ m, eNext, eAvgNext, hNext, tNext))
     79 
     80 -- | Default DA parameters, given a base step size and burn in period.
     81 basicDualAveragingParameters :: Double -> Int -> DualAveragingParameters
     82 basicDualAveragingParameters step burnInPeriod = DualAveragingParameters {
     83     mu     = log (10 * step)
     84   , delta  = 0.5
     85   , mAdapt = burnInPeriod
     86   , gamma  = 0.05
     87   , tau0   = 10
     88   , kappa  = 0.75
     89   }
     90 
     91 -- | A single iteration of dual-averaging NUTS.
     92 nutsKernelDualAvg 
     93   :: PrimMonad m 
     94   => Density
     95   -> Gradient
     96   -> Double
     97   -> Double
     98   -> Double
     99   -> Int
    100   -> DualAveragingParameters
    101   -> [Double]
    102   -> Gen (PrimState m)
    103   -> m (Double, Double, Double, Parameters)
    104 nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do
    105   r0 <- replicateM (length t) (normal 0 1 g)
    106   z0 <- exponential 1 g
    107   let logu = log (auxilliaryTarget lTarget t r0) - z0
    108 
    109   let go (tn, tp, rn, rp, tm, j, n, s, a, na) g
    110         | s == 1 = do
    111             vj <- symmetricCategorical [-1, 1] g
    112             z  <- uniform g
    113 
    114             (tnn, rnn, tpp, rpp, t1, n1, s1, a1, na1) <-
    115               if   vj == -1
    116               then do
    117                 (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- 
    118                   buildTreeDualAvg lTarget glTarget g tn rn logu vj j e t r0
    119                 return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1')
    120               else do
    121                 (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- 
    122                   buildTreeDualAvg lTarget glTarget g tp rp logu vj j e t r0
    123                 return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
    124 
    125             let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z 
    126 
    127                 n2 = n + n1
    128                 s2 = s1 * stopCriterion tnn tpp rnn rpp
    129                 j1 = succ j
    130                 t2 | accept    = t1
    131                    | otherwise = tm
    132 
    133             go (tnn, tpp, rnn, rpp, t2, j1, n2, s2, a1, na1) g
    134 
    135         | otherwise = return (tm, a, na)
    136 
    137   (nextPosition, alpha, nalpha) <- go (t, t, r0, r0, t, 0, 1, 1, 0, 0) g
    138   
    139   let (hNext, eNext, eAvgNext) =
    140           if   m <= mAdapt daParams
    141           then (hm, exp logEm, exp logEbarM)
    142           else (h, eAvg, eAvg)
    143         where
    144           eta = 1 / (fromIntegral m + tau0 daParams)
    145           hm  = (1 - eta) * h 
    146               + eta * (delta daParams - alpha / fromIntegral nalpha)
    147 
    148           zeta = fromIntegral m ** (- (kappa daParams))
    149 
    150           logEm    = mu daParams - sqrt (fromIntegral m) / gamma daParams * hm
    151           logEbarM = (1 - zeta) * log eAvg + zeta * logEm
    152 
    153   return (eNext, eAvgNext, hNext, nextPosition)
    154 
    155 -- | A single iteration of NUTS.
    156 nutsKernel 
    157   :: PrimMonad m 
    158   => Density
    159   -> Gradient
    160   -> Double
    161   -> Parameters
    162   -> Gen (PrimState m)
    163   -> m Parameters
    164 nutsKernel lTarget glTarget e t g = do
    165   r0   <- replicateM (length t) (normal 0 1 g)
    166   z0   <- exponential 1 g
    167   let logu = log (auxilliaryTarget lTarget t r0) - z0
    168 
    169   let go (tn, tp, rn, rp, tm, j, n, s) g
    170         | s == 1 = do
    171             vj <- symmetricCategorical [-1, 1] g
    172             z  <- uniform g
    173 
    174             (tnn, rnn, tpp, rpp, t1, n1, s1) <- 
    175               if   vj == -1
    176               then do
    177                 (tnn', rnn', _, _, t1', n1', s1') <- 
    178                   buildTree lTarget glTarget g tn rn logu vj j e
    179                 return (tnn', rnn', tp, rp, t1', n1', s1')
    180               else do
    181                 (_, _, tpp', rpp', t1', n1', s1') <- 
    182                   buildTree lTarget glTarget g tp rp logu vj j e
    183                 return (tn, rn, tpp', rpp', t1', n1', s1')
    184 
    185             let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z
    186 
    187                 n2 = n + n1
    188                 s2 = s1 * stopCriterion tnn tpp rnn rpp
    189                 j1 = succ j
    190                 t2 | accept    = t1
    191                    | otherwise = tm
    192 
    193             go (tnn, tpp, rnn, rpp, t2, j1, n2, s2) g
    194 
    195         | otherwise = return tm
    196 
    197   go (t, t, r0, r0, t, 0, 1, 1) g
    198 
    199 -- | Build the 'tree' of candidate states.
    200 buildTree 
    201   :: PrimMonad m 
    202   => Density
    203   -> Gradient
    204   -> Gen (PrimState m)
    205   -> Parameters
    206   -> Parameters
    207   -> Double
    208   -> Double
    209   -> Int
    210   -> Double
    211   -> m StateSpecs
    212 buildTree lTarget glTarget g t r logu v 0 e = do
    213   let (t0, r0) = leapfrog glTarget (t, r) (v * e)
    214       joint    = log $ auxilliaryTarget lTarget t0 r0
    215       n        = indicate (logu < joint)
    216       s        = indicate (logu - 1000 < joint)
    217   return (t0, r0, t0, r0, t0, n, s)
    218 
    219 buildTree lTarget glTarget g t r logu v j e = do
    220   z <- uniform g
    221   (tn, rn, tp, rp, t0, n0, s0) <- 
    222     buildTree lTarget glTarget g t r logu v (pred j) e
    223 
    224   if   s0 == 1
    225   then do
    226     (tnn, rnn, tpp, rpp, t1, n1, s1) <- 
    227       if   v == -1
    228       then do
    229         (tnn', rnn', _, _, t1', n1', s1') <- 
    230           buildTree lTarget glTarget g tn rn logu v (pred j) e
    231         return (tnn', rnn', tp, rp, t1', n1', s1')
    232       else do
    233         (_, _, tpp', rpp', t1', n1', s1') <- 
    234           buildTree lTarget glTarget g tp rp logu v (pred j) e
    235         return (tn, rn, tpp', rpp', t1', n1', s1')
    236 
    237     let accept = (fi n1 / max (fi (n0 + n1)) 1) > (z :: Double)
    238         n2     = n0 + n1
    239         s2     = s0 * s1 * stopCriterion tnn tpp rnn rpp
    240         t2     | accept    = t1
    241                | otherwise = t0 
    242 
    243     return (tnn, rnn, tpp, rpp, t2, n2, s2)
    244   else return (tn, rn, tp, rp, t0, n0, s0)
    245 
    246 -- | Determine whether or not to stop doubling the tree of candidate states.
    247 stopCriterion :: (Integral a, Num b, Ord b) => [b] -> [b] -> [b] -> [b] -> a
    248 stopCriterion tn tp rn rp = 
    249       indicate (positionDifference `innerProduct` rn >= 0)
    250     * indicate (positionDifference `innerProduct` rp >= 0)
    251   where
    252     positionDifference = tp .- tn
    253 
    254 -- | Build the tree of candidate states under dual averaging.
    255 buildTreeDualAvg
    256   :: PrimMonad m 
    257   => Density
    258   -> Gradient
    259   -> Gen (PrimState m)
    260   -> Parameters
    261   -> Parameters
    262   -> Double
    263   -> Double
    264   -> Int
    265   -> Double
    266   -> Parameters
    267   -> Parameters
    268   -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
    269 buildTreeDualAvg lTarget glTarget g t r logu v 0 e t0 r0 = do
    270   let (t1, r1) = leapfrog glTarget (t, r) (v * e)
    271       joint    = log $ auxilliaryTarget lTarget t1 r1
    272       n        = indicate (logu < joint)
    273       s        = indicate (logu - 1000 <  joint)
    274       a        = min 1 (acceptanceRatio lTarget t0 t1 r0 r1)
    275   return (t1, r1, t1, r1, t1, n, s, a, 1)
    276       
    277 buildTreeDualAvg lTarget glTarget g t r logu v j e t0 r0 = do
    278   z <- uniform g
    279   (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- 
    280     buildTreeDualAvg lTarget glTarget g t r logu v (pred j) e t0 r0
    281 
    282   if   s1 == 1
    283   then do
    284     (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
    285       if   v == -1
    286       then do 
    287         (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- 
    288           buildTreeDualAvg lTarget glTarget g tn rn logu v (pred j) e t0 r0
    289         return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1')
    290       else do
    291         (_, _, tpp', rpp', t1', n1', s1', a1', na1') <-
    292           buildTreeDualAvg lTarget glTarget g tp rp logu v (pred j) e t0 r0
    293         return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
    294 
    295     let p      = fi n2 / max (fi (n1 + n2)) 1
    296         accept = p > (z :: Double)
    297         n3     = n1 + n2
    298         a3     = a1 + a2
    299         na3    = na1 + na2
    300         s3     = s1 * s2 * stopCriterion tnn tpp rnn rpp
    301 
    302         t3  | accept    = t2
    303             | otherwise = t1
    304 
    305     return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
    306   else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
    307 
    308 -- | Heuristic for initializing step size.
    309 findReasonableEpsilon 
    310   :: PrimMonad m 
    311   => Density
    312   -> Gradient
    313   -> Parameters
    314   -> Gen (PrimState m) 
    315   -> m Double
    316 findReasonableEpsilon lTarget glTarget t0 g = do
    317   r0 <- replicateM (length t0) (normal 0 1 g)
    318   let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
    319       a        = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
    320 
    321       go j e t r 
    322         | j <= 0 = e -- no need to shrink this excessively
    323         | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) = 
    324             let (tn, rn) = leapfrog glTarget (t, r) e
    325             in  go (pred j) (2 ^^ a * e) tn rn 
    326         | otherwise = e
    327 
    328   return $ go 10 1.0 t1 r1
    329 
    330 -- | Simulate a single step of Hamiltonian dynamics.
    331 leapfrog :: Gradient -> Particle -> Double -> Particle
    332 leapfrog glTarget (t, r) e = (tf, rf)
    333   where 
    334     rm = adjustMomentum glTarget e t r
    335     tf = adjustPosition e rm t
    336     rf = adjustMomentum glTarget e tf rm
    337 
    338 -- | Adjust momentum.
    339 adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c]
    340 adjustMomentum glTarget e t r = r .+ ((e / 2) .* glTarget t)
    341 
    342 -- | Adjust position.
    343 adjustPosition :: Num c => c -> [c] -> [c] -> [c]
    344 adjustPosition e r t = t .+ (e .* r)
    345 
    346 -- | The MH acceptance ratio for a given proposal.
    347 acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
    348 acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
    349                                     / auxilliaryTarget lTarget t0 r0
    350 
    351 -- | The negative potential. 
    352 auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
    353 auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
    354 
    355 -- | Simple inner product.
    356 innerProduct :: Num a => [a] -> [a] -> a
    357 innerProduct xs ys = sum $ zipWith (*) xs ys
    358 
    359 -- | Vectorized multiplication.
    360 (.*) :: Num b => b -> [b] -> [b]
    361 z .* xs = map (* z) xs
    362 
    363 -- | Vectorized subtraction.
    364 (.-) :: Num a => [a] -> [a] -> [a]
    365 xs .- ys = zipWith (-) xs ys
    366 
    367 -- | Vectorized addition.
    368 (.+) :: Num a => [a] -> [a] -> [a]
    369 xs .+ ys = zipWith (+) xs ys
    370 
    371 -- | Indicator function.
    372 indicate :: Integral a => Bool -> a
    373 indicate True  = 1
    374 indicate False = 0
    375 
    376 -- | A symmetric categorical (discrete uniform) distribution.
    377 symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
    378 symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
    379 symmetricCategorical zs g = do
    380   j <- uniformR (0, length zs - 1) g
    381   return $ zs !! j
    382 
    383 -- | Alias for fromIntegral.
    384 fi :: (Integral a, Num b) => a -> b
    385 fi = fromIntegral
    386