commit da4ff176919afb192278fb43fff40dff2dae54d4
parent 8c3ae473a11deec6ad2769b49617062ab759d3e9
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun,  1 Sep 2013 20:15:26 +1200
Add findReasonableEpsilon.
Diffstat:
| M | NUTS.hs | | | 55 | ++++++++++++++++++++++++++++++++++++++++--------------- | 
| M | README.md | | | 2 | +- | 
2 files changed, 41 insertions(+), 16 deletions(-)
diff --git a/NUTS.hs b/NUTS.hs
@@ -1,3 +1,6 @@
+-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
+--   Lengths in Hamiltonian Monte Carlo.
+
 {-# OPTIONS_GHC -Wall -fno-warn-type-defaults #-}
 
 import Control.Monad
@@ -15,7 +18,7 @@ hmc :: (Enum a, Eq a, Ord a, Num a, PrimMonad m )
     -> [Double]               -- ^ Parameters
     -> a                      -- ^ Epochs to run the chain
     -> a                      -- ^ Number of discretizing steps
-    -> Double                 -- ^ Tolerance
+    -> Double                 -- ^ Step size
     -> Gen (PrimState m)      -- ^ PRNG
     -> m [[Double]]           -- ^ Chain
 hmc lTarget glTarget t n ndisc e g = unfoldrM kernel (n, (t, []))
@@ -31,7 +34,7 @@ hmcKernel :: (Enum a, Eq a, Ord a, Num a, PrimMonad m)
           -> ([Double] -> [Double]) -- ^ Gradient of log target
           -> [Double]               -- ^ Parameters
           -> a                      -- ^ Number of discretizing steps
-          -> Double                 -- ^ Tolerance
+          -> Double                 -- ^ Step size
           -> Gen (PrimState m)      -- ^ PRNG
           -> m ([Double], [Double]) -- ^ m (End params, end momenta)
 hmcKernel lTarget glTarget t0 ndisc e g = do
@@ -43,19 +46,17 @@ hmcKernel lTarget glTarget t0 ndisc e g = do
             | otherwise = (t0, r0)
   return final
 
--- Utilities ------------------------------------------------------------------
-
--- TODO quickcheck all this
+-- note that this is not the greatest buildTree we could use
 buildTree
   :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
-  => ([c] -> c)
-  -> ([c] -> [c])
-  -> [c]
-  -> [c]
-  -> c
-  -> c
-  -> a
-  -> c
+  => ([c] -> c)   -- ^ Log target
+  -> ([c] -> [c]) -- ^ Gradient
+  -> [c]          -- ^ Position
+  -> [c]          -- ^ Momentum
+  -> c            -- ^ Slice variable
+  -> c            -- ^ Direction (-1, +1)
+  -> a            -- ^ Depth
+  -> c            -- ^ Step size
   -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
 buildTree lTarget glTarget = go 
   where 
@@ -85,7 +86,7 @@ leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
          -> [c]          -- ^ List of parameters to target
          -> [c]          -- ^ Momentum variables
          -> a            -- ^ Number of discretizing steps
-         -> c            -- ^ Tolerance
+         -> c            -- ^ Step size
          -> ([c], [c])   -- ^ (End parameters, end momenta)
 leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0)
                                           | otherwise = go t0 r0 ndisc
@@ -106,10 +107,34 @@ hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
 auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
 auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
 
+findReasonableEpsilon :: PrimMonad m 
+                      => ([Double] -> Double) 
+                      -> ([Double] -> [Double]) 
+                      -> [Double] 
+                      -> Gen (PrimState m) 
+                      -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+  r0 <- replicateM (length t0) (normal 0 1 g)
+  let (t1, r1) = leapfrog glTarget t0 r0 1 1.0
+
+      a = 2 * indicator (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+
+      go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = 
+                   let en       = 2 ^ a * e
+                       (tn, rn) = leapfrog glTarget t r 1 e
+                   in  go en tn rn 
+               | otherwise = e
+
+  return $ go 1.0 t1 r1
+
+
+
+
+-- Utilities ------------------------------------------------------------------
+
 innerProduct :: Num a => [a] -> [a] -> a
 innerProduct xs ys = sum $ zipWith (*) xs ys
 
--- | Vectorized subtraction.
 (.-) :: Num a => [a] -> [a] -> [a]
 xs .- ys = zipWith (-) xs ys
 
diff --git a/README.md b/README.md
@@ -1,6 +1,6 @@
 hnuts
 -----
 
-See: Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Settings Path 
+See: Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path 
 Lengths in Hamiltonian Monte Carlo.