commit 8c3ae473a11deec6ad2769b49617062ab759d3e9
parent e223fc4e94e1ebe18d3e3083831a2284c0387e46
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 1 Sep 2013 18:46:08 +1200
Correct buildTree function.
Diffstat:
M | NUTS.hs | | | 56 | ++++++++++++++++++++++++++++++++++++-------------------- |
1 file changed, 36 insertions(+), 20 deletions(-)
diff --git a/NUTS.hs b/NUTS.hs
@@ -1,9 +1,11 @@
-{-# OPTIONS_GHC -Wall #-}
+{-# OPTIONS_GHC -Wall -fno-warn-type-defaults #-}
import Control.Monad
import Control.Monad.Loops
import Control.Monad.Primitive
-import Data.List
+import Data.Hashable
+import Data.HashSet (HashSet)
+import qualified Data.HashSet as HashSet
import System.Random.MWC
import System.Random.MWC.Distributions
@@ -44,36 +46,40 @@ hmcKernel lTarget glTarget t0 ndisc e g = do
-- Utilities ------------------------------------------------------------------
-- TODO quickcheck all this
--- change leapfrog to return (parameters, momentum)
-buildTree lTarget glTarget t0 r0 u0 v0 j0 e0 = go t0 r0 u0 v0 j0 e0
+buildTree
+ :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
+ => ([c] -> c)
+ -> ([c] -> [c])
+ -> [c]
+ -> [c]
+ -> c
+ -> c
+ -> a
+ -> c
+ -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
+buildTree lTarget glTarget = go
where
go t r u v 0 e =
let (t1, r1) = leapfrog glTarget t r 1 (v * e)
- c | u <= auxilliaryTarget lTarget t1 r1 = [(t1, r1)] -- only require a set here
- | otherwise = []
+ c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
+ | otherwise = HashSet.empty
s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1
| otherwise = 0
in (t1, r1, t1, r1, c, s)
go t r u v j e =
let (tn, rn, tp, rp, c0, s0) = go t r u v (pred j) e
- (tnn, rnn, tpp, rpp, c1, s1) = if v == -1
+ (tnn, rnn, tpp, rpp, c1, s1) = if roundTo 6 v == -1
then go tn rn u v (pred j) e
else go tp rp u v (pred j) e
- s2 = s0 * s1 * indicator ((tpp - tnn) * rnn >= 0) -- check these
- * indicator ((tpp - tnn) * rpp >= 0)
+ s2 = s0 * s1 * indicator ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicator ((tpp .- tnn) `innerProduct` rpp >= 0)
- c2 = c0 `union` c1
+ c2 = c0 `HashSet.union` c1
in (tnn, rnn, tpp, rpp, c2, s2)
-
-
-
-
-
-
leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
=> ([c] -> [c]) -- ^ Gradient of log target function
-> [c] -- ^ List of parameters to target
@@ -89,8 +95,9 @@ leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0)
rt = zipWith (+) rm (map (* (0.5 * e)) (glTarget t))
in go tt rt (pred n)
--- | Acceptance ratio. t0/r0 denote the present state of the parameters and
--- auxilliary variables, and t1/r1 denote the proposed state.
+-- | Acceptance ratio for a proposed move. t0/r0 denote the present state of
+-- the parameters and auxilliary variables, and t1/r1 denote the proposed
+-- state.
hmcAcceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
/ auxilliaryTarget lTarget t0 r0
@@ -102,6 +109,15 @@ auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
innerProduct :: Num a => [a] -> [a] -> a
innerProduct xs ys = sum $ zipWith (*) xs ys
-indicator p | p = const 1
- | otherwise = const 0
+-- | Vectorized subtraction.
+(.-) :: Num a => [a] -> [a] -> [a]
+xs .- ys = zipWith (-) xs ys
+
+indicator :: Integral a => Bool -> a
+indicator True = 1
+indicator False = 0
+
+-- | Round to a specified number of digits.
+roundTo :: RealFrac a => Int -> a -> a
+roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)