commit 26038098d57eb64c8afcac8fc2dacc1e3f3927d0
parent cc4df614316dd34b94823b42dfd86913d5b5a4af
Author: Jared Tobin <jared@jtobin.ca>
Date:   Mon,  9 Sep 2013 12:19:01 +1200
Add reference, split between NUTS and dual-averaging NUTS.
Diffstat:
| A | HoffmanGelman2011_NUTS.pdf | | | 0 |  | 
| M | NUTS.hs | | | 201 | ++++++++++++++++++++++++++++++++++++++++++------------------------------------- | 
| A | daNUTS.hs | | | 199 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | 
3 files changed, 305 insertions(+), 95 deletions(-)
diff --git a/HoffmanGelman2011_NUTS.pdf b/HoffmanGelman2011_NUTS.pdf
Binary files differ.
diff --git a/NUTS.hs b/NUTS.hs
@@ -1,46 +1,69 @@
 -- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
 --   Lengths in Hamiltonian Monte Carlo.
 
+module NUTS where
+
 import Control.Monad
 import Control.Monad.Primitive
-import System.Random.MWC
+import System.Random.MWC -- FIXME change to Prob monad
 import System.Random.MWC.Distributions
 import Statistics.Distribution.Normal
 
-type Parameters = [Double]
+-- FIXME change to probably api
+type Parameters = [Double] 
 type Density    = Parameters -> Double
 type Gradient   = Parameters -> Parameters
 type Particle   = (Parameters, Parameters)
 
-leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
-leapfrogIntegrator n glTarget particle e = go particle n
-  where go state ndisc 
-          | ndisc <= 0 = state
-          | otherwise  = go (leapfrog glTarget state e) (pred n)
+-- FIXME must be streaming
+nuts :: PrimMonad m
+     => Density
+     -> Gradient
+     -> Int
+     -> Parameters
+     -> Gen (PrimState m)
+     -> m Parameters
+nuts lTarget glTarget m t g = do
+  e <- findReasonableEpsilon lTarget glTarget t g
+  let go 0 t0 = return t0
+      go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n)
+
+  go m t
+
+nutsKernel :: PrimMonad m 
+           => Density
+           -> Gradient
+           -> Double
+           -> Parameters
+           -> Gen (PrimState m)
+           -> m Parameters
+nutsKernel lTarget glTarget e t g = do
+  r0 <- replicateM (length t) (normal 0 1 g)
+  u  <- uniformR (0, auxilliaryTarget lTarget t r0) g
 
-leapfrog :: Gradient -> Particle -> Double -> Particle
-leapfrog glTarget (t, r) e = (tf, rf)
-  where rm = zipWith (+) r  ((e / 2) .* glTarget t)
-        tf = zipWith (+) t  (e .* rm)
-        rf = zipWith (+) rm ((e / 2) .* glTarget tf)
+  let go (tn, tp, rn, rp, j, tm, n, s) g
+        | s == 1 = do
+            vj <- symmetricCategorical [-1, 1] g
+            z  <- uniform g
 
-findReasonableEpsilon :: PrimMonad m 
-                      => Density
-                      -> Gradient
-                      -> Parameters
-                      -> Gen (PrimState m) 
-                      -> m Double
-findReasonableEpsilon lTarget glTarget t0 g = do
-  r0 <- replicateM (length t0) (normal 0 1 g)
-  let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
-      a        = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+            (tnn, rnn, tpp, rpp, t1, n1, s1) <-
+              if   vj == -1
+              then buildTree lTarget glTarget g tn rn u vj j e
+              else buildTree lTarget glTarget g tp rp u vj j e
 
-      go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = 
-                   let (tn, rn) = leapfrog glTarget (t, r) e
-                   in  go (2 ^ a * e) tn rn 
-               | otherwise = e
+            let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
+                   | otherwise = t
 
-  return $ go 1.0 t1 r1
+                n2 = n + n1
+                s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                        * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) 
+                j1 = succ j
+
+            go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) g
+
+        | otherwise = return tm
+
+  go (t, t, r0, r0, 0, t, 1, 1) g
 
 buildTree 
   :: PrimMonad m 
@@ -53,84 +76,80 @@ buildTree
   -> Double
   -> Int
   -> Double
-  -> Parameters
-  -> Parameters
-  -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
+  -> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
 buildTree lTarget glTarget g = go
-  where 
-    go t r u v 0 e _ r0 = return $
-      let (t1, r1) = leapfrog glTarget (t, r) (v * e)
-          n        = indicate (u <= auxilliaryTarget lTarget t1 r1)
-          s        = indicate (u <  exp 1000 * auxilliaryTarget lTarget t1 r1)
-          m        = min 1 (acceptanceRatio lTarget t1 r1 r0 r0)
-      in  (t1, r1, t1, r1, t1, n, s, m, 1)
-
-    go t r u v j e t0 r0 = do
+  where
+    go t r u v 0 e = return $
+      let (t0, r0) = leapfrog glTarget (t, r) (v * e)
+          auxTgt   = auxilliaryTarget lTarget t0 r0
+          n        = indicate (u <= auxTgt)
+          s        = indicate (auxTgt > log u - 1000)
+      in  (t0, r0, t0, r0, t, n, s)
+
+    go t r u v j e = do
       z <- uniform g
-      (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+      (tn, rn, tp, rp, t0, n0, s0) <- go t r u v (pred j) e
 
-      if   s1 == 1
+      if s0 == 1
       then do
-        (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+        (tnn, rnn, tpp, rpp, t1, n1, s1) <- 
           if   v == -1
-          then go tn rn u v (pred j) e t0 r0
-          else go tp rp u v (pred j) e t0 r0
-
-        let p   = fromIntegral n2 / fromIntegral (n1 + n2)
-            n3  = n1 + n2
-            t3  | p > (z :: Double) = t2
-                | otherwise         = t1
-            a3  = a1 + a2
-            na3 = na1 + na2
-            s3  = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
-                     * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
-
-        return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
-      else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
-
-innerNutsKernel 
-    :: PrimMonad m 
-    => Density
-    -> Gradient
-    -> Parameters
-    -> Double
-    -> Gen (PrimState m)
-    -> m (Parameters, Double, Int)
-innerNutsKernel lTarget glTarget t e g = do
-  r0 <- replicateM (length t) (normal 0 1 g)
-  u  <- uniformR (0, auxilliaryTarget lTarget t r0) g
+          then go tn rn u v (pred j) e
+          else go tp rp u v (pred j) e
 
-  let go (tn, tp, rn, rp, j, tm, n, s) aOrig naOrig g
-        | s == 1 = do
-            vj <- symmetricCategorical [-1, 1] g
-            z  <- uniform g
+        let p  = fromIntegral n1 / fromIntegral (n0 + n1)
+            n2 = n0 + n1
+            t2 | p > (z :: Double) = t1
+               | otherwise         = t0 
+            s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                    * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
 
-            (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
-              if   vj == -1
-              then buildTree lTarget glTarget g tn rn u vj j e t r0
-              else buildTree lTarget glTarget g tp rp u vj j e t r0
+        return (tnn, rnn, tpp, rpp, t2, n2, s2)
+      else return (tn, rn, tp, rp, t0, n0, s0)
 
-            let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
-                   | otherwise = t
+findReasonableEpsilon :: PrimMonad m 
+                      => Density
+                      -> Gradient
+                      -> Parameters
+                      -> Gen (PrimState m) 
+                      -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+  r0 <- replicateM (length t0) (normal 0 1 g)
+  let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
+      a        = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
 
-                n2 = n + n1
-                s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
-                        * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) 
-                j1 = succ j
+      go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = 
+                   let (tn, rn) = leapfrog glTarget (t, r) e
+                   in  go (2 ^ a * e) tn rn 
+               | otherwise = e
+
+  return $ go 1.0 t1 r1
 
-            go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) a na g
+leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
+leapfrogIntegrator n glTarget particle e = go particle n
+  where go state ndisc 
+          | ndisc <= 0 = state
+          | otherwise  = go (leapfrog glTarget state e) (pred n)
 
-        | otherwise = return (tm, aOrig, naOrig)
+leapfrog :: Gradient -> Particle -> Double -> Particle
+leapfrog glTarget (t, r) e = (tf, rf)
+  where rm = adjustMomentum glTarget e t r
+        tf = adjustPosition e rm t
+        rf = adjustMomentum glTarget e tf rm
 
-  go (t, t, r0, r0, 0, t, 1, 1) 0 0 g
+adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c]
+adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t)
 
-auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
-auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+adjustPosition :: Num c => c -> [c] -> [c] -> [c]
+adjustPosition e r t = zipWith (+) t (e .* r)
 
 acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
 acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
                                     / auxilliaryTarget lTarget t0 r0
 
+auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
+auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+
 innerProduct :: Num a => [a] -> [a] -> a
 innerProduct xs ys = sum $ zipWith (*) xs ys
 
@@ -153,11 +172,3 @@ symmetricCategorical zs g = do
 fi :: (Integral a, Num b) => a -> b
 fi = fromIntegral
 
--- Testing
-
-f :: Density
-f _ = log $ 1 / 10
-
-g :: Gradient
-g xs = replicate (length xs) 1
-
diff --git a/daNUTS.hs b/daNUTS.hs
@@ -0,0 +1,199 @@
+-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
+--   Lengths in Hamiltonian Monte Carlo.
+
+import Control.Monad
+import Control.Monad.Primitive
+import System.Random.MWC
+import System.Random.MWC.Distributions
+import Statistics.Distribution.Normal
+
+type Parameters = [Double]
+type Density    = Parameters -> Double
+type Gradient   = Parameters -> Parameters
+type Particle   = (Parameters, Parameters)
+
+leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
+leapfrogIntegrator n glTarget particle e = go particle n
+  where go state ndisc 
+          | ndisc <= 0 = state
+          | otherwise  = go (leapfrog glTarget state e) (pred n)
+
+leapfrog :: Gradient -> Particle -> Double -> Particle
+leapfrog glTarget (t, r) e = (tf, rf)
+  where rm = zipWith (+) r  ((e / 2) .* glTarget t)
+        tf = zipWith (+) t  (e .* rm)
+        rf = zipWith (+) rm ((e / 2) .* glTarget tf)
+
+findReasonableEpsilon :: PrimMonad m 
+                      => Density
+                      -> Gradient
+                      -> Parameters
+                      -> Gen (PrimState m) 
+                      -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+  r0 <- replicateM (length t0) (normal 0 1 g)
+  let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
+      a        = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+
+      go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = 
+                   let (tn, rn) = leapfrog glTarget (t, r) e
+                   in  go (2 ^ a * e) tn rn 
+               | otherwise = e
+
+  return $ go 1.0 t1 r1
+
+-- this is the dual averaging buildTree
+buildTree 
+  :: PrimMonad m 
+  => Density
+  -> Gradient
+  -> Gen (PrimState m)
+  -> Parameters
+  -> Parameters
+  -> Double
+  -> Double
+  -> Int
+  -> Double
+  -> Parameters
+  -> Parameters
+  -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
+buildTree lTarget glTarget g = go
+  where 
+    go t r u v 0 e _ r0 = return $
+      let (t1, r1) = leapfrog glTarget (t, r) (v * e)
+          n        = indicate (u <= auxilliaryTarget lTarget t1 r1)
+          s        = indicate (u <  exp 1000 * auxilliaryTarget lTarget t1 r1)
+          m        = min 1 (acceptanceRatio lTarget t1 r1 r0 r0)
+      in  (t1, r1, t1, r1, t1, n, s, m, 1)
+
+    go t r u v j e t0 r0 = do
+      z <- uniform g
+      (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+
+      if   s1 == 1
+      then do
+        (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+          if   v == -1
+          then go tn rn u v (pred j) e t0 r0
+          else go tp rp u v (pred j) e t0 r0
+
+        let p   = fromIntegral n2 / fromIntegral (n1 + n2)
+            n3  = n1 + n2
+            t3  | p > (z :: Double) = t2
+                | otherwise         = t1
+            a3  = a1 + a2
+            na3 = na1 + na2
+            s3  = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                     * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+        return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+      else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+
+relaxingNuts = undefined
+
+-- better idea: wrap this dual averaging scheme around the actual nuts
+--              kernel itself.  in fact you'd like to just be able to loosely
+--              add dual-averaging to any procedure.
+--
+-- adaptingNutsKenel lTarget glTarget t m g = do
+--   e0 <- findReasonableEpsilon lTarget glTarget t g
+-- 
+--   let mu      = log (10 * e)
+--       epsBar0 = 0
+--       h0Bar   = 0
+--       gamma   = 0.05
+--       delta   = 0.45 -- target mean acceptance probability
+--       tau0    = 10
+--       kappa   = 0.75
+-- 
+--       go hBar eNext logEpsBar tToReturn n
+--         | n <= 0    = return (tToReturn, logEpsBar, 
+-- 
+--         | otherwise = do
+--             (t0, a, na) <- innerNutsKernel lTarget glTarget t e g
+--             let hBarNext = (1 - 1 / (m - n + tau0)) * hBar
+--                          + (1 / (m - n + tau0)) * (delta - a) 
+-- 
+--                 logEpsNext = mu - ((sqrt (m - n)) / gamma) * hmBar
+-- 
+--                 logEpsBarNext = (m - n) ^ (-kappa) * logEpsNext
+--                               + (1 - (m - n) ^ (-kappa)) * logEpsBar
+-- 
+--             go hBarNext logEpsBarNext t0 (pred n)
+
+
+
+
+innerNutsKernel 
+    :: PrimMonad m 
+    => Density
+    -> Gradient
+    -> Parameters
+    -> Double
+    -> Gen (PrimState m)
+    -> m (Parameters, Double, Int)
+innerNutsKernel lTarget glTarget t e g = do
+  r0 <- replicateM (length t) (normal 0 1 g)
+  u  <- uniformR (0, auxilliaryTarget lTarget t r0) g
+
+  let go (tn, tp, rn, rp, j, tm, n, s) aOrig naOrig g
+        | s == 1 = do
+            vj <- symmetricCategorical [-1, 1] g
+            z  <- uniform g
+
+            (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
+              if   vj == -1
+              then buildTree lTarget glTarget g tn rn u vj j e t r0
+              else buildTree lTarget glTarget g tp rp u vj j e t r0
+
+            let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
+                   | otherwise = t
+
+                n2 = n + n1
+                s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                        * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) 
+                j1 = succ j
+
+            go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) a na g
+
+        | otherwise = return (tm, aOrig, naOrig)
+
+  go (t, t, r0, r0, 0, t, 1, 1) 0 0 g
+
+auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
+auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+
+acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
+acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
+                                    / auxilliaryTarget lTarget t0 r0
+
+innerProduct :: Num a => [a] -> [a] -> a
+innerProduct xs ys = sum $ zipWith (*) xs ys
+
+(.*) :: Num b => b -> [b] -> [b]
+z .* xs = map (* z) xs
+
+(.-) :: Num a => [a] -> [a] -> [a]
+xs .- ys = zipWith (-) xs ys
+
+indicate :: Integral a => Bool -> a
+indicate True  = 1
+indicate False = 0
+
+symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
+symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
+symmetricCategorical zs g = do
+  z <- uniform g
+  return $ zs !! truncate (z * fromIntegral (length zs) :: Double)
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+
+-- Testing
+
+f :: Density
+f _ = log $ 1 / 10
+
+g :: Gradient
+g xs = replicate (length xs) 0
+