commit 069d9dbefe1c36247ab05c1b8dd6842b1133b93f
parent e66e88022afcabc67aae30ce697dabb64e785b93
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 22 Sep 2013 20:26:36 +1200
Fix bug in nutsKernel.
Diffstat:
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -20,12 +20,11 @@ nuts :: PrimMonad m
=> Density
-> Gradient
-> Int
- -> Double
-> Parameters
-> Gen (PrimState m)
-> m Parameters
-nuts lTarget glTarget m e t g = do
- -- e <- findReasonableEpsilon lTarget glTarget t g
+nuts lTarget glTarget m t g = do
+ e <- findReasonableEpsilon lTarget glTarget t g
let go 0 t0 = return t0
go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n)
@@ -49,8 +48,14 @@ nutsKernel lTarget glTarget e t g = do
(tnn, rnn, tpp, rpp, t1, n1, s1) <-
if vj == -1
- then buildTree lTarget glTarget g tn rn u vj j e
- else buildTree lTarget glTarget g tp rp u vj j e
+ then do
+ (tnn', rnn', _, _, t1', n1', s1') <-
+ buildTree lTarget glTarget g tn rn u vj j e
+ return (tnn', rnn', tp, rp, t1', n1', s1')
+ else do
+ (_, _, tpp', rpp', t1', n1', s1') <-
+ buildTree lTarget glTarget g tp rp u vj j e
+ return (tn, rn, tpp', rpp', t1', n1', s1')
let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
| otherwise = t