hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit b7b3b030cea290e3a5d1876eab9cc31d75158410
parent 15c3bd89ed542a1520167485c890d5b674829ce6
Author: Jared Tobin <jared@jtobin.ca>
Date:   Thu,  3 Oct 2013 20:21:31 +1300

Fix proposal bug in buildTree.

Diffstat:
Msrc/Numeric/MCMC/NUTS.hs | 76+++++++++++++++++++++++++++++++++++++++++-----------------------------------
Mtests/Test.hs | 8++++----
2 files changed, 45 insertions(+), 39 deletions(-)

diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs @@ -10,6 +10,8 @@ import System.Random.MWC -- FIXME change to Prob monad (Mersenne64) import System.Random.MWC.Distributions import Statistics.Distribution.Normal +import Debug.Trace + -- FIXME change to probably api type Parameters = [Double] type Density = Parameters -> Double @@ -30,14 +32,15 @@ instance Show BuildTree where ++ "\n" ++ "n : " ++ show n ++ "\n" ++ "s : " ++ show s -nuts :: PrimMonad m - => Density - -> Gradient - -> Int - -> Double - -> Parameters - -> Gen (PrimState m) - -> m [Parameters] +nuts + :: PrimMonad m + => Density + -> Gradient + -> Int + -> Double + -> Parameters + -> Gen (PrimState m) + -> m [Parameters] nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t) where kernel eps (n, t0) = do @@ -46,16 +49,18 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t) then Nothing else Just (t0, (pred n, t1)) -nutsKernel :: PrimMonad m - => Density - -> Gradient - -> Double - -> Parameters - -> Gen (PrimState m) - -> m Parameters +nutsKernel + :: PrimMonad m + => Density + -> Gradient + -> Double + -> Parameters + -> Gen (PrimState m) + -> m Parameters nutsKernel lTarget glTarget e t g = do - r0 <- replicateM (length t) (normal 0 1 g) - u <- uniformR (0, auxilliaryTarget lTarget t r0) g + r0 <- replicateM (length t) (normal 0 1 g) + z0 <- exponential 1 g + let logu = auxilliaryTarget lTarget t r0 - z0 let go (tn, tp, rn, rp, j, tm, n, s) g | s == 1 = do @@ -66,14 +71,14 @@ nutsKernel lTarget glTarget e t g = do if vj == -1 then do (tnn', rnn', _, _, t1', n1', s1') <- - buildTree lTarget glTarget g tn rn u vj j e + buildTree lTarget glTarget g tn rn logu vj j e return (tnn', rnn', tp, rp, t1', n1', s1') else do (_, _, tpp', rpp', t1', n1', s1') <- - buildTree lTarget glTarget g tp rp u vj j e + buildTree lTarget glTarget g tp rp logu vj j e return (tn, rn, tpp', rpp', t1', n1', s1') - let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn + let t2 | s1 == 1 && (fi n1 / fi n :: Double) > z = t1 | otherwise = t n2 = n + n1 @@ -101,26 +106,26 @@ buildTree -> m ([Double], [Double], [Double], [Double], [Double], Int, Int) buildTree lTarget glTarget g = go where - go t r u v 0 e = return $ - let (t0, r0) = leapfrog glTarget (t, r) (v * e) - auxTgt = auxilliaryTarget lTarget t0 r0 - n = indicate (u <= auxTgt) - s = indicate (auxTgt > log u - 1000) + go t r logu v 0 e = return $ + let (t0, r0) = leapfrog glTarget (t, r) (v * e) + auxTarget = auxilliaryTarget lTarget t0 r0 + n = indicate (logu < auxTarget) + s = indicate (logu - 1000 < auxTarget) in (t0, r0, t0, r0, t0, n, s) - go t r u v j e = do + go t r logu v j e = do z <- uniform g - (tn, rn, tp, rp, t0, n0, s0) <- go t r u v (pred j) e + (tn, rn, tp, rp, t0, n0, s0) <- go t r logu v (pred j) e if s0 == 1 then do (tnn, rnn, tpp, rpp, t1, n1, s1) <- if v == -1 then do - (tnn', rnn', _, _, t1', n1', s1') <- go tn rn u v (pred j) e + (tnn', rnn', _, _, t1', n1', s1') <- go tn rn logu v (pred j) e return (tnn', rnn', tp, rp, t1', n1', s1') else do - ( _, _, tpp', rpp', t1', n1', s1') <- go tp rp u v (pred j) e + (_, _, tpp', rpp', t1', n1', s1') <- go tp rp logu v (pred j) e return (tn, rn, tpp', rpp', t1', n1', s1') let p = fromIntegral n1 / fromIntegral (n0 + n1) @@ -133,12 +138,13 @@ buildTree lTarget glTarget g = go return (tnn, rnn, tpp, rpp, t2, n2, s2) else return (tn, rn, tp, rp, t0, n0, s0) -findReasonableEpsilon :: PrimMonad m - => Density - -> Gradient - -> Parameters - -> Gen (PrimState m) - -> m Double +findReasonableEpsilon + :: PrimMonad m + => Density + -> Gradient + -> Parameters + -> Gen (PrimState m) + -> m Double findReasonableEpsilon lTarget glTarget t0 g = do r0 <- replicateM (length t0) (normal 0 1 g) let (t1, r1) = leapfrog glTarget (t0, r0) 1.0 diff --git a/tests/Test.hs b/tests/Test.hs @@ -12,7 +12,7 @@ glTarget :: [Double] -> [Double] glTarget = grad lTarget t0 :: [Double] -t0 = [1.0, 1.0] +t0 = [0.0, 0.0] r0 :: [Double] r0 = [0.0, 0.0] @@ -21,15 +21,15 @@ logu = -0.12840 -- from octave u = exp logu v = -1 :: Double -n = 20 :: Int +n = 9 :: Int e = 0.1 :: Double runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree runBuildTree g = do - liftM BuildTree $ buildTree lTarget glTarget g t0 r0 u v n e + liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e main = do - test <- create >>= nuts lTarget glTarget 1000 0.1 t0 + test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 1000 0.1 t0 mapM_ print test