commit e96bccefe000f2fd58ce8ee5fa14b83b9efb9ecc
parent 7807cf266fc3ea20248a1aee62b49bee3acfb64d
Author: Jared Tobin <jared@jtobin.ca>
Date:   Mon, 14 Oct 2013 10:41:02 +1300
Add dual-averaging code to module.
Diffstat:
1 file changed, 179 insertions(+), 3 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -7,7 +7,7 @@ import Control.Monad
 import Control.Monad.Loops
 import Control.Monad.Primitive
 import System.Random.MWC
-import System.Random.MWC.Distributions
+import System.Random.MWC.Distributions hiding (gamma)
 import Statistics.Distribution.Normal
 
 type Parameters = [Double] 
@@ -22,13 +22,20 @@ newtype BuildTree = BuildTree {
 instance Show BuildTree where
   show (BuildTree (tm, rm, tp, rp, t', n, s)) = 
        "\n" ++ "tm: " ++ show tm 
-    -- ++ "\n" ++ "rm: " ++ show rm
     ++ "\n" ++ "tp: " ++ show tp
-    -- ++ "\n" ++ "rp: " ++ show rp
     ++ "\n" ++ "t': " ++ show t'
     ++ "\n" ++ "n : " ++ show n
     ++ "\n" ++ "s : " ++ show s
 
+data DualAveragingParameters = DualAveragingParameters {
+    mAdapt :: Int
+  , delta  :: Double
+  , mu     :: Double
+  , gamma  :: Double
+  , tau0   :: Double
+  , kappa  :: Double
+  } deriving Show
+
 -- | The NUTS sampler.
 nuts 
   :: PrimMonad m
@@ -47,6 +54,99 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
                then Nothing
                else Just (t0, (pred n, t1))
 
+-- | The NUTS sampler with dual averaging.
+nutsDualAveraging
+  :: PrimMonad m
+  => Density
+  -> Gradient
+  -> Int
+  -> Int
+  -> Parameters
+  -> Gen (PrimState m)
+  -> m [Parameters]
+nutsDualAveraging lTarget glTarget n nAdapt t g = do
+    e0 <- findReasonableEpsilon lTarget glTarget t g
+    let daParams = DualAveragingParameters {
+            mu     = log (10 * e0)
+          , delta  = 0.5
+          , mAdapt = nAdapt
+          , gamma  = 0.05
+          , tau0   = 10
+          , kappa  = 0.75
+          }
+
+    unfoldrM (kernel daParams) (0, e0, 0, 0, t)
+  where
+    kernel params (m, e, eAvg, h, t0) = do
+      (eNext, eAvgNext, hNext, tNext) <- 
+        nutsKernelDualAvg lTarget glTarget e eAvg h m params t0 g
+      return $ if   m >= n
+               then Nothing
+               else Just (t0, (succ m, eNext, eAvgNext, hNext, tNext))
+
+-- | A single iteration of dual-averaging NUTS.
+nutsKernelDualAvg 
+  :: PrimMonad m 
+  => Density
+  -> Gradient
+  -> Double
+  -> Double
+  -> Double
+  -> Int
+  -> DualAveragingParameters
+  -> [Double]
+  -> Gen (PrimState m)
+  -> m (Double, Double, Double, [Double])
+nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do
+  r0 <- replicateM (length t) (normal 0 1 g)
+  z0 <- exponential 1 g
+  let logu = auxilliaryTarget lTarget t r0 - z0
+
+  let go (tn, tp, rn, rp, j, tm, n, s, a, na) g
+        | s == 1 = do
+            vj <- symmetricCategorical [-1, 1] g
+            z  <- uniform g
+
+            (tnn, rnn, tpp, rpp, t1, n1, s1, a1, na1) <-
+              if   vj == -1
+              then do
+                (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- 
+                  buildTreeDualAvg lTarget glTarget g tn rn logu vj j e t r0
+                return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1')
+              else do
+                (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- 
+                  buildTreeDualAvg lTarget glTarget g tp rp logu vj j e t r0
+                return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
+
+            let t2 | s1 == 1 
+                  && (fi n1 / fi n :: Double) > z = t1
+                   | otherwise                    = tm
+
+                n2 = n + n1
+                s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                        * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) 
+                j1 = succ j
+
+            go (tnn, rnn, tpp, rpp, j1, t2, n2, s2, a1, na1) g
+
+        | otherwise = return (tm, a)
+
+  (nextPosition, prob) <- go (t, t, r0, r0, 0, t, 1, 1, 0, 0) g
+  
+  let (hNext, eNext, eAvgNext) =
+          if   m <= mAdapt daParams
+          then (hm, exp logEm, exp logEbarM)
+          else (h, eAvg, eAvg)
+        where
+          hm = (1 - 1 / (fromIntegral m + tau0 daParams)) * h 
+             + (1 / (fromIntegral m + tau0 daParams)) * (delta daParams - prob)
+
+          logEm    = mu daParams - (sqrt (fromIntegral m) / gamma daParams) * hm
+          logEbarM = fromIntegral m ** (- (kappa daParams)) * logEm 
+                   + (1 - fromIntegral m ** (- (kappa daParams))) * (log eAvg)
+
+  return (e, eAvg, h, nextPosition)
+
 -- | A single iteration of NUTS.
 nutsKernel 
   :: PrimMonad m 
@@ -140,6 +240,82 @@ buildTree lTarget glTarget g t r logu v j e = do
     return (tnn, rnn, tpp, rpp, t2, n2, s2)
   else return (tn, rn, tp, rp, t0, n0, s0)
 
+-- | Build the tree of candidate states under dual averaging.
+buildTreeDualAvg
+  :: PrimMonad m 
+  => Density
+  -> Gradient
+  -> Gen (PrimState m)
+  -> Parameters
+  -> Parameters
+  -> Double
+  -> Double
+  -> Int
+  -> Double
+  -> Parameters
+  -> Parameters
+  -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
+buildTreeDualAvg lTarget glTarget g = go
+  where 
+    go t r logu v 0 e t0 r0 = return $
+      let (t1, r1)   = leapfrog glTarget (t, r) (v * e)
+          lAuxTarget = log $ auxilliaryTarget lTarget t1 r1
+          n          = indicate (logu <= lAuxTarget)
+          s          = indicate (logu - 1000 <  lAuxTarget)
+          m          = min 1 (acceptanceRatio lTarget t1 r1 t0 r0)
+      in  (t1, r1, t1, r1, t1, n, s, m, 1)
+
+    go t r u v j e t0 r0 = do
+      z <- uniform g
+      (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+
+      if   s1 == 1
+      then do
+        (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+          if   v == -1
+          then do 
+            (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- 
+              go tn rn u v (pred j) e t0 r0
+            return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1')
+          else do
+            (_, _, tpp', rpp', t1', n1', s1', a1', na1') <-
+              go tp rp u v (pred j) e t0 r0
+            return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
+
+        let p   = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1
+            n3  = n1 + n2
+            a3  = a1 + a2
+            na3 = na1 + na2
+
+            s3  = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                     * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+            t3  | p > (z :: Double) = t2
+                | otherwise         = t1
+
+        return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+      else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+
+-- | Heuristic for initializing step size.
+findReasonableEpsilon 
+  :: PrimMonad m 
+  => Density
+  -> Gradient
+  -> Parameters
+  -> Gen (PrimState m) 
+  -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+  r0 <- replicateM (length t0) (normal 0 1 g)
+  let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
+      a        = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+
+      go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = 
+                   let (tn, rn) = leapfrog glTarget (t, r) e
+                   in  go (2 ^ a * e) tn rn 
+               | otherwise = e
+
+  return $ go 1.0 t1 r1
+
 -- | Simulate a single step of Hamiltonian dynamics.
 leapfrog :: Gradient -> Particle -> Double -> Particle
 leapfrog glTarget (t, r) e = (tf, rf)