commit ad7925d2aa33a7ed0f6822370ddc12a746e40f03
parent b7b3b030cea290e3a5d1876eab9cc31d75158410
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun, 13 Oct 2013 21:47:15 +1300
Switch buildTree to a fully recursive version.
Diffstat:
2 files changed, 68 insertions(+), 73 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -6,13 +6,10 @@ module Numeric.MCMC.NUTS where
 import Control.Monad
 import Control.Monad.Loops
 import Control.Monad.Primitive
-import System.Random.MWC -- FIXME change to Prob monad (Mersenne64)
+import System.Random.MWC
 import System.Random.MWC.Distributions
 import Statistics.Distribution.Normal
 
-import Debug.Trace
-
--- FIXME change to probably api
 type Parameters = [Double] 
 type Density    = Parameters -> Double
 type Gradient   = Parameters -> Parameters
@@ -25,13 +22,14 @@ newtype BuildTree = BuildTree {
 instance Show BuildTree where
   show (BuildTree (tm, rm, tp, rp, t', n, s)) = 
        "\n" ++ "tm: " ++ show tm 
-    ++ "\n" ++ "rm: " ++ show rm
+    -- ++ "\n" ++ "rm: " ++ show rm
     ++ "\n" ++ "tp: " ++ show tp
-    ++ "\n" ++ "rp: " ++ show rp
+    -- ++ "\n" ++ "rp: " ++ show rp
     ++ "\n" ++ "t': " ++ show t'
     ++ "\n" ++ "n : " ++ show n
     ++ "\n" ++ "s : " ++ show s
 
+-- | The NUTS sampler.
 nuts 
   :: PrimMonad m
   => Density
@@ -49,6 +47,7 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
                then Nothing
                else Just (t0, (pred n, t1))
 
+-- | A single iteration of NUTS.
 nutsKernel 
   :: PrimMonad m 
   => Density
@@ -78,8 +77,9 @@ nutsKernel lTarget glTarget e t g = do
                   buildTree lTarget glTarget g tp rp logu vj j e
                 return (tn, rn, tpp', rpp', t1', n1', s1')
 
-            let t2 | s1 == 1 && (fi n1 / fi n :: Double) > z = t1
-                   | otherwise = t
+            let t2 | s1 == 1 
+                  && (fi n1 / fi n :: Double) > z = t1
+                   | otherwise                    = tm
 
                 n2 = n + n1
                 s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
@@ -92,6 +92,7 @@ nutsKernel lTarget glTarget e t g = do
 
   go (t, t, r0, r0, 0, t, 1, 1) g
 
+-- | Build the 'tree' of candidate states.
 buildTree 
   :: PrimMonad m 
   => Density
@@ -104,103 +105,98 @@ buildTree
   -> Int
   -> Double
   -> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
-buildTree lTarget glTarget g = go
-  where
-    go t r logu v 0 e  = return $
-      let (t0, r0)  = leapfrog glTarget (t, r) (v * e)
-          auxTarget = auxilliaryTarget lTarget t0 r0
-          n         = indicate (logu < auxTarget)
-          s         = indicate (logu - 1000 < auxTarget)
-      in  (t0, r0, t0, r0, t0, n, s)
-
-    go t r logu v j e = do
-      z <- uniform g
-      (tn, rn, tp, rp, t0, n0, s0) <- go t r logu v (pred j) e
-
-      if   s0 == 1
+buildTree lTarget glTarget g t r logu v 0 e = do
+  let (t0, r0)   = leapfrog glTarget (t, r) (v * e)
+      lAuxTarget = log $ auxilliaryTarget lTarget t0 r0
+      n          = indicate (logu < lAuxTarget)
+      s          = indicate (logu - 1000 < lAuxTarget)
+  return (t0, r0, t0, r0, t0, n, s)
+
+buildTree lTarget glTarget g t r logu v j e = do
+  z <- uniform g
+  (tn, rn, tp, rp, t0, n0, s0) <- 
+    buildTree lTarget glTarget g t r logu v (pred j) e
+
+  if   s0 == 1
+  then do
+    (tnn, rnn, tpp, rpp, t1, n1, s1) <- 
+      if   v == -1
       then do
-        (tnn, rnn, tpp, rpp, t1, n1, s1) <- 
-          if   v == -1
-          then do
-            (tnn', rnn', _, _, t1', n1', s1') <- go tn rn logu v (pred j) e
-            return (tnn', rnn', tp, rp, t1', n1', s1')
-          else do
-            (_, _, tpp', rpp', t1', n1', s1') <- go tp rp logu v (pred j) e
-            return (tn, rn, tpp', rpp', t1', n1', s1')
-
-        let p  = fromIntegral n1 / fromIntegral (n0 + n1)
-            n2 = n0 + n1
-            t2 | p > (z :: Double) = t1
-               | otherwise         = t0 
-            s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+        (tnn', rnn', _, _, t1', n1', s1') <- 
+          buildTree lTarget glTarget g tn rn logu v (pred j) e
+        return (tnn', rnn', tp, rp, t1', n1', s1')
+      else do
+        (_, _, tpp', rpp', t1', n1', s1') <- 
+          buildTree lTarget glTarget g tp rp logu v (pred j) e
+        return (tn, rn, tpp', rpp', t1', n1', s1')
+
+    let accept = (fi n1 / max (fi (n0 + n1)) 1) > (z :: Double)
+        n2     = n0 + n1
+        t2     | accept    = t1
+               | otherwise = t0 
+        s2     = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
                     * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
 
-        return (tnn, rnn, tpp, rpp, t2, n2, s2)
-      else return (tn, rn, tp, rp, t0, n0, s0)
-
-findReasonableEpsilon
-  :: PrimMonad m 
-  => Density
-  -> Gradient
-  -> Parameters
-  -> Gen (PrimState m) 
-  -> m Double
-findReasonableEpsilon lTarget glTarget t0 g = do
-  r0 <- replicateM (length t0) (normal 0 1 g)
-  let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
-      a        = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
-
-      go e t r | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) = 
-                   let (tn, rn) = leapfrog glTarget (t, r) e
-                   in  go (2 ^^ a * e) tn rn 
-               | otherwise = e
-
-  return $ go 1.0 t1 r1
+    return (tnn, rnn, tpp, rpp, t2, n2, s2)
+  else return (tn, rn, tp, rp, t0, n0, s0)
 
+-- | Simulate Hamiltonian dynamics for n steps.
 leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
 leapfrogIntegrator n glTarget particle e = go particle n
   where go state ndisc 
           | ndisc <= 0 = state
           | otherwise  = go (leapfrog glTarget state e) (pred n)
 
+-- | Simulate a single step of Hamiltonian dynamics.
 leapfrog :: Gradient -> Particle -> Double -> Particle
 leapfrog glTarget (t, r) e = (tf, rf)
-  where rm = adjustMomentum glTarget e t r
-        tf = adjustPosition e rm t
-        rf = adjustMomentum glTarget e tf rm
+  where 
+    rm = adjustMomentum glTarget e t r
+    tf = adjustPosition e rm t
+    rf = adjustMomentum glTarget e tf rm
 
+-- | Adjust momentum.
 adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c]
 adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t)
 
+-- | Adjust position.
 adjustPosition :: Num c => c -> [c] -> [c] -> [c]
 adjustPosition e r t = zipWith (+) t (e .* r)
 
+-- | The MH acceptance ratio for a given proposal.
 acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
 acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
                                     / auxilliaryTarget lTarget t0 r0
 
+-- | The negative potential. 
 auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
 auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
 
+-- | Simple inner product.
 innerProduct :: Num a => [a] -> [a] -> a
 innerProduct xs ys = sum $ zipWith (*) xs ys
 
+-- | Vectorized multiplication.
 (.*) :: Num b => b -> [b] -> [b]
 z .* xs = map (* z) xs
 
+-- | Vectorized subtraction.
 (.-) :: Num a => [a] -> [a] -> [a]
 xs .- ys = zipWith (-) xs ys
 
+-- | Indicator function.
 indicate :: Integral a => Bool -> a
 indicate True  = 1
 indicate False = 0
 
+-- | A symmetric categorical (discrete uniform) distribution.
 symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
 symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
 symmetricCategorical zs g = do
   z <- uniform g
   return $ zs !! truncate (z * fromIntegral (length zs) :: Double)
 
+-- | Alias for fromIntegral.
 fi :: (Integral a, Num b) => a -> b
 fi = fromIntegral
 
diff --git a/tests/Test.hs b/tests/Test.hs
@@ -11,25 +11,24 @@ lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
 glTarget :: [Double] -> [Double]
 glTarget = grad lTarget
 
-t0 :: [Double]
-t0 = [0.0, 0.0]
-
-r0 :: [Double]
-r0 = [0.0, 0.0]
-
-logu = -0.12840 -- from octave
-u    = exp logu
-v    = -1 :: Double
-
-n = 9   :: Int
-e = 0.1 :: Double
+-- glTarget [x, y] =
+--   let dx = 20 * x * (y - x ^ 2) + 0.1 * (1 - x)
+--       dy = -10 * (y - x ^ 2)
+--   in  [dx, dy]
+
+t0   = [0.0, 0.0] :: [Double]
+r0   = [0.0, 0.0] :: [Double]
+logu = -0.12840   :: Double
+v    = -1         :: Double
+n    = 5          :: Int
+e    = 0.1        :: Double
 
 runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree
 runBuildTree g = do
   liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e
 
 main = do
-  test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 1000 0.1 t0
+  test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 20000 0.075 t0
   mapM_ print test