hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit da4ff176919afb192278fb43fff40dff2dae54d4
parent 8c3ae473a11deec6ad2769b49617062ab759d3e9
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun,  1 Sep 2013 20:15:26 +1200

Add findReasonableEpsilon.

Diffstat:
MNUTS.hs | 55++++++++++++++++++++++++++++++++++++++++---------------
MREADME.md | 2+-
2 files changed, 41 insertions(+), 16 deletions(-)

diff --git a/NUTS.hs b/NUTS.hs @@ -1,3 +1,6 @@ +-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path +-- Lengths in Hamiltonian Monte Carlo. + {-# OPTIONS_GHC -Wall -fno-warn-type-defaults #-} import Control.Monad @@ -15,7 +18,7 @@ hmc :: (Enum a, Eq a, Ord a, Num a, PrimMonad m ) -> [Double] -- ^ Parameters -> a -- ^ Epochs to run the chain -> a -- ^ Number of discretizing steps - -> Double -- ^ Tolerance + -> Double -- ^ Step size -> Gen (PrimState m) -- ^ PRNG -> m [[Double]] -- ^ Chain hmc lTarget glTarget t n ndisc e g = unfoldrM kernel (n, (t, [])) @@ -31,7 +34,7 @@ hmcKernel :: (Enum a, Eq a, Ord a, Num a, PrimMonad m) -> ([Double] -> [Double]) -- ^ Gradient of log target -> [Double] -- ^ Parameters -> a -- ^ Number of discretizing steps - -> Double -- ^ Tolerance + -> Double -- ^ Step size -> Gen (PrimState m) -- ^ PRNG -> m ([Double], [Double]) -- ^ m (End params, end momenta) hmcKernel lTarget glTarget t0 ndisc e g = do @@ -43,19 +46,17 @@ hmcKernel lTarget glTarget t0 ndisc e g = do | otherwise = (t0, r0) return final --- Utilities ------------------------------------------------------------------ - --- TODO quickcheck all this +-- note that this is not the greatest buildTree we could use buildTree :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c) - => ([c] -> c) - -> ([c] -> [c]) - -> [c] - -> [c] - -> c - -> c - -> a - -> c + => ([c] -> c) -- ^ Log target + -> ([c] -> [c]) -- ^ Gradient + -> [c] -- ^ Position + -> [c] -- ^ Momentum + -> c -- ^ Slice variable + -> c -- ^ Direction (-1, +1) + -> a -- ^ Depth + -> c -- ^ Step size -> ([c], [c], [c], [c], HashSet ([c], [c]), t) buildTree lTarget glTarget = go where @@ -85,7 +86,7 @@ leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a) -> [c] -- ^ List of parameters to target -> [c] -- ^ Momentum variables -> a -- ^ Number of discretizing steps - -> c -- ^ Tolerance + -> c -- ^ Step size -> ([c], [c]) -- ^ (End parameters, end momenta) leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0) | otherwise = go t0 r0 ndisc @@ -106,10 +107,34 @@ hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1 auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r) +findReasonableEpsilon :: PrimMonad m + => ([Double] -> Double) + -> ([Double] -> [Double]) + -> [Double] + -> Gen (PrimState m) + -> m Double +findReasonableEpsilon lTarget glTarget t0 g = do + r0 <- replicateM (length t0) (normal 0 1 g) + let (t1, r1) = leapfrog glTarget t0 r0 1 1.0 + + a = 2 * indicator (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 + + go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = + let en = 2 ^ a * e + (tn, rn) = leapfrog glTarget t r 1 e + in go en tn rn + | otherwise = e + + return $ go 1.0 t1 r1 + + + + +-- Utilities ------------------------------------------------------------------ + innerProduct :: Num a => [a] -> [a] -> a innerProduct xs ys = sum $ zipWith (*) xs ys --- | Vectorized subtraction. (.-) :: Num a => [a] -> [a] -> [a] xs .- ys = zipWith (-) xs ys diff --git a/README.md b/README.md @@ -1,6 +1,6 @@ hnuts ----- -See: Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Settings Path +See: Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.