hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 8c3ae473a11deec6ad2769b49617062ab759d3e9
parent e223fc4e94e1ebe18d3e3083831a2284c0387e46
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun,  1 Sep 2013 18:46:08 +1200

Correct buildTree function.

Diffstat:
MNUTS.hs | 56++++++++++++++++++++++++++++++++++++--------------------
1 file changed, 36 insertions(+), 20 deletions(-)

diff --git a/NUTS.hs b/NUTS.hs @@ -1,9 +1,11 @@ -{-# OPTIONS_GHC -Wall #-} +{-# OPTIONS_GHC -Wall -fno-warn-type-defaults #-} import Control.Monad import Control.Monad.Loops import Control.Monad.Primitive -import Data.List +import Data.Hashable +import Data.HashSet (HashSet) +import qualified Data.HashSet as HashSet import System.Random.MWC import System.Random.MWC.Distributions @@ -44,36 +46,40 @@ hmcKernel lTarget glTarget t0 ndisc e g = do -- Utilities ------------------------------------------------------------------ -- TODO quickcheck all this --- change leapfrog to return (parameters, momentum) -buildTree lTarget glTarget t0 r0 u0 v0 j0 e0 = go t0 r0 u0 v0 j0 e0 +buildTree + :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c) + => ([c] -> c) + -> ([c] -> [c]) + -> [c] + -> [c] + -> c + -> c + -> a + -> c + -> ([c], [c], [c], [c], HashSet ([c], [c]), t) +buildTree lTarget glTarget = go where go t r u v 0 e = let (t1, r1) = leapfrog glTarget t r 1 (v * e) - c | u <= auxilliaryTarget lTarget t1 r1 = [(t1, r1)] -- only require a set here - | otherwise = [] + c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1) + | otherwise = HashSet.empty s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1 | otherwise = 0 in (t1, r1, t1, r1, c, s) go t r u v j e = let (tn, rn, tp, rp, c0, s0) = go t r u v (pred j) e - (tnn, rnn, tpp, rpp, c1, s1) = if v == -1 + (tnn, rnn, tpp, rpp, c1, s1) = if roundTo 6 v == -1 then go tn rn u v (pred j) e else go tp rp u v (pred j) e - s2 = s0 * s1 * indicator ((tpp - tnn) * rnn >= 0) -- check these - * indicator ((tpp - tnn) * rpp >= 0) + s2 = s0 * s1 * indicator ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicator ((tpp .- tnn) `innerProduct` rpp >= 0) - c2 = c0 `union` c1 + c2 = c0 `HashSet.union` c1 in (tnn, rnn, tpp, rpp, c2, s2) - - - - - - leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a) => ([c] -> [c]) -- ^ Gradient of log target function -> [c] -- ^ List of parameters to target @@ -89,8 +95,9 @@ leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0) rt = zipWith (+) rm (map (* (0.5 * e)) (glTarget t)) in go tt rt (pred n) --- | Acceptance ratio. t0/r0 denote the present state of the parameters and --- auxilliary variables, and t1/r1 denote the proposed state. +-- | Acceptance ratio for a proposed move. t0/r0 denote the present state of +-- the parameters and auxilliary variables, and t1/r1 denote the proposed +-- state. hmcAcceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1 / auxilliaryTarget lTarget t0 r0 @@ -102,6 +109,15 @@ auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r) innerProduct :: Num a => [a] -> [a] -> a innerProduct xs ys = sum $ zipWith (*) xs ys -indicator p | p = const 1 - | otherwise = const 0 +-- | Vectorized subtraction. +(.-) :: Num a => [a] -> [a] -> [a] +xs .- ys = zipWith (-) xs ys + +indicator :: Integral a => Bool -> a +indicator True = 1 +indicator False = 0 + +-- | Round to a specified number of digits. +roundTo :: RealFrac a => Int -> a -> a +roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)