commit 710caac18b002d4541189d8b0975ff175608365c
parent 84720a3c176a17f52410a3b140855aad814ce735
Author: Jared Tobin <jared@jtobin.ca>
Date: Mon, 16 Sep 2013 12:42:25 +1200
Add some code for testing.
Diffstat:
4 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -1,3 +1,4 @@
*swp
*hi
+reference
*o
diff --git a/reference/HoffmanGelman2011_NUTS.pdf b/reference/HoffmanGelman2011_NUTS.pdf
Binary files differ.
diff --git a/src/Numeric/MCMC/Examples/Rosenbrock.hs b/src/Numeric/MCMC/Examples/Rosenbrock.hs
@@ -1,4 +1,5 @@
import Numeric.AD
+import Numeric.MCMC.NUTS
import System.Random.MWC
lTarget :: RealFloat a => [a] -> a
@@ -8,7 +9,7 @@ glTarget :: [Double] -> [Double]
glTarget = grad lTarget
inits :: [Double]
-inits = [5.0, 5.0]
+inits = [0.0, 0.0]
epochs :: Int
epochs = 100
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -21,11 +21,12 @@ nuts :: PrimMonad m
=> Density
-> Gradient
-> Int
+ -> Double
-> Parameters
-> Gen (PrimState m)
-> m Parameters
-nuts lTarget glTarget m t g = do
- e <- findReasonableEpsilon lTarget glTarget t g
+nuts lTarget glTarget m e t g = do
+ -- e <- findReasonableEpsilon lTarget glTarget t g
let go 0 t0 = return t0
go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n)
@@ -119,9 +120,9 @@ findReasonableEpsilon lTarget glTarget t0 g = do
let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
- go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
+ go e t r | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) =
let (tn, rn) = leapfrog glTarget (t, r) e
- in go (2 ^ a * e) tn rn
+ in go (2 ^^ a * e) tn rn
| otherwise = e
return $ go 1.0 t1 r1