hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 15c3bd89ed542a1520167485c890d5b674829ce6
parent 069d9dbefe1c36247ab05c1b8dd6842b1133b93f
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun, 22 Sep 2013 21:57:42 +1200

Add testing code.

Diffstat:
Msrc/Numeric/MCMC/NUTS.hs | 48+++++++++++++++++++++++++-----------------------
Atests/Test.hs | 36++++++++++++++++++++++++++++++++++++
2 files changed, 61 insertions(+), 23 deletions(-)

diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs @@ -4,8 +4,9 @@ module Numeric.MCMC.NUTS where import Control.Monad +import Control.Monad.Loops import Control.Monad.Primitive -import System.Random.MWC -- FIXME change to Prob monad +import System.Random.MWC -- FIXME change to Prob monad (Mersenne64) import System.Random.MWC.Distributions import Statistics.Distribution.Normal @@ -15,20 +16,35 @@ type Density = Parameters -> Double type Gradient = Parameters -> Parameters type Particle = (Parameters, Parameters) --- FIXME must be streaming +newtype BuildTree = BuildTree { + getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int) + } + +instance Show BuildTree where + show (BuildTree (tm, rm, tp, rp, t', n, s)) = + "\n" ++ "tm: " ++ show tm + ++ "\n" ++ "rm: " ++ show rm + ++ "\n" ++ "tp: " ++ show tp + ++ "\n" ++ "rp: " ++ show rp + ++ "\n" ++ "t': " ++ show t' + ++ "\n" ++ "n : " ++ show n + ++ "\n" ++ "s : " ++ show s + nuts :: PrimMonad m => Density -> Gradient -> Int + -> Double -> Parameters -> Gen (PrimState m) - -> m Parameters -nuts lTarget glTarget m t g = do - e <- findReasonableEpsilon lTarget glTarget t g - let go 0 t0 = return t0 - go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n) - - go m t + -> m [Parameters] +nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t) + where + kernel eps (n, t0) = do + t1 <- nutsKernel lTarget glTarget eps t0 g + return $ if n <= 0 + then Nothing + else Just (t0, (pred n, t1)) nutsKernel :: PrimMonad m => Density @@ -71,20 +87,6 @@ nutsKernel lTarget glTarget e t g = do go (t, t, r0, r0, 0, t, 1, 1) g -newtype BuildTree = BuildTree { - getBuildTree :: ([Double], [Double], [Double], [Double], [Double], Int, Int) - } - -instance Show BuildTree where - show (BuildTree (tm, rm, tp, rp, t', n, s)) = - "\n" ++ "tm: " ++ show tm - ++ "\n" ++ "rm: " ++ show rm - ++ "\n" ++ "tp: " ++ show tp - ++ "\n" ++ "rp: " ++ show rp - ++ "\n" ++ "t': " ++ show t' - ++ "\n" ++ "n : " ++ show n - ++ "\n" ++ "s : " ++ show s - buildTree :: PrimMonad m => Density diff --git a/tests/Test.hs b/tests/Test.hs @@ -0,0 +1,36 @@ +import Control.Monad +import Control.Monad.Primitive +import Data.Vector (singleton) +import Numeric.AD +import Numeric.MCMC.NUTS +import System.Random.MWC + +lTarget :: RealFloat a => [a] -> a +lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) + +glTarget :: [Double] -> [Double] +glTarget = grad lTarget + +t0 :: [Double] +t0 = [1.0, 1.0] + +r0 :: [Double] +r0 = [0.0, 0.0] + +logu = -0.12840 -- from octave +u = exp logu +v = -1 :: Double + +n = 20 :: Int +e = 0.1 :: Double + +runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree +runBuildTree g = do + liftM BuildTree $ buildTree lTarget glTarget g t0 r0 u v n e + +main = do + test <- create >>= nuts lTarget glTarget 1000 0.1 t0 + mapM_ print test + + +