commit 5aebd6f773fa2bbbff1378fd634c43d82a8ce0bd
parent e96bccefe000f2fd58ce8ee5fa14b83b9efb9ecc
Author: Jared Tobin <jared@jtobin.ca>
Date: Tue, 15 Oct 2013 15:51:53 +1300
Bug fixes, DA dev.
Diffstat:
M | src/Numeric/MCMC/NUTS.hs | | | 219 | +++++++++++++++++++++++++++++++++++++++++++------------------------------------ |
M | tests/Test.hs | | | 66 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------- |
2 files changed, 175 insertions(+), 110 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -1,3 +1,5 @@
+{-# LANGUAGE BangPatterns #-}
+
-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
-- Lengths in Hamiltonian Monte Carlo.
@@ -10,6 +12,8 @@ import System.Random.MWC
import System.Random.MWC.Distributions hiding (gamma)
import Statistics.Distribution.Normal
+import Debug.Trace
+
type Parameters = [Double]
type Density = Parameters -> Double
type Gradient = Parameters -> Parameters
@@ -46,13 +50,12 @@ nuts
-> Parameters
-> Gen (PrimState m)
-> m [Parameters]
-nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
- where
- kernel eps (n, t0) = do
- t1 <- nutsKernel lTarget glTarget eps t0 g
- return $ if n <= 0
- then Nothing
- else Just (t0, (pred n, t1))
+nuts lTarget glTarget n e t g = go t 0 []
+ where go position j acc
+ | j >= n = return acc
+ | otherwise = do
+ nextPosition <- nutsKernel lTarget glTarget e position g
+ go nextPosition (succ j) (nextPosition : acc)
-- | The NUTS sampler with dual averaging.
nutsDualAveraging
@@ -66,24 +69,27 @@ nutsDualAveraging
-> m [Parameters]
nutsDualAveraging lTarget glTarget n nAdapt t g = do
e0 <- findReasonableEpsilon lTarget glTarget t g
- let daParams = DualAveragingParameters {
- mu = log (10 * e0)
- , delta = 0.5
- , mAdapt = nAdapt
- , gamma = 0.05
- , tau0 = 10
- , kappa = 0.75
- }
-
- unfoldrM (kernel daParams) (0, e0, 0, 0, t)
+ let daParams = basicDualAveragingParameters e0 nAdapt
+ unfoldrM (kernel daParams) (1, e0, 1, 0, t)
where
kernel params (m, e, eAvg, h, t0) = do
(eNext, eAvgNext, hNext, tNext) <-
nutsKernelDualAvg lTarget glTarget e eAvg h m params t0 g
- return $ if m >= n
+ return $ if m > n
then Nothing
else Just (t0, (succ m, eNext, eAvgNext, hNext, tNext))
+-- | Default DA parameters, given a base step size and burn in period.
+basicDualAveragingParameters :: Double -> Int -> DualAveragingParameters
+basicDualAveragingParameters step burnInPeriod = DualAveragingParameters {
+ mu = log (10 * step)
+ , delta = 0.5
+ , mAdapt = burnInPeriod
+ , gamma = 0.05
+ , tau0 = 10
+ , kappa = 0.75
+ }
+
-- | A single iteration of dual-averaging NUTS.
nutsKernelDualAvg
:: PrimMonad m
@@ -96,13 +102,13 @@ nutsKernelDualAvg
-> DualAveragingParameters
-> [Double]
-> Gen (PrimState m)
- -> m (Double, Double, Double, [Double])
+ -> m (Double, Double, Double, Parameters)
nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do
r0 <- replicateM (length t) (normal 0 1 g)
z0 <- exponential 1 g
- let logu = auxilliaryTarget lTarget t r0 - z0
+ let logu = log (auxilliaryTarget lTarget t r0) - z0
- let go (tn, tp, rn, rp, j, tm, n, s, a, na) g
+ let go (tn, tp, rn, rp, tm, j, n, s, a, na) g
| s == 1 = do
vj <- symmetricCategorical [-1, 1] g
z <- uniform g
@@ -118,34 +124,35 @@ nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do
buildTreeDualAvg lTarget glTarget g tp rp logu vj j e t r0
return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
- let t2 | s1 == 1
- && (fi n1 / fi n :: Double) > z = t1
- | otherwise = tm
+ let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z
n2 = n + n1
- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+ s2 = s1 * stopCriterion tnn tpp rnn rpp
j1 = succ j
+ t2 | accept = t1
+ | otherwise = tm
- go (tnn, rnn, tpp, rpp, j1, t2, n2, s2, a1, na1) g
+ go (tnn, tpp, rnn, rpp, t2, j1, n2, s2, a1, na1) g
- | otherwise = return (tm, a)
+ | otherwise = return (tm, a, na)
- (nextPosition, prob) <- go (t, t, r0, r0, 0, t, 1, 1, 0, 0) g
+ (nextPosition, alpha, nalpha) <- go (t, t, r0, r0, t, 0, 1, 1, 0, 0) g
- let (hNext, eNext, eAvgNext) =
+ let (hNext, eNext, !eAvgNext) =
if m <= mAdapt daParams
then (hm, exp logEm, exp logEbarM)
else (h, eAvg, eAvg)
where
- hm = (1 - 1 / (fromIntegral m + tau0 daParams)) * h
- + (1 / (fromIntegral m + tau0 daParams)) * (delta daParams - prob)
+ eta = 1 / (fromIntegral m + tau0 daParams)
+ hm = (1 - eta) * h
+ + eta * (delta daParams - alpha / fromIntegral nalpha)
- logEm = mu daParams - (sqrt (fromIntegral m) / gamma daParams) * hm
- logEbarM = fromIntegral m ** (- (kappa daParams)) * logEm
- + (1 - fromIntegral m ** (- (kappa daParams))) * (log eAvg)
+ zeta = fromIntegral m ** (- (kappa daParams))
- return (e, eAvg, h, nextPosition)
+ logEm = mu daParams - sqrt (fromIntegral m) / gamma daParams * hm
+ logEbarM = (1 - zeta) * log eAvg + zeta * logEm
+
+ trace (show eAvgNext) $ return (eNext, eAvgNext, hNext, nextPosition)
-- | A single iteration of NUTS.
nutsKernel
@@ -159,14 +166,14 @@ nutsKernel
nutsKernel lTarget glTarget e t g = do
r0 <- replicateM (length t) (normal 0 1 g)
z0 <- exponential 1 g
- let logu = auxilliaryTarget lTarget t r0 - z0
+ let logu = log (auxilliaryTarget lTarget t r0) - z0
- let go (tn, tp, rn, rp, j, tm, n, s) g
+ let go (tn, tp, rn, rp, tm, j, n, s) g
| s == 1 = do
vj <- symmetricCategorical [-1, 1] g
z <- uniform g
- (tnn, rnn, tpp, rpp, t1, n1, s1) <-
+ (tnn, rnn, tpp, rpp, t1, n1, s1) <-
if vj == -1
then do
(tnn', rnn', _, _, t1', n1', s1') <-
@@ -177,20 +184,19 @@ nutsKernel lTarget glTarget e t g = do
buildTree lTarget glTarget g tp rp logu vj j e
return (tn, rn, tpp', rpp', t1', n1', s1')
- let t2 | s1 == 1
- && (fi n1 / fi n :: Double) > z = t1
- | otherwise = tm
+ let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z
n2 = n + n1
- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+ s2 = s1 * stopCriterion tnn tpp rnn rpp
j1 = succ j
+ t2 | accept = t1
+ | otherwise = tm
- go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) g
+ go (tnn, tpp, rnn, rpp, t2, j1, n2, s2) g
| otherwise = return tm
- go (t, t, r0, r0, 0, t, 1, 1) g
+ go (t, t, r0, r0, t, 0, 1, 1) g
-- | Build the 'tree' of candidate states.
buildTree
@@ -206,10 +212,10 @@ buildTree
-> Double
-> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
buildTree lTarget glTarget g t r logu v 0 e = do
- let (t0, r0) = leapfrog glTarget (t, r) (v * e)
- lAuxTarget = log $ auxilliaryTarget lTarget t0 r0
- n = indicate (logu < lAuxTarget)
- s = indicate (logu - 1000 < lAuxTarget)
+ let (t0, r0) = leapfrog glTarget (t, r) (v * e)
+ joint = log $ auxilliaryTarget lTarget t0 r0
+ n = indicate (logu < joint)
+ s = indicate (logu - 1000 < joint)
return (t0, r0, t0, r0, t0, n, s)
buildTree lTarget glTarget g t r logu v j e = do
@@ -232,14 +238,21 @@ buildTree lTarget glTarget g t r logu v j e = do
let accept = (fi n1 / max (fi (n0 + n1)) 1) > (z :: Double)
n2 = n0 + n1
+ s2 = s0 * s1 * stopCriterion tnn tpp rnn rpp
t2 | accept = t1
| otherwise = t0
- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
return (tnn, rnn, tpp, rpp, t2, n2, s2)
else return (tn, rn, tp, rp, t0, n0, s0)
+-- | Determine whether or not to stop doubling the tree of candidate states.
+stopCriterion :: (Integral a, Num b, Ord b) => [b] -> [b] -> [b] -> [b] -> a
+stopCriterion tn tp rn rp =
+ indicate (positionDifference `innerProduct` rn >= 0)
+ * indicate (positionDifference `innerProduct` rp >= 0)
+ where
+ positionDifference = tp .- tn
+
-- | Build the tree of candidate states under dual averaging.
buildTreeDualAvg
:: PrimMonad m
@@ -255,46 +268,44 @@ buildTreeDualAvg
-> Parameters
-> Parameters
-> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
-buildTreeDualAvg lTarget glTarget g = go
- where
- go t r logu v 0 e t0 r0 = return $
- let (t1, r1) = leapfrog glTarget (t, r) (v * e)
- lAuxTarget = log $ auxilliaryTarget lTarget t1 r1
- n = indicate (logu <= lAuxTarget)
- s = indicate (logu - 1000 < lAuxTarget)
- m = min 1 (acceptanceRatio lTarget t1 r1 t0 r0)
- in (t1, r1, t1, r1, t1, n, s, m, 1)
-
- go t r u v j e t0 r0 = do
- z <- uniform g
- (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
-
- if s1 == 1
- then do
- (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
- if v == -1
- then do
- (tnn', rnn', _, _, t1', n1', s1', a1', na1') <-
- go tn rn u v (pred j) e t0 r0
- return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1')
- else do
- (_, _, tpp', rpp', t1', n1', s1', a1', na1') <-
- go tp rp u v (pred j) e t0 r0
- return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
-
- let p = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1
- n3 = n1 + n2
- a3 = a1 + a2
- na3 = na1 + na2
-
- s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
-
- t3 | p > (z :: Double) = t2
- | otherwise = t1
-
- return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
- else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+buildTreeDualAvg lTarget glTarget g t r logu v 0 e t0 r0 = do
+ let (t1, r1) = leapfrog glTarget (t, r) (v * e)
+ joint = log $ auxilliaryTarget lTarget t1 r1
+ n = indicate (logu <= joint)
+ s = indicate (logu - 1000 < joint)
+ a = min 1 (acceptanceRatio lTarget t1 r1 t0 r0)
+ return (t1, r1, t1, r1, t1, n, s, a, 1)
+
+buildTreeDualAvg lTarget glTarget g t r logu v j e t0 r0 = do
+ z <- uniform g
+ (tn, rn, tp, rp, t1, n1, s1, a1, na1) <-
+ buildTreeDualAvg lTarget glTarget g t r logu v (pred j) e t0 r0
+
+ if s1 == 1
+ then do
+ (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+ if v == -1
+ then do
+ (tnn', rnn', _, _, t1', n1', s1', a1', na1') <-
+ buildTreeDualAvg lTarget glTarget g tn rn logu v (pred j) e t0 r0
+ return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1')
+ else do
+ (_, _, tpp', rpp', t1', n1', s1', a1', na1') <-
+ buildTreeDualAvg lTarget glTarget g tp rp logu v (pred j) e t0 r0
+ return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1')
+
+ let p = fromIntegral n2 / max (fromIntegral (n1 + n2)) 1
+ accept = p > (z :: Double)
+ n3 = n1 + n2
+ a3 = a1 + a2
+ na3 = na1 + na2
+ s3 = s1 * s2 * stopCriterion tnn tpp rnn rpp
+
+ t3 | accept = t2
+ | otherwise = t1
+
+ return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+ else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
-- | Heuristic for initializing step size.
findReasonableEpsilon
@@ -309,12 +320,14 @@ findReasonableEpsilon lTarget glTarget t0 g = do
let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
- go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
- let (tn, rn) = leapfrog glTarget (t, r) e
- in go (2 ^ a * e) tn rn
- | otherwise = e
+ go j e t r
+ | j <= 0 = e -- no infinite loops
+ | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) =
+ let (tn, rn) = leapfrog glTarget (t, r) e
+ in go (pred j) (2 ^^ a * e) tn rn
+ | otherwise = e
- return $ go 1.0 t1 r1
+ return $ go 1000 1.0 t1 r1
-- | Simulate a single step of Hamiltonian dynamics.
leapfrog :: Gradient -> Particle -> Double -> Particle
@@ -326,11 +339,11 @@ leapfrog glTarget (t, r) e = (tf, rf)
-- | Adjust momentum.
adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c]
-adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t)
+adjustMomentum glTarget e t r = r .+ ((e / 2) .* glTarget t)
-- | Adjust position.
adjustPosition :: Num c => c -> [c] -> [c] -> [c]
-adjustPosition e r t = zipWith (+) t (e .* r)
+adjustPosition e r t = t .+ (e .* r)
-- | The MH acceptance ratio for a given proposal.
acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
@@ -353,6 +366,10 @@ z .* xs = map (* z) xs
(.-) :: Num a => [a] -> [a] -> [a]
xs .- ys = zipWith (-) xs ys
+-- | Vectorized addition.
+(.+) :: Num a => [a] -> [a] -> [a]
+xs .+ ys = zipWith (+) xs ys
+
-- | Indicator function.
indicate :: Integral a => Bool -> a
indicate True = 1
@@ -362,8 +379,8 @@ indicate False = 0
symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
symmetricCategorical zs g = do
- z <- uniform g
- return $ zs !! truncate (z * fromIntegral (length zs) :: Double)
+ j <- uniformR (0, length zs - 1) g
+ return $ zs !! j
-- | Alias for fromIntegral.
fi :: (Integral a, Num b) => a -> b
diff --git a/tests/Test.hs b/tests/Test.hs
@@ -1,3 +1,4 @@
+import Control.Lens
import Control.Monad
import Control.Monad.Primitive
import Data.Vector (singleton)
@@ -5,31 +6,78 @@ import Numeric.AD
import Numeric.MCMC.NUTS
import System.Random.MWC
+import Debug.Trace
+
lTarget :: RealFloat a => [a] -> a
lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
-glTarget :: [Double] -> [Double]
-glTarget = grad lTarget
+-- glTarget :: [Double] -> [Double]
+-- glTarget = grad lTarget
--- glTarget [x, y] =
--- let dx = 20 * x * (y - x ^ 2) + 0.1 * (1 - x)
--- dy = -10 * (y - x ^ 2)
--- in [dx, dy]
+glTarget [x, y] =
+ let dx = 20 * x * (y - x ^ 2) + 0.1 * (1 - x)
+ dy = -10 * (y - x ^ 2)
+ in [dx, dy]
t0 = [0.0, 0.0] :: [Double]
r0 = [0.0, 0.0] :: [Double]
logu = -0.12840 :: Double
v = -1 :: Double
n = 5 :: Int
+m = 1
+madapt = 0
e = 0.1 :: Double
+eAvg = 1
+h = 0
+
+t0da = [0.0, 0.0] :: [Double]
+r0da = [0.0, 0.0] :: [Double]
runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree
runBuildTree g = do
liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e
-main = do
- test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 20000 0.075 t0
- mapM_ print test
+getThetas :: IO [Parameters]
+getThetas = replicateM 1000 getTheta
+
+getTheta :: IO Parameters
+getTheta = do
+ bt <- withSystemRandom . asGenIO $ \g ->
+ buildTree lTarget glTarget g t0 r0 logu v n e
+ return $ bt^._5
+
+getThetasDa :: IO [Parameters]
+getThetasDa = replicateM 1000 getThetaDa
+
+getThetaDa :: IO Parameters
+getThetaDa = do
+ bt <- withSystemRandom . asGenIO $ \g ->
+ buildTreeDualAvg lTarget glTarget g t0 r0 logu v n e t0da r0da
+ return $ bt^._5
+genMoveDa :: IO Parameters
+genMoveDa = withSystemRandom . asGenIO $ \g -> do
+ eps <- findReasonableEpsilon lTarget glTarget t0 g
+ let daParams = basicDualAveragingParameters eps madapt
+ blah <- nutsKernelDualAvg lTarget glTarget eps eAvg h m daParams t0 g
+ return $ blah^._4
+
+genMovesDa :: IO [Parameters]
+genMovesDa = replicateM 1000 genMoveDa
+
+genMove :: IO Parameters
+genMove = withSystemRandom . asGenIO $ nutsKernel lTarget glTarget e t0
+
+genMoves :: IO [Parameters]
+genMoves = replicateM 1000 genMove
+
+main = do
+ test <- withSystemRandom . asGenIO $
+ nutsDualAveraging lTarget glTarget 100 10 t0
+ -- nuts lTarget glTarget 5000 0.1 t0
+ -- genMovesDa
+ -- genMoves
+ mapM_ (putStrLn . filter (`notElem` "[]") . show) test
+ -- mapM_ print test