hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 5aebd6f773fa2bbbff1378fd634c43d82a8ce0bd
parent e96bccefe000f2fd58ce8ee5fa14b83b9efb9ecc
Author: Jared Tobin <jared@jtobin.ca>
Date:   Tue, 15 Oct 2013 15:51:53 +1300

Bug fixes, DA dev.

Diffstat:
Msrc/Numeric/MCMC/NUTS.hs | 219+++++++++++++++++++++++++++++++++++++++++++------------------------------------
Mtests/Test.hs | 66+++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------
2 files changed, 175 insertions(+), 110 deletions(-)

diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE BangPatterns #-} + -- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path -- Lengths in Hamiltonian Monte Carlo. @@ -10,6 +12,8 @@ import System.Random.MWC import System.Random.MWC.Distributions hiding (gamma) import Statistics.Distribution.Normal +import Debug.Trace + type Parameters = [Double] type Density = Parameters -> Double type Gradient = Parameters -> Parameters @@ -46,13 +50,12 @@ nuts -> Parameters -> Gen (PrimState m) -> m [Parameters] -nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t) - where - kernel eps (n, t0) = do - t1 <- nutsKernel lTarget glTarget eps t0 g - return $ if n <= 0 - then Nothing - else Just (t0, (pred n, t1)) +nuts lTarget glTarget n e t g = go t 0 [] + where go position j acc + | j >= n = return acc + | otherwise = do + nextPosition <- nutsKernel lTarget glTarget e position g + go nextPosition (succ j) (nextPosition : acc) -- | The NUTS sampler with dual averaging. nutsDualAveraging @@ -66,24 +69,27 @@ nutsDualAveraging -> m [Parameters] nutsDualAveraging lTarget glTarget n nAdapt t g = do e0 <- findReasonableEpsilon lTarget glTarget t g - let daParams = DualAveragingParameters { - mu = log (10 * e0) - , delta = 0.5 - , mAdapt = nAdapt - , gamma = 0.05 - , tau0 = 10 - , kappa = 0.75 - } - - unfoldrM (kernel daParams) (0, e0, 0, 0, t) + let daParams = basicDualAveragingParameters e0 nAdapt + unfoldrM (kernel daParams) (1, e0, 1, 0, t) where kernel params (m, e, eAvg, h, t0) = do (eNext, eAvgNext, hNext, tNext) <- nutsKernelDualAvg lTarget glTarget e eAvg h m params t0 g - return $ if m >= n + return $ if m > n then Nothing else Just (t0, (succ m, eNext, eAvgNext, hNext, tNext)) +-- | Default DA parameters, given a base step size and burn in period. +basicDualAveragingParameters :: Double -> Int -> DualAveragingParameters +basicDualAveragingParameters step burnInPeriod = DualAveragingParameters { + mu = log (10 * step) + , delta = 0.5 + , mAdapt = burnInPeriod + , gamma = 0.05 + , tau0 = 10 + , kappa = 0.75 + } + -- | A single iteration of dual-averaging NUTS. nutsKernelDualAvg :: PrimMonad m @@ -96,13 +102,13 @@ nutsKernelDualAvg -> DualAveragingParameters -> [Double] -> Gen (PrimState m) - -> m (Double, Double, Double, [Double]) + -> m (Double, Double, Double, Parameters) nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do r0 <- replicateM (length t) (normal 0 1 g) z0 <- exponential 1 g - let logu = auxilliaryTarget lTarget t r0 - z0 + let logu = log (auxilliaryTarget lTarget t r0) - z0 - let go (tn, tp, rn, rp, j, tm, n, s, a, na) g + let go (tn, tp, rn, rp, tm, j, n, s, a, na) g | s == 1 = do vj <- symmetricCategorical [-1, 1] g z <- uniform g @@ -118,34 +124,35 @@ nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do buildTreeDualAvg lTarget glTarget g tp rp logu vj j e t r0 return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') - let t2 | s1 == 1 - && (fi n1 / fi n :: Double) > z = t1 - | otherwise = tm + let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z n2 = n + n1 - s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + s2 = s1 * stopCriterion tnn tpp rnn rpp j1 = succ j + t2 | accept = t1 + | otherwise = tm - go (tnn, rnn, tpp, rpp, j1, t2, n2, s2, a1, na1) g + go (tnn, tpp, rnn, rpp, t2, j1, n2, s2, a1, na1) g - | otherwise = return (tm, a) + | otherwise = return (tm, a, na) - (nextPosition, prob) <- go (t, t, r0, r0, 0, t, 1, 1, 0, 0) g + (nextPosition, alpha, nalpha) <- go (t, t, r0, r0, t, 0, 1, 1, 0, 0) g - let (hNext, eNext, eAvgNext) = + let (hNext, eNext, !eAvgNext) = if m <= mAdapt daParams then (hm, exp logEm, exp logEbarM) else (h, eAvg, eAvg) where - hm = (1 - 1 / (fromIntegral m + tau0 daParams)) * h - + (1 / (fromIntegral m + tau0 daParams)) * (delta daParams - prob) + eta = 1 / (fromIntegral m + tau0 daParams) + hm = (1 - eta) * h + + eta * (delta daParams - alpha / fromIntegral nalpha) - logEm = mu daParams - (sqrt (fromIntegral m) / gamma daParams) * hm - logEbarM = fromIntegral m ** (- (kappa daParams)) * logEm - + (1 - fromIntegral m ** (- (kappa daParams))) * (log eAvg) + zeta = fromIntegral m ** (- (kappa daParams)) - return (e, eAvg, h, nextPosition) + logEm = mu daParams - sqrt (fromIntegral m) / gamma daParams * hm + logEbarM = (1 - zeta) * log eAvg + zeta * logEm + + trace (show eAvgNext) $ return (eNext, eAvgNext, hNext, nextPosition) -- | A single iteration of NUTS. nutsKernel @@ -159,14 +166,14 @@ nutsKernel nutsKernel lTarget glTarget e t g = do r0 <- replicateM (length t) (normal 0 1 g) z0 <- exponential 1 g - let logu = auxilliaryTarget lTarget t r0 - z0 + let logu = log (auxilliaryTarget lTarget t r0) - z0 - let go (tn, tp, rn, rp, j, tm, n, s) g + let go (tn, tp, rn, rp, tm, j, n, s) g | s == 1 = do vj <- symmetricCategorical [-1, 1] g z <- uniform g - (tnn, rnn, tpp, rpp, t1, n1, s1) <- + (tnn, rnn, tpp, rpp, t1, n1, s1) <- if vj == -1 then do (tnn', rnn', _, _, t1', n1', s1') <- @@ -177,20 +184,19 @@ nutsKernel lTarget glTarget e t g = do buildTree lTarget glTarget g tp rp logu vj j e return (tn, rn, tpp', rpp', t1', n1', s1') - let t2 | s1 == 1 - && (fi n1 / fi n :: Double) > z = t1 - | otherwise = tm + let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z n2 = n + n1 - s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + s2 = s1 * stopCriterion tnn tpp rnn rpp j1 = succ j + t2 | accept = t1 + | otherwise = tm - go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) g + go (tnn, tpp, rnn, rpp, t2, j1, n2, s2) g | otherwise = return tm - go (t, t, r0, r0, 0, t, 1, 1) g + go (t, t, r0, r0, t, 0, 1, 1) g -- | Build the 'tree' of candidate states. buildTree @@ -206,10 +212,10 @@ buildTree -> Double -> m ([Double], [Double], [Double], [Double], [Double], Int, Int) buildTree lTarget glTarget g t r logu v 0 e = do - let (t0, r0) = leapfrog glTarget (t, r) (v * e) - lAuxTarget = log $ auxilliaryTarget lTarget t0 r0 - n = indicate (logu < lAuxTarget) - s = indicate (logu - 1000 < lAuxTarget) + let (t0, r0) = leapfrog glTarget (t, r) (v * e) + joint = log $ auxilliaryTarget lTarget t0 r0 + n = indicate (logu < joint) + s = indicate (logu - 1000 < joint) return (t0, r0, t0, r0, t0, n, s) buildTree lTarget glTarget g t r logu v j e = do @@ -232,14 +238,21 @@ buildTree lTarget glTarget g t r logu v j e = do let accept = (fi n1 / max (fi (n0 + n1)) 1) > (z :: Double) n2 = n0 + n1 + s2 = s0 * s1 * stopCriterion tnn tpp rnn rpp t2 | accept = t1 | otherwise = t0 - s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) return (tnn, rnn, tpp, rpp, t2, n2, s2) else return (tn, rn, tp, rp, t0, n0, s0) +-- | Determine whether or not to stop doubling the tree of candidate states. +stopCriterion :: (Integral a, Num b, Ord b) => [b] -> [b] -> [b] -> [b] -> a +stopCriterion tn tp rn rp = + indicate (positionDifference `innerProduct` rn >= 0) + * indicate (positionDifference `innerProduct` rp >= 0) + where + positionDifference = tp .- tn + -- | Build the tree of candidate states under dual averaging. buildTreeDualAvg :: PrimMonad m @@ -255,46 +268,44 @@ buildTreeDualAvg -> Parameters -> Parameters -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int) -buildTreeDualAvg lTarget glTarget g = go - where - go t r logu v 0 e t0 r0 = return $ - let (t1, r1) = leapfrog glTarget (t, r) (v * e) - lAuxTarget = log $ auxilliaryTarget lTarget t1 r1 - n = indicate (logu <= lAuxTarget) - s = indicate (logu - 1000 < lAuxTarget) - m = min 1 (acceptanceRatio lTarget t1 r1 t0 r0) - in (t1, r1, t1, r1, t1, n, s, m, 1) - - go t r u v j e t0 r0 = do - z <- uniform g - (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0 - - if s1 == 1 - then do - (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <- - if v == -1 - then do - (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- - go tn rn u v (pred j) e t0 r0 - return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1') - else do - (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- - go tp rp u v (pred j) e t0 r0 - return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') - - let p = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1 - n3 = n1 + n2 - a3 = a1 + a2 - na3 = na1 + na2 - - s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) - - t3 | p > (z :: Double) = t2 - | otherwise = t1 - - return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3) - else return (tn, rn, tp, rp, t1, n1, s1, a1, na1) +buildTreeDualAvg lTarget glTarget g t r logu v 0 e t0 r0 = do + let (t1, r1) = leapfrog glTarget (t, r) (v * e) + joint = log $ auxilliaryTarget lTarget t1 r1 + n = indicate (logu <= joint) + s = indicate (logu - 1000 < joint) + a = min 1 (acceptanceRatio lTarget t1 r1 t0 r0) + return (t1, r1, t1, r1, t1, n, s, a, 1) + +buildTreeDualAvg lTarget glTarget g t r logu v j e t0 r0 = do + z <- uniform g + (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- + buildTreeDualAvg lTarget glTarget g t r logu v (pred j) e t0 r0 + + if s1 == 1 + then do + (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <- + if v == -1 + then do + (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- + buildTreeDualAvg lTarget glTarget g tn rn logu v (pred j) e t0 r0 + return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1') + else do + (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- + buildTreeDualAvg lTarget glTarget g tp rp logu v (pred j) e t0 r0 + return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') + + let p = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1 + accept = p > (z :: Double) + n3 = n1 + n2 + a3 = a1 + a2 + na3 = na1 + na2 + s3 = s1 * s2 * stopCriterion tnn tpp rnn rpp + + t3 | accept = t2 + | otherwise = t1 + + return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3) + else return (tn, rn, tp, rp, t1, n1, s1, a1, na1) -- | Heuristic for initializing step size. findReasonableEpsilon @@ -309,12 +320,14 @@ findReasonableEpsilon lTarget glTarget t0 g = do let (t1, r1) = leapfrog glTarget (t0, r0) 1.0 a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 - go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = - let (tn, rn) = leapfrog glTarget (t, r) e - in go (2 ^ a * e) tn rn - | otherwise = e + go j e t r + | j <= 0 = e -- no infinite loops + | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) = + let (tn, rn) = leapfrog glTarget (t, r) e + in go (pred j) (2 ^^ a * e) tn rn + | otherwise = e - return $ go 1.0 t1 r1 + return $ go 1000 1.0 t1 r1 -- | Simulate a single step of Hamiltonian dynamics. leapfrog :: Gradient -> Particle -> Double -> Particle @@ -326,11 +339,11 @@ leapfrog glTarget (t, r) e = (tf, rf) -- | Adjust momentum. adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c] -adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t) +adjustMomentum glTarget e t r = r .+ ((e / 2) .* glTarget t) -- | Adjust position. adjustPosition :: Num c => c -> [c] -> [c] -> [c] -adjustPosition e r t = zipWith (+) t (e .* r) +adjustPosition e r t = t .+ (e .* r) -- | The MH acceptance ratio for a given proposal. acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a @@ -353,6 +366,10 @@ z .* xs = map (* z) xs (.-) :: Num a => [a] -> [a] -> [a] xs .- ys = zipWith (-) xs ys +-- | Vectorized addition. +(.+) :: Num a => [a] -> [a] -> [a] +xs .+ ys = zipWith (+) xs ys + -- | Indicator function. indicate :: Integral a => Bool -> a indicate True = 1 @@ -362,8 +379,8 @@ indicate False = 0 symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a symmetricCategorical [] _ = error "symmetricCategorical: no candidates" symmetricCategorical zs g = do - z <- uniform g - return $ zs !! truncate (z * fromIntegral (length zs) :: Double) + j <- uniformR (0, length zs - 1) g + return $ zs !! j -- | Alias for fromIntegral. fi :: (Integral a, Num b) => a -> b diff --git a/tests/Test.hs b/tests/Test.hs @@ -1,3 +1,4 @@ +import Control.Lens import Control.Monad import Control.Monad.Primitive import Data.Vector (singleton) @@ -5,31 +6,78 @@ import Numeric.AD import Numeric.MCMC.NUTS import System.Random.MWC +import Debug.Trace + lTarget :: RealFloat a => [a] -> a lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) -glTarget :: [Double] -> [Double] -glTarget = grad lTarget +-- glTarget :: [Double] -> [Double] +-- glTarget = grad lTarget --- glTarget [x, y] = --- let dx = 20 * x * (y - x ^ 2) + 0.1 * (1 - x) --- dy = -10 * (y - x ^ 2) --- in [dx, dy] +glTarget [x, y] = + let dx = 20 * x * (y - x ^ 2) + 0.1 * (1 - x) + dy = -10 * (y - x ^ 2) + in [dx, dy] t0 = [0.0, 0.0] :: [Double] r0 = [0.0, 0.0] :: [Double] logu = -0.12840 :: Double v = -1 :: Double n = 5 :: Int +m = 1 +madapt = 0 e = 0.1 :: Double +eAvg = 1 +h = 0 + +t0da = [0.0, 0.0] :: [Double] +r0da = [0.0, 0.0] :: [Double] runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree runBuildTree g = do liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e -main = do - test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 20000 0.075 t0 - mapM_ print test +getThetas :: IO [Parameters] +getThetas = replicateM 1000 getTheta + +getTheta :: IO Parameters +getTheta = do + bt <- withSystemRandom . asGenIO $ \g -> + buildTree lTarget glTarget g t0 r0 logu v n e + return $ bt^._5 + +getThetasDa :: IO [Parameters] +getThetasDa = replicateM 1000 getThetaDa + +getThetaDa :: IO Parameters +getThetaDa = do + bt <- withSystemRandom . asGenIO $ \g -> + buildTreeDualAvg lTarget glTarget g t0 r0 logu v n e t0da r0da + return $ bt^._5 +genMoveDa :: IO Parameters +genMoveDa = withSystemRandom . asGenIO $ \g -> do + eps <- findReasonableEpsilon lTarget glTarget t0 g + let daParams = basicDualAveragingParameters eps madapt + blah <- nutsKernelDualAvg lTarget glTarget eps eAvg h m daParams t0 g + return $ blah^._4 + +genMovesDa :: IO [Parameters] +genMovesDa = replicateM 1000 genMoveDa + +genMove :: IO Parameters +genMove = withSystemRandom . asGenIO $ nutsKernel lTarget glTarget e t0 + +genMoves :: IO [Parameters] +genMoves = replicateM 1000 genMove + +main = do + test <- withSystemRandom . asGenIO $ + nutsDualAveraging lTarget glTarget 100 10 t0 + -- nuts lTarget glTarget 5000 0.1 t0 + -- genMovesDa + -- genMoves + mapM_ (putStrLn . filter (`notElem` "[]") . show) test + -- mapM_ print test