commit 15c3bd89ed542a1520167485c890d5b674829ce6
parent 069d9dbefe1c36247ab05c1b8dd6842b1133b93f
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 22 Sep 2013 21:57:42 +1200
Add testing code.
Diffstat:
2 files changed, 61 insertions(+), 23 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -4,8 +4,9 @@
module Numeric.MCMC.NUTS where
import Control.Monad
+import Control.Monad.Loops
import Control.Monad.Primitive
-import System.Random.MWC -- FIXME change to Prob monad
+import System.Random.MWC -- FIXME change to Prob monad (Mersenne64)
import System.Random.MWC.Distributions
import Statistics.Distribution.Normal
@@ -15,20 +16,35 @@ type Density = Parameters -> Double
type Gradient = Parameters -> Parameters
type Particle = (Parameters, Parameters)
--- FIXME must be streaming
+newtype BuildTree = BuildTree {
+ getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int)
+ }
+
+instance Show BuildTree where
+ show (BuildTree (tm, rm, tp, rp, t', n, s)) =
+ "\n" ++ "tm: " ++ show tm
+ ++ "\n" ++ "rm: " ++ show rm
+ ++ "\n" ++ "tp: " ++ show tp
+ ++ "\n" ++ "rp: " ++ show rp
+ ++ "\n" ++ "t': " ++ show t'
+ ++ "\n" ++ "n : " ++ show n
+ ++ "\n" ++ "s : " ++ show s
+
nuts :: PrimMonad m
=> Density
-> Gradient
-> Int
+ -> Double
-> Parameters
-> Gen (PrimState m)
- -> m Parameters
-nuts lTarget glTarget m t g = do
- e <- findReasonableEpsilon lTarget glTarget t g
- let go 0 t0 = return t0
- go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n)
-
- go m t
+ -> m [Parameters]
+nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
+ where
+ kernel eps (n, t0) = do
+ t1 <- nutsKernel lTarget glTarget eps t0 g
+ return $ if n <= 0
+ then Nothing
+ else Just (t0, (pred n, t1))
nutsKernel :: PrimMonad m
=> Density
@@ -71,20 +87,6 @@ nutsKernel lTarget glTarget e t g = do
go (t, t, r0, r0, 0, t, 1, 1) g
-newtype BuildTree = BuildTree {
- getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int)
- }
-
-instance Show BuildTree where
- show (BuildTree (tm, rm, tp, rp, t', n, s)) =
- "\n" ++ "tm: " ++ show tm
- ++ "\n" ++ "rm: " ++ show rm
- ++ "\n" ++ "tp: " ++ show tp
- ++ "\n" ++ "rp: " ++ show rp
- ++ "\n" ++ "t': " ++ show t'
- ++ "\n" ++ "n : " ++ show n
- ++ "\n" ++ "s : " ++ show s
-
buildTree
:: PrimMonad m
=> Density
diff --git a/tests/Test.hs b/tests/Test.hs
@@ -0,0 +1,36 @@
+import Control.Monad
+import Control.Monad.Primitive
+import Data.Vector (singleton)
+import Numeric.AD
+import Numeric.MCMC.NUTS
+import System.Random.MWC
+
+lTarget :: RealFloat a => [a] -> a
+lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
+
+glTarget :: [Double] -> [Double]
+glTarget = grad lTarget
+
+t0 :: [Double]
+t0 = [1.0, 1.0]
+
+r0 :: [Double]
+r0 = [0.0, 0.0]
+
+logu = -0.12840 -- from octave
+u = exp logu
+v = -1 :: Double
+
+n = 20 :: Int
+e = 0.1 :: Double
+
+runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree
+runBuildTree g = do
+ liftM BuildTree $ buildTree lTarget glTarget g t0 r0 u v n e
+
+main = do
+ test <- create >>= nuts lTarget glTarget 1000 0.1 t0
+ mapM_ print test
+
+
+