commit 7db25113e0a796e022dcd9acbbab3a33b26812e0
parent da4ff176919afb192278fb43fff40dff2dae54d4
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun,  1 Sep 2013 21:09:06 +1200
Add more sophisticated buildTree.
Diffstat:
| M | NUTS.hs | | | 137 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------- | 
1 file changed, 98 insertions(+), 39 deletions(-)
diff --git a/NUTS.hs b/NUTS.hs
@@ -46,41 +46,6 @@ hmcKernel lTarget glTarget t0 ndisc e g = do
             | otherwise = (t0, r0)
   return final
 
--- note that this is not the greatest buildTree we could use
-buildTree
-  :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
-  => ([c] -> c)   -- ^ Log target
-  -> ([c] -> [c]) -- ^ Gradient
-  -> [c]          -- ^ Position
-  -> [c]          -- ^ Momentum
-  -> c            -- ^ Slice variable
-  -> c            -- ^ Direction (-1, +1)
-  -> a            -- ^ Depth
-  -> c            -- ^ Step size
-  -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
-buildTree lTarget glTarget = go 
-  where 
-    go t r u v 0 e = 
-      let (t1, r1) = leapfrog glTarget t r 1 (v * e)
-          c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
-            | otherwise                           = HashSet.empty
-          s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1
-            | otherwise                                     = 0
-      in  (t1, r1, t1, r1, c, s)
-
-    go t r u v j e = 
-      let (tn, rn, tp, rp, c0, s0)     = go t r u v (pred j) e
-          (tnn, rnn, tpp, rpp, c1, s1) = if   roundTo 6 v == -1
-                                         then go tn rn u v (pred j) e
-                                         else go tp rp u v (pred j) e
-
-          s2 = s0 * s1 * indicator ((tpp .- tnn) `innerProduct` rnn >= 0)
-                       * indicator ((tpp .- tnn) `innerProduct` rpp >= 0)
-
-          c2 = c0 `HashSet.union` c1
-
-      in  (tnn, rnn, tpp, rpp, c2, s2) 
-
 leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
          => ([c] -> [c]) -- ^ Gradient of log target function
          -> [c]          -- ^ List of parameters to target
@@ -117,7 +82,7 @@ findReasonableEpsilon lTarget glTarget t0 g = do
   r0 <- replicateM (length t0) (normal 0 1 g)
   let (t1, r1) = leapfrog glTarget t0 r0 1 1.0
 
-      a = 2 * indicator (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+      a = 2 * indicate (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
 
       go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = 
                    let en       = 2 ^ a * e
@@ -127,6 +92,64 @@ findReasonableEpsilon lTarget glTarget t0 g = do
 
   return $ go 1.0 t1 r1
 
+-- go needs to return in some monad
+
+
+buildTree :: (Enum a, Eq a, Floating t, Fractional c, Integral c, Integral d
+             , Num a, Num e, RealFrac d, RealFrac t, PrimMonad m, Variate c) 
+  => ([t] -> t)
+  -> ([t] -> [t])
+  -> Gen (PrimState m)
+  -> [t]
+  -> [t]
+  -> t
+  -> t
+  -> a
+  -> t
+  -> t1
+  -> [t]
+  -> m ([t], [t], [t], [t], [t], c, d, t, e)
+buildTree lTarget glTarget g = go 
+  where
+    go t r u v 0 e _ r0 = return $
+      let (t1, r1) = leapfrog glTarget t r 1 (v * e)
+          n        = indicate (u <= auxilliaryTarget lTarget t1 r1)
+          s        = indicate (u <  exp 1000 * auxilliaryTarget lTarget t1 r1)
+          m        = min 1 (hmcAcceptanceRatio lTarget t1 r1 r0 r0)
+      in  (t1, r1, t1, r1, t1, n, s, m, 1)
+
+    go t r u v j e t0 r0 = do
+      z <- uniform g
+      (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+
+      if   roundTo 6 s1 == 1
+      then do
+        (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+          if   roundTo 6 v == -1
+          then go tn rn u v (pred j) e t0 r0
+          else go tp rp u v (pred j) e t0 r0
+  
+        let p  = n2 / (n1 + n2)
+  
+            t3 | p > z     = t2
+               | otherwise = t1
+  
+            a3  = a1  + a2
+            na3 = na1 + na2
+  
+            s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                    * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) 
+  
+            n3 = n1 + n2
+        return $ (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+      else return $ (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+
+
+
+
+
+
+
 
 
 
@@ -138,11 +161,47 @@ innerProduct xs ys = sum $ zipWith (*) xs ys
 (.-) :: Num a => [a] -> [a] -> [a]
 xs .- ys = zipWith (-) xs ys
 
-indicator :: Integral a => Bool -> a
-indicator True  = 1
-indicator False = 0
+indicate :: Integral a => Bool -> a
+indicate True  = 1
+indicate False = 0
 
 -- | Round to a specified number of digits.
 roundTo :: RealFrac a => Int -> a -> a
 roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)
 
+-- Deprecated -----------------------------------------------------------------
+
+basicBuildTree
+  :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
+  => ([c] -> c)   -- ^ Log target
+  -> ([c] -> [c]) -- ^ Gradient
+  -> [c]          -- ^ Position
+  -> [c]          -- ^ Momentum
+  -> c            -- ^ Slice variable
+  -> c            -- ^ Direction (-1, +1)
+  -> a            -- ^ Depth
+  -> c            -- ^ Step size
+  -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
+basicBuildTree lTarget glTarget = go 
+  where 
+    go t r u v 0 e = 
+      let (t1, r1) = leapfrog glTarget t r 1 (v * e)
+          c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
+            | otherwise                           = HashSet.empty
+          s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1
+            | otherwise                                     = 0
+      in  (t1, r1, t1, r1, c, s)
+
+    go t r u v j e = 
+      let (tn, rn, tp, rp, c0, s0)     = go t r u v (pred j) e
+          (tnn, rnn, tpp, rpp, c1, s1) = if   roundTo 6 v == -1
+                                         then go tn rn u v (pred j) e
+                                         else go tp rp u v (pred j) e
+
+          s2 = s0 * s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+                       * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+          c2 = c0 `HashSet.union` c1
+
+      in  (tnn, rnn, tpp, rpp, c2, s2) 
+