commit 8c3ae473a11deec6ad2769b49617062ab759d3e9
parent e223fc4e94e1ebe18d3e3083831a2284c0387e46
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun,  1 Sep 2013 18:46:08 +1200
Correct buildTree function.
Diffstat:
| M | NUTS.hs | | | 56 | ++++++++++++++++++++++++++++++++++++-------------------- | 
1 file changed, 36 insertions(+), 20 deletions(-)
diff --git a/NUTS.hs b/NUTS.hs
@@ -1,9 +1,11 @@
-{-# OPTIONS_GHC -Wall #-}
+{-# OPTIONS_GHC -Wall -fno-warn-type-defaults #-}
 
 import Control.Monad
 import Control.Monad.Loops
 import Control.Monad.Primitive
-import Data.List
+import Data.Hashable
+import Data.HashSet (HashSet)
+import qualified Data.HashSet as HashSet
 import System.Random.MWC
 import System.Random.MWC.Distributions
 
@@ -44,36 +46,40 @@ hmcKernel lTarget glTarget t0 ndisc e g = do
 -- Utilities ------------------------------------------------------------------
 
 -- TODO quickcheck all this
---   change leapfrog to return (parameters, momentum)
-buildTree lTarget glTarget t0 r0 u0 v0 j0 e0 = go t0 r0 u0 v0 j0 e0
+buildTree
+  :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
+  => ([c] -> c)
+  -> ([c] -> [c])
+  -> [c]
+  -> [c]
+  -> c
+  -> c
+  -> a
+  -> c
+  -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
+buildTree lTarget glTarget = go 
   where 
     go t r u v 0 e = 
       let (t1, r1) = leapfrog glTarget t r 1 (v * e)
-          c | u <= auxilliaryTarget lTarget t1 r1 = [(t1, r1)] -- only require a set here
-            | otherwise                           = []
+          c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
+            | otherwise                           = HashSet.empty
           s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1
             | otherwise                                     = 0
       in  (t1, r1, t1, r1, c, s)
 
     go t r u v j e = 
       let (tn, rn, tp, rp, c0, s0)     = go t r u v (pred j) e
-          (tnn, rnn, tpp, rpp, c1, s1) = if   v == -1
+          (tnn, rnn, tpp, rpp, c1, s1) = if   roundTo 6 v == -1
                                          then go tn rn u v (pred j) e
                                          else go tp rp u v (pred j) e
 
-          s2 = s0 * s1 * indicator ((tpp - tnn) * rnn >= 0) -- check these
-                       * indicator ((tpp - tnn) * rpp >= 0)
+          s2 = s0 * s1 * indicator ((tpp .- tnn) `innerProduct` rnn >= 0)
+                       * indicator ((tpp .- tnn) `innerProduct` rpp >= 0)
 
-          c2 = c0 `union` c1
+          c2 = c0 `HashSet.union` c1
 
       in  (tnn, rnn, tpp, rpp, c2, s2) 
 
-
-
-
-
-
-
 leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
          => ([c] -> [c]) -- ^ Gradient of log target function
          -> [c]          -- ^ List of parameters to target
@@ -89,8 +95,9 @@ leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0)
                        rt = zipWith (+) rm (map (* (0.5 * e)) (glTarget t))
                    in  go tt rt (pred n)
 
--- | Acceptance ratio.  t0/r0 denote the present state of the parameters and
---   auxilliary variables, and t1/r1 denote the proposed state.
+-- | Acceptance ratio for a proposed move.  t0/r0 denote the present state of 
+--   the parameters and auxilliary variables, and t1/r1 denote the proposed 
+--   state.
 hmcAcceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
 hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
                                        / auxilliaryTarget lTarget t0 r0
@@ -102,6 +109,15 @@ auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
 innerProduct :: Num a => [a] -> [a] -> a
 innerProduct xs ys = sum $ zipWith (*) xs ys
 
-indicator p | p         = const 1
-            | otherwise = const 0
+-- | Vectorized subtraction.
+(.-) :: Num a => [a] -> [a] -> [a]
+xs .- ys = zipWith (-) xs ys
+
+indicator :: Integral a => Bool -> a
+indicator True  = 1
+indicator False = 0
+
+-- | Round to a specified number of digits.
+roundTo :: RealFrac a => Int -> a -> a
+roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)