commit b7b3b030cea290e3a5d1876eab9cc31d75158410
parent 15c3bd89ed542a1520167485c890d5b674829ce6
Author: Jared Tobin <jared@jtobin.ca>
Date: Thu, 3 Oct 2013 20:21:31 +1300
Fix proposal bug in buildTree.
Diffstat:
2 files changed, 45 insertions(+), 39 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -10,6 +10,8 @@ import System.Random.MWC -- FIXME change to Prob monad (Mersenne64)
import System.Random.MWC.Distributions
import Statistics.Distribution.Normal
+import Debug.Trace
+
-- FIXME change to probably api
type Parameters = [Double]
type Density = Parameters -> Double
@@ -30,14 +32,15 @@ instance Show BuildTree where
++ "\n" ++ "n : " ++ show n
++ "\n" ++ "s : " ++ show s
-nuts :: PrimMonad m
- => Density
- -> Gradient
- -> Int
- -> Double
- -> Parameters
- -> Gen (PrimState m)
- -> m [Parameters]
+nuts
+ :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Int
+ -> Double
+ -> Parameters
+ -> Gen (PrimState m)
+ -> m [Parameters]
nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
where
kernel eps (n, t0) = do
@@ -46,16 +49,18 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
then Nothing
else Just (t0, (pred n, t1))
-nutsKernel :: PrimMonad m
- => Density
- -> Gradient
- -> Double
- -> Parameters
- -> Gen (PrimState m)
- -> m Parameters
+nutsKernel
+ :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Double
+ -> Parameters
+ -> Gen (PrimState m)
+ -> m Parameters
nutsKernel lTarget glTarget e t g = do
- r0 <- replicateM (length t) (normal 0 1 g)
- u <- uniformR (0, auxilliaryTarget lTarget t r0) g
+ r0 <- replicateM (length t) (normal 0 1 g)
+ z0 <- exponential 1 g
+ let logu = auxilliaryTarget lTarget t r0 - z0
let go (tn, tp, rn, rp, j, tm, n, s) g
| s == 1 = do
@@ -66,14 +71,14 @@ nutsKernel lTarget glTarget e t g = do
if vj == -1
then do
(tnn', rnn', _, _, t1', n1', s1') <-
- buildTree lTarget glTarget g tn rn u vj j e
+ buildTree lTarget glTarget g tn rn logu vj j e
return (tnn', rnn', tp, rp, t1', n1', s1')
else do
(_, _, tpp', rpp', t1', n1', s1') <-
- buildTree lTarget glTarget g tp rp u vj j e
+ buildTree lTarget glTarget g tp rp logu vj j e
return (tn, rn, tpp', rpp', t1', n1', s1')
- let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
+ let t2 | s1 == 1 && (fi n1 / fi n :: Double) > z = t1
| otherwise = t
n2 = n + n1
@@ -101,26 +106,26 @@ buildTree
-> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
buildTree lTarget glTarget g = go
where
- go t r u v 0 e = return $
- let (t0, r0) = leapfrog glTarget (t, r) (v * e)
- auxTgt = auxilliaryTarget lTarget t0 r0
- n = indicate (u <= auxTgt)
- s = indicate (auxTgt > log u - 1000)
+ go t r logu v 0 e = return $
+ let (t0, r0) = leapfrog glTarget (t, r) (v * e)
+ auxTarget = auxilliaryTarget lTarget t0 r0
+ n = indicate (logu < auxTarget)
+ s = indicate (logu - 1000 < auxTarget)
in (t0, r0, t0, r0, t0, n, s)
- go t r u v j e = do
+ go t r logu v j e = do
z <- uniform g
- (tn, rn, tp, rp, t0, n0, s0) <- go t r u v (pred j) e
+ (tn, rn, tp, rp, t0, n0, s0) <- go t r logu v (pred j) e
if s0 == 1
then do
(tnn, rnn, tpp, rpp, t1, n1, s1) <-
if v == -1
then do
- (tnn', rnn', _, _, t1', n1', s1') <- go tn rn u v (pred j) e
+ (tnn', rnn', _, _, t1', n1', s1') <- go tn rn logu v (pred j) e
return (tnn', rnn', tp, rp, t1', n1', s1')
else do
- ( _, _, tpp', rpp', t1', n1', s1') <- go tp rp u v (pred j) e
+ (_, _, tpp', rpp', t1', n1', s1') <- go tp rp logu v (pred j) e
return (tn, rn, tpp', rpp', t1', n1', s1')
let p = fromIntegral n1 / fromIntegral (n0 + n1)
@@ -133,12 +138,13 @@ buildTree lTarget glTarget g = go
return (tnn, rnn, tpp, rpp, t2, n2, s2)
else return (tn, rn, tp, rp, t0, n0, s0)
-findReasonableEpsilon :: PrimMonad m
- => Density
- -> Gradient
- -> Parameters
- -> Gen (PrimState m)
- -> m Double
+findReasonableEpsilon
+ :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Parameters
+ -> Gen (PrimState m)
+ -> m Double
findReasonableEpsilon lTarget glTarget t0 g = do
r0 <- replicateM (length t0) (normal 0 1 g)
let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
diff --git a/tests/Test.hs b/tests/Test.hs
@@ -12,7 +12,7 @@ glTarget :: [Double] -> [Double]
glTarget = grad lTarget
t0 :: [Double]
-t0 = [1.0, 1.0]
+t0 = [0.0, 0.0]
r0 :: [Double]
r0 = [0.0, 0.0]
@@ -21,15 +21,15 @@ logu = -0.12840 -- from octave
u = exp logu
v = -1 :: Double
-n = 20 :: Int
+n = 9 :: Int
e = 0.1 :: Double
runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree
runBuildTree g = do
- liftM BuildTree $ buildTree lTarget glTarget g t0 r0 u v n e
+ liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e
main = do
- test <- create >>= nuts lTarget glTarget 1000 0.1 t0
+ test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 1000 0.1 t0
mapM_ print test