commit 55844abe596de30bd4c7f52f93bf7a59259b410a
parent d76127a13afbcdc63e371d6ff5413cfc5a129270
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sat,  2 May 2015 14:45:50 +1200
Some misc work I had lying around.
Diffstat:
2 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -14,13 +14,14 @@ type Parameters = [Double]
 type Density    = Parameters -> Double
 type Gradient   = Parameters -> Parameters
 type Particle   = (Parameters, Parameters)
+type StateSpecs = ([Double], [Double], [Double], [Double], [Double], Int, Int)
 
-newtype BuildTree = BuildTree { 
-    getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int)
+newtype StateTree = StateTree { 
+    getStateTree :: StateSpecs
   }
 
-instance Show BuildTree where
-  show (BuildTree (tm, rm, tp, rp, t', n, s)) = 
+instance Show StateTree where
+  show (StateTree (tm, rm, tp, rp, t', n, s)) = 
        "\n" ++ "tm: " ++ show tm 
     ++ "\n" ++ "tp: " ++ show tp
     ++ "\n" ++ "t': " ++ show t'
@@ -207,7 +208,7 @@ buildTree
   -> Double
   -> Int
   -> Double
-  -> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
+  -> m StateSpecs
 buildTree lTarget glTarget g t r logu v 0 e = do
   let (t0, r0) = leapfrog glTarget (t, r) (v * e)
       joint    = log $ auxilliaryTarget lTarget t0 r0
@@ -291,7 +292,7 @@ buildTreeDualAvg lTarget glTarget g t r logu v j e t0 r0 = do
           buildTreeDualAvg lTarget glTarget g tp rp logu v (pred j) e t0 r0
         return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
 
-    let p      = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1
+    let p      = fi n2 / max (fi (n1 + n2)) 1
         accept = p > (z :: Double)
         n3     = n1 + n2
         a3     = a1 + a2
diff --git a/src/Numeric/MCMC/NUTS/Examples.hs b/src/Numeric/MCMC/NUTS/Examples.hs
@@ -54,5 +54,5 @@ printTrace :: Show a => [a] -> IO ()
 printTrace = mapM_ (putStrLn . filter (`notElem` "[]") . show) 
 
 main :: IO ()
-main = bnnTrace >>= printTrace 
+main = bealeTrace >>= printTrace