commit b7b3b030cea290e3a5d1876eab9cc31d75158410
parent 15c3bd89ed542a1520167485c890d5b674829ce6
Author: Jared Tobin <jared@jtobin.ca>
Date:   Thu,  3 Oct 2013 20:21:31 +1300
Fix proposal bug in buildTree.
Diffstat:
2 files changed, 45 insertions(+), 39 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -10,6 +10,8 @@ import System.Random.MWC -- FIXME change to Prob monad (Mersenne64)
 import System.Random.MWC.Distributions
 import Statistics.Distribution.Normal
 
+import Debug.Trace
+
 -- FIXME change to probably api
 type Parameters = [Double] 
 type Density    = Parameters -> Double
@@ -30,14 +32,15 @@ instance Show BuildTree where
     ++ "\n" ++ "n : " ++ show n
     ++ "\n" ++ "s : " ++ show s
 
-nuts :: PrimMonad m
-     => Density
-     -> Gradient
-     -> Int
-     -> Double
-     -> Parameters
-     -> Gen (PrimState m)
-     -> m [Parameters]
+nuts 
+  :: PrimMonad m
+  => Density
+  -> Gradient
+  -> Int
+  -> Double
+  -> Parameters
+  -> Gen (PrimState m)
+  -> m [Parameters]
 nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
   where
     kernel eps (n, t0) = do
@@ -46,16 +49,18 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
                then Nothing
                else Just (t0, (pred n, t1))
 
-nutsKernel :: PrimMonad m 
-           => Density
-           -> Gradient
-           -> Double
-           -> Parameters
-           -> Gen (PrimState m)
-           -> m Parameters
+nutsKernel 
+  :: PrimMonad m 
+  => Density
+  -> Gradient
+  -> Double
+  -> Parameters
+  -> Gen (PrimState m)
+  -> m Parameters
 nutsKernel lTarget glTarget e t g = do
-  r0 <- replicateM (length t) (normal 0 1 g)
-  u  <- uniformR (0, auxilliaryTarget lTarget t r0) g
+  r0   <- replicateM (length t) (normal 0 1 g)
+  z0   <- exponential 1 g
+  let logu = auxilliaryTarget lTarget t r0 - z0
 
   let go (tn, tp, rn, rp, j, tm, n, s) g
         | s == 1 = do
@@ -66,14 +71,14 @@ nutsKernel lTarget glTarget e t g = do
               if   vj == -1
               then do
                 (tnn', rnn', _, _, t1', n1', s1') <- 
-                  buildTree lTarget glTarget g tn rn u vj j e
+                  buildTree lTarget glTarget g tn rn logu vj j e
                 return (tnn', rnn', tp, rp, t1', n1', s1')
               else do
                 (_, _, tpp', rpp', t1', n1', s1') <- 
-                  buildTree lTarget glTarget g tp rp u vj j e
+                  buildTree lTarget glTarget g tp rp logu vj j e
                 return (tn, rn, tpp', rpp', t1', n1', s1')
 
-            let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
+            let t2 | s1 == 1 && (fi n1 / fi n :: Double) > z = t1
                    | otherwise = t
 
                 n2 = n + n1
@@ -101,26 +106,26 @@ buildTree
   -> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
 buildTree lTarget glTarget g = go
   where
-    go t r u v 0 e = return $
-      let (t0, r0) = leapfrog glTarget (t, r) (v * e)
-          auxTgt   = auxilliaryTarget lTarget t0 r0
-          n        = indicate (u <= auxTgt)
-          s        = indicate (auxTgt > log u - 1000)
+    go t r logu v 0 e  = return $
+      let (t0, r0)  = leapfrog glTarget (t, r) (v * e)
+          auxTarget = auxilliaryTarget lTarget t0 r0
+          n         = indicate (logu < auxTarget)
+          s         = indicate (logu - 1000 < auxTarget)
       in  (t0, r0, t0, r0, t0, n, s)
 
-    go t r u v j e = do
+    go t r logu v j e = do
       z <- uniform g
-      (tn, rn, tp, rp, t0, n0, s0) <- go t r u v (pred j) e
+      (tn, rn, tp, rp, t0, n0, s0) <- go t r logu v (pred j) e
 
       if   s0 == 1
       then do
         (tnn, rnn, tpp, rpp, t1, n1, s1) <- 
           if   v == -1
           then do
-            (tnn', rnn', _, _, t1', n1', s1') <- go tn rn u v (pred j) e
+            (tnn', rnn', _, _, t1', n1', s1') <- go tn rn logu v (pred j) e
             return (tnn', rnn', tp, rp, t1', n1', s1')
           else do
-            (   _,    _, tpp', rpp', t1', n1', s1') <- go tp rp u v (pred j) e
+            (_, _, tpp', rpp', t1', n1', s1') <- go tp rp logu v (pred j) e
             return (tn, rn, tpp', rpp', t1', n1', s1')
 
         let p  = fromIntegral n1 / fromIntegral (n0 + n1)
@@ -133,12 +138,13 @@ buildTree lTarget glTarget g = go
         return (tnn, rnn, tpp, rpp, t2, n2, s2)
       else return (tn, rn, tp, rp, t0, n0, s0)
 
-findReasonableEpsilon :: PrimMonad m 
-                      => Density
-                      -> Gradient
-                      -> Parameters
-                      -> Gen (PrimState m) 
-                      -> m Double
+findReasonableEpsilon
+  :: PrimMonad m 
+  => Density
+  -> Gradient
+  -> Parameters
+  -> Gen (PrimState m) 
+  -> m Double
 findReasonableEpsilon lTarget glTarget t0 g = do
   r0 <- replicateM (length t0) (normal 0 1 g)
   let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
diff --git a/tests/Test.hs b/tests/Test.hs
@@ -12,7 +12,7 @@ glTarget :: [Double] -> [Double]
 glTarget = grad lTarget
 
 t0 :: [Double]
-t0 = [1.0, 1.0]
+t0 = [0.0, 0.0]
 
 r0 :: [Double]
 r0 = [0.0, 0.0]
@@ -21,15 +21,15 @@ logu = -0.12840 -- from octave
 u    = exp logu
 v    = -1 :: Double
 
-n = 20   :: Int
+n = 9   :: Int
 e = 0.1 :: Double
 
 runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree
 runBuildTree g = do
-  liftM BuildTree $ buildTree lTarget glTarget g t0 r0 u v n e
+  liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e
 
 main = do
-  test <- create >>= nuts lTarget glTarget 1000 0.1 t0
+  test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 1000 0.1 t0
   mapM_ print test