commit 710caac18b002d4541189d8b0975ff175608365c
parent 84720a3c176a17f52410a3b140855aad814ce735
Author: Jared Tobin <jared@jtobin.ca>
Date:   Mon, 16 Sep 2013 12:42:25 +1200
Add some code for testing.
Diffstat:
4 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -1,3 +1,4 @@
 *swp
 *hi
+reference
 *o
diff --git a/reference/HoffmanGelman2011_NUTS.pdf b/reference/HoffmanGelman2011_NUTS.pdf
Binary files differ.
diff --git a/src/Numeric/MCMC/Examples/Rosenbrock.hs b/src/Numeric/MCMC/Examples/Rosenbrock.hs
@@ -1,4 +1,5 @@
 import Numeric.AD
+import Numeric.MCMC.NUTS
 import System.Random.MWC
 
 lTarget :: RealFloat a => [a] -> a
@@ -8,7 +9,7 @@ glTarget :: [Double] -> [Double]
 glTarget = grad lTarget
 
 inits :: [Double]
-inits = [5.0, 5.0]
+inits = [0.0, 0.0]
 
 epochs :: Int
 epochs = 100
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -21,11 +21,12 @@ nuts :: PrimMonad m
      => Density
      -> Gradient
      -> Int
+     -> Double
      -> Parameters
      -> Gen (PrimState m)
      -> m Parameters
-nuts lTarget glTarget m t g = do
-  e <- findReasonableEpsilon lTarget glTarget t g
+nuts lTarget glTarget m e t g = do
+  -- e <- findReasonableEpsilon lTarget glTarget t g
   let go 0 t0 = return t0
       go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n)
 
@@ -119,9 +120,9 @@ findReasonableEpsilon lTarget glTarget t0 g = do
   let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
       a        = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
 
-      go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = 
+      go e t r | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) = 
                    let (tn, rn) = leapfrog glTarget (t, r) e
-                   in  go (2 ^ a * e) tn rn 
+                   in  go (2 ^^ a * e) tn rn 
                | otherwise = e
 
   return $ go 1.0 t1 r1