hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit e96bccefe000f2fd58ce8ee5fa14b83b9efb9ecc
parent 7807cf266fc3ea20248a1aee62b49bee3acfb64d
Author: Jared Tobin <jared@jtobin.ca>
Date:   Mon, 14 Oct 2013 10:41:02 +1300

Add dual-averaging code to module.

Diffstat:
Msrc/Numeric/MCMC/NUTS.hs | 182+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 179 insertions(+), 3 deletions(-)

diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs @@ -7,7 +7,7 @@ import Control.Monad import Control.Monad.Loops import Control.Monad.Primitive import System.Random.MWC -import System.Random.MWC.Distributions +import System.Random.MWC.Distributions hiding (gamma) import Statistics.Distribution.Normal type Parameters = [Double] @@ -22,13 +22,20 @@ newtype BuildTree = BuildTree { instance Show BuildTree where show (BuildTree (tm, rm, tp, rp, t', n, s)) = "\n" ++ "tm: " ++ show tm - -- ++ "\n" ++ "rm: " ++ show rm ++ "\n" ++ "tp: " ++ show tp - -- ++ "\n" ++ "rp: " ++ show rp ++ "\n" ++ "t': " ++ show t' ++ "\n" ++ "n : " ++ show n ++ "\n" ++ "s : " ++ show s +data DualAveragingParameters = DualAveragingParameters { + mAdapt :: Int + , delta :: Double + , mu :: Double + , gamma :: Double + , tau0 :: Double + , kappa :: Double + } deriving Show + -- | The NUTS sampler. nuts :: PrimMonad m @@ -47,6 +54,99 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t) then Nothing else Just (t0, (pred n, t1)) +-- | The NUTS sampler with dual averaging. +nutsDualAveraging + :: PrimMonad m + => Density + -> Gradient + -> Int + -> Int + -> Parameters + -> Gen (PrimState m) + -> m [Parameters] +nutsDualAveraging lTarget glTarget n nAdapt t g = do + e0 <- findReasonableEpsilon lTarget glTarget t g + let daParams = DualAveragingParameters { + mu = log (10 * e0) + , delta = 0.5 + , mAdapt = nAdapt + , gamma = 0.05 + , tau0 = 10 + , kappa = 0.75 + } + + unfoldrM (kernel daParams) (0, e0, 0, 0, t) + where + kernel params (m, e, eAvg, h, t0) = do + (eNext, eAvgNext, hNext, tNext) <- + nutsKernelDualAvg lTarget glTarget e eAvg h m params t0 g + return $ if m >= n + then Nothing + else Just (t0, (succ m, eNext, eAvgNext, hNext, tNext)) + +-- | A single iteration of dual-averaging NUTS. +nutsKernelDualAvg + :: PrimMonad m + => Density + -> Gradient + -> Double + -> Double + -> Double + -> Int + -> DualAveragingParameters + -> [Double] + -> Gen (PrimState m) + -> m (Double, Double, Double, [Double]) +nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do + r0 <- replicateM (length t) (normal 0 1 g) + z0 <- exponential 1 g + let logu = auxilliaryTarget lTarget t r0 - z0 + + let go (tn, tp, rn, rp, j, tm, n, s, a, na) g + | s == 1 = do + vj <- symmetricCategorical [-1, 1] g + z <- uniform g + + (tnn, rnn, tpp, rpp, t1, n1, s1, a1, na1) <- + if vj == -1 + then do + (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- + buildTreeDualAvg lTarget glTarget g tn rn logu vj j e t r0 + return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1') + else do + (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- + buildTreeDualAvg lTarget glTarget g tp rp logu vj j e t r0 + return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') + + let t2 | s1 == 1 + && (fi n1 / fi n :: Double) > z = t1 + | otherwise = tm + + n2 = n + n1 + s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + j1 = succ j + + go (tnn, rnn, tpp, rpp, j1, t2, n2, s2, a1, na1) g + + | otherwise = return (tm, a) + + (nextPosition, prob) <- go (t, t, r0, r0, 0, t, 1, 1, 0, 0) g + + let (hNext, eNext, eAvgNext) = + if m <= mAdapt daParams + then (hm, exp logEm, exp logEbarM) + else (h, eAvg, eAvg) + where + hm = (1 - 1 / (fromIntegral m + tau0 daParams)) * h + + (1 / (fromIntegral m + tau0 daParams)) * (delta daParams - prob) + + logEm = mu daParams - (sqrt (fromIntegral m) / gamma daParams) * hm + logEbarM = fromIntegral m ** (- (kappa daParams)) * logEm + + (1 - fromIntegral m ** (- (kappa daParams))) * (log eAvg) + + return (e, eAvg, h, nextPosition) + -- | A single iteration of NUTS. nutsKernel :: PrimMonad m @@ -140,6 +240,82 @@ buildTree lTarget glTarget g t r logu v j e = do return (tnn, rnn, tpp, rpp, t2, n2, s2) else return (tn, rn, tp, rp, t0, n0, s0) +-- | Build the tree of candidate states under dual averaging. +buildTreeDualAvg + :: PrimMonad m + => Density + -> Gradient + -> Gen (PrimState m) + -> Parameters + -> Parameters + -> Double + -> Double + -> Int + -> Double + -> Parameters + -> Parameters + -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int) +buildTreeDualAvg lTarget glTarget g = go + where + go t r logu v 0 e t0 r0 = return $ + let (t1, r1) = leapfrog glTarget (t, r) (v * e) + lAuxTarget = log $ auxilliaryTarget lTarget t1 r1 + n = indicate (logu <= lAuxTarget) + s = indicate (logu - 1000 < lAuxTarget) + m = min 1 (acceptanceRatio lTarget t1 r1 t0 r0) + in (t1, r1, t1, r1, t1, n, s, m, 1) + + go t r u v j e t0 r0 = do + z <- uniform g + (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0 + + if s1 == 1 + then do + (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <- + if v == -1 + then do + (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- + go tn rn u v (pred j) e t0 r0 + return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1') + else do + (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- + go tp rp u v (pred j) e t0 r0 + return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') + + let p = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1 + n3 = n1 + n2 + a3 = a1 + a2 + na3 = na1 + na2 + + s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + + t3 | p > (z :: Double) = t2 + | otherwise = t1 + + return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3) + else return (tn, rn, tp, rp, t1, n1, s1, a1, na1) + +-- | Heuristic for initializing step size. +findReasonableEpsilon + :: PrimMonad m + => Density + -> Gradient + -> Parameters + -> Gen (PrimState m) + -> m Double +findReasonableEpsilon lTarget glTarget t0 g = do + r0 <- replicateM (length t0) (normal 0 1 g) + let (t1, r1) = leapfrog glTarget (t0, r0) 1.0 + a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 + + go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = + let (tn, rn) = leapfrog glTarget (t, r) e + in go (2 ^ a * e) tn rn + | otherwise = e + + return $ go 1.0 t1 r1 + -- | Simulate a single step of Hamiltonian dynamics. leapfrog :: Gradient -> Particle -> Double -> Particle leapfrog glTarget (t, r) e = (tf, rf)