hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 55844abe596de30bd4c7f52f93bf7a59259b410a
parent d76127a13afbcdc63e371d6ff5413cfc5a129270
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sat,  2 May 2015 14:45:50 +1200

Some misc work I had lying around.

Diffstat:
Msrc/Numeric/MCMC/NUTS.hs | 13+++++++------
Msrc/Numeric/MCMC/NUTS/Examples.hs | 2+-
2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs @@ -14,13 +14,14 @@ type Parameters = [Double] type Density = Parameters -> Double type Gradient = Parameters -> Parameters type Particle = (Parameters, Parameters) +type StateSpecs = ([Double], [Double], [Double], [Double], [Double], Int, Int) -newtype BuildTree = BuildTree { - getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int) +newtype StateTree = StateTree { + getStateTree :: StateSpecs } -instance Show BuildTree where - show (BuildTree (tm, rm, tp, rp, t', n, s)) = +instance Show StateTree where + show (StateTree (tm, rm, tp, rp, t', n, s)) = "\n" ++ "tm: " ++ show tm ++ "\n" ++ "tp: " ++ show tp ++ "\n" ++ "t': " ++ show t' @@ -207,7 +208,7 @@ buildTree -> Double -> Int -> Double - -> m ([Double], [Double], [Double], [Double], [Double], Int, Int) + -> m StateSpecs buildTree lTarget glTarget g t r logu v 0 e = do let (t0, r0) = leapfrog glTarget (t, r) (v * e) joint = log $ auxilliaryTarget lTarget t0 r0 @@ -291,7 +292,7 @@ buildTreeDualAvg lTarget glTarget g t r logu v j e t0 r0 = do buildTreeDualAvg lTarget glTarget g tp rp logu v (pred j) e t0 r0 return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') - let p = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1 + let p = fi n2 / max (fi (n1 + n2)) 1 accept = p > (z :: Double) n3 = n1 + n2 a3 = a1 + a2 diff --git a/src/Numeric/MCMC/NUTS/Examples.hs b/src/Numeric/MCMC/NUTS/Examples.hs @@ -54,5 +54,5 @@ printTrace :: Show a => [a] -> IO () printTrace = mapM_ (putStrLn . filter (`notElem` "[]") . show) main :: IO () -main = bnnTrace >>= printTrace +main = bealeTrace >>= printTrace