commit 15c3bd89ed542a1520167485c890d5b674829ce6
parent 069d9dbefe1c36247ab05c1b8dd6842b1133b93f
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun, 22 Sep 2013 21:57:42 +1200
Add testing code.
Diffstat:
2 files changed, 61 insertions(+), 23 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -4,8 +4,9 @@
 module Numeric.MCMC.NUTS where
 
 import Control.Monad
+import Control.Monad.Loops
 import Control.Monad.Primitive
-import System.Random.MWC -- FIXME change to Prob monad
+import System.Random.MWC -- FIXME change to Prob monad (Mersenne64)
 import System.Random.MWC.Distributions
 import Statistics.Distribution.Normal
 
@@ -15,20 +16,35 @@ type Density    = Parameters -> Double
 type Gradient   = Parameters -> Parameters
 type Particle   = (Parameters, Parameters)
 
--- FIXME must be streaming
+newtype BuildTree = BuildTree { 
+    getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int)
+  }
+
+instance Show BuildTree where
+  show (BuildTree (tm, rm, tp, rp, t', n, s)) = 
+       "\n" ++ "tm: " ++ show tm 
+    ++ "\n" ++ "rm: " ++ show rm
+    ++ "\n" ++ "tp: " ++ show tp
+    ++ "\n" ++ "rp: " ++ show rp
+    ++ "\n" ++ "t': " ++ show t'
+    ++ "\n" ++ "n : " ++ show n
+    ++ "\n" ++ "s : " ++ show s
+
 nuts :: PrimMonad m
      => Density
      -> Gradient
      -> Int
+     -> Double
      -> Parameters
      -> Gen (PrimState m)
-     -> m Parameters
-nuts lTarget glTarget m t g = do
-  e <- findReasonableEpsilon lTarget glTarget t g
-  let go 0 t0 = return t0
-      go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n)
-
-  go m t
+     -> m [Parameters]
+nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
+  where
+    kernel eps (n, t0) = do
+      t1 <- nutsKernel lTarget glTarget eps t0 g
+      return $ if   n <= 0
+               then Nothing
+               else Just (t0, (pred n, t1))
 
 nutsKernel :: PrimMonad m 
            => Density
@@ -71,20 +87,6 @@ nutsKernel lTarget glTarget e t g = do
 
   go (t, t, r0, r0, 0, t, 1, 1) g
 
-newtype BuildTree = BuildTree { 
-    getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int)
-  }
-
-instance Show BuildTree where
-  show (BuildTree (tm, rm, tp, rp, t', n, s)) = 
-       "\n" ++ "tm: " ++ show tm 
-    ++ "\n" ++ "rm: " ++ show rm
-    ++ "\n" ++ "tp: " ++ show tp
-    ++ "\n" ++ "rp: " ++ show rp
-    ++ "\n" ++ "t': " ++ show t'
-    ++ "\n" ++ "n : " ++ show n
-    ++ "\n" ++ "s : " ++ show s
-
 buildTree 
   :: PrimMonad m 
   => Density
diff --git a/tests/Test.hs b/tests/Test.hs
@@ -0,0 +1,36 @@
+import Control.Monad
+import Control.Monad.Primitive
+import Data.Vector (singleton)
+import Numeric.AD
+import Numeric.MCMC.NUTS
+import System.Random.MWC
+
+lTarget :: RealFloat a => [a] -> a
+lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
+
+glTarget :: [Double] -> [Double]
+glTarget = grad lTarget
+
+t0 :: [Double]
+t0 = [1.0, 1.0]
+
+r0 :: [Double]
+r0 = [0.0, 0.0]
+
+logu = -0.12840 -- from octave
+u    = exp logu
+v    = -1 :: Double
+
+n = 20   :: Int
+e = 0.1 :: Double
+
+runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree
+runBuildTree g = do
+  liftM BuildTree $ buildTree lTarget glTarget g t0 r0 u v n e
+
+main = do
+  test <- create >>= nuts lTarget glTarget 1000 0.1 t0
+  mapM_ print test
+
+
+