commit da4ff176919afb192278fb43fff40dff2dae54d4
parent 8c3ae473a11deec6ad2769b49617062ab759d3e9
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 1 Sep 2013 20:15:26 +1200
Add findReasonableEpsilon.
Diffstat:
M | NUTS.hs | | | 55 | ++++++++++++++++++++++++++++++++++++++++--------------- |
M | README.md | | | 2 | +- |
2 files changed, 41 insertions(+), 16 deletions(-)
diff --git a/NUTS.hs b/NUTS.hs
@@ -1,3 +1,6 @@
+-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
+-- Lengths in Hamiltonian Monte Carlo.
+
{-# OPTIONS_GHC -Wall -fno-warn-type-defaults #-}
import Control.Monad
@@ -15,7 +18,7 @@ hmc :: (Enum a, Eq a, Ord a, Num a, PrimMonad m )
-> [Double] -- ^ Parameters
-> a -- ^ Epochs to run the chain
-> a -- ^ Number of discretizing steps
- -> Double -- ^ Tolerance
+ -> Double -- ^ Step size
-> Gen (PrimState m) -- ^ PRNG
-> m [[Double]] -- ^ Chain
hmc lTarget glTarget t n ndisc e g = unfoldrM kernel (n, (t, []))
@@ -31,7 +34,7 @@ hmcKernel :: (Enum a, Eq a, Ord a, Num a, PrimMonad m)
-> ([Double] -> [Double]) -- ^ Gradient of log target
-> [Double] -- ^ Parameters
-> a -- ^ Number of discretizing steps
- -> Double -- ^ Tolerance
+ -> Double -- ^ Step size
-> Gen (PrimState m) -- ^ PRNG
-> m ([Double], [Double]) -- ^ m (End params, end momenta)
hmcKernel lTarget glTarget t0 ndisc e g = do
@@ -43,19 +46,17 @@ hmcKernel lTarget glTarget t0 ndisc e g = do
| otherwise = (t0, r0)
return final
--- Utilities ------------------------------------------------------------------
-
--- TODO quickcheck all this
+-- note that this is not the greatest buildTree we could use
buildTree
:: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
- => ([c] -> c)
- -> ([c] -> [c])
- -> [c]
- -> [c]
- -> c
- -> c
- -> a
- -> c
+ => ([c] -> c) -- ^ Log target
+ -> ([c] -> [c]) -- ^ Gradient
+ -> [c] -- ^ Position
+ -> [c] -- ^ Momentum
+ -> c -- ^ Slice variable
+ -> c -- ^ Direction (-1, +1)
+ -> a -- ^ Depth
+ -> c -- ^ Step size
-> ([c], [c], [c], [c], HashSet ([c], [c]), t)
buildTree lTarget glTarget = go
where
@@ -85,7 +86,7 @@ leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
-> [c] -- ^ List of parameters to target
-> [c] -- ^ Momentum variables
-> a -- ^ Number of discretizing steps
- -> c -- ^ Tolerance
+ -> c -- ^ Step size
-> ([c], [c]) -- ^ (End parameters, end momenta)
leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0)
| otherwise = go t0 r0 ndisc
@@ -106,10 +107,34 @@ hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+findReasonableEpsilon :: PrimMonad m
+ => ([Double] -> Double)
+ -> ([Double] -> [Double])
+ -> [Double]
+ -> Gen (PrimState m)
+ -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+ r0 <- replicateM (length t0) (normal 0 1 g)
+ let (t1, r1) = leapfrog glTarget t0 r0 1 1.0
+
+ a = 2 * indicator (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+
+ go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
+ let en = 2 ^ a * e
+ (tn, rn) = leapfrog glTarget t r 1 e
+ in go en tn rn
+ | otherwise = e
+
+ return $ go 1.0 t1 r1
+
+
+
+
+-- Utilities ------------------------------------------------------------------
+
innerProduct :: Num a => [a] -> [a] -> a
innerProduct xs ys = sum $ zipWith (*) xs ys
--- | Vectorized subtraction.
(.-) :: Num a => [a] -> [a] -> [a]
xs .- ys = zipWith (-) xs ys
diff --git a/README.md b/README.md
@@ -1,6 +1,6 @@
hnuts
-----
-See: Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Settings Path
+See: Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
Lengths in Hamiltonian Monte Carlo.