commit 26038098d57eb64c8afcac8fc2dacc1e3f3927d0
parent cc4df614316dd34b94823b42dfd86913d5b5a4af
Author: Jared Tobin <jared@jtobin.ca>
Date: Mon, 9 Sep 2013 12:19:01 +1200
Add reference, split between NUTS and dual-averaging NUTS.
Diffstat:
A | HoffmanGelman2011_NUTS.pdf | | | 0 | |
M | NUTS.hs | | | 201 | ++++++++++++++++++++++++++++++++++++++++++------------------------------------- |
A | daNUTS.hs | | | 199 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
3 files changed, 305 insertions(+), 95 deletions(-)
diff --git a/HoffmanGelman2011_NUTS.pdf b/HoffmanGelman2011_NUTS.pdf
Binary files differ.
diff --git a/NUTS.hs b/NUTS.hs
@@ -1,46 +1,69 @@
-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
-- Lengths in Hamiltonian Monte Carlo.
+module NUTS where
+
import Control.Monad
import Control.Monad.Primitive
-import System.Random.MWC
+import System.Random.MWC -- FIXME change to Prob monad
import System.Random.MWC.Distributions
import Statistics.Distribution.Normal
-type Parameters = [Double]
+-- FIXME change to probably api
+type Parameters = [Double]
type Density = Parameters -> Double
type Gradient = Parameters -> Parameters
type Particle = (Parameters, Parameters)
-leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
-leapfrogIntegrator n glTarget particle e = go particle n
- where go state ndisc
- | ndisc <= 0 = state
- | otherwise = go (leapfrog glTarget state e) (pred n)
+-- FIXME must be streaming
+nuts :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Int
+ -> Parameters
+ -> Gen (PrimState m)
+ -> m Parameters
+nuts lTarget glTarget m t g = do
+ e <- findReasonableEpsilon lTarget glTarget t g
+ let go 0 t0 = return t0
+ go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n)
+
+ go m t
+
+nutsKernel :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Double
+ -> Parameters
+ -> Gen (PrimState m)
+ -> m Parameters
+nutsKernel lTarget glTarget e t g = do
+ r0 <- replicateM (length t) (normal 0 1 g)
+ u <- uniformR (0, auxilliaryTarget lTarget t r0) g
-leapfrog :: Gradient -> Particle -> Double -> Particle
-leapfrog glTarget (t, r) e = (tf, rf)
- where rm = zipWith (+) r ((e / 2) .* glTarget t)
- tf = zipWith (+) t (e .* rm)
- rf = zipWith (+) rm ((e / 2) .* glTarget tf)
+ let go (tn, tp, rn, rp, j, tm, n, s) g
+ | s == 1 = do
+ vj <- symmetricCategorical [-1, 1] g
+ z <- uniform g
-findReasonableEpsilon :: PrimMonad m
- => Density
- -> Gradient
- -> Parameters
- -> Gen (PrimState m)
- -> m Double
-findReasonableEpsilon lTarget glTarget t0 g = do
- r0 <- replicateM (length t0) (normal 0 1 g)
- let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
- a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+ (tnn, rnn, tpp, rpp, t1, n1, s1) <-
+ if vj == -1
+ then buildTree lTarget glTarget g tn rn u vj j e
+ else buildTree lTarget glTarget g tp rp u vj j e
- go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
- let (tn, rn) = leapfrog glTarget (t, r) e
- in go (2 ^ a * e) tn rn
- | otherwise = e
+ let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
+ | otherwise = t
- return $ go 1.0 t1 r1
+ n2 = n + n1
+ s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+ j1 = succ j
+
+ go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) g
+
+ | otherwise = return tm
+
+ go (t, t, r0, r0, 0, t, 1, 1) g
buildTree
:: PrimMonad m
@@ -53,84 +76,80 @@ buildTree
-> Double
-> Int
-> Double
- -> Parameters
- -> Parameters
- -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
+ -> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
buildTree lTarget glTarget g = go
- where
- go t r u v 0 e _ r0 = return $
- let (t1, r1) = leapfrog glTarget (t, r) (v * e)
- n = indicate (u <= auxilliaryTarget lTarget t1 r1)
- s = indicate (u < exp 1000 * auxilliaryTarget lTarget t1 r1)
- m = min 1 (acceptanceRatio lTarget t1 r1 r0 r0)
- in (t1, r1, t1, r1, t1, n, s, m, 1)
-
- go t r u v j e t0 r0 = do
+ where
+ go t r u v 0 e = return $
+ let (t0, r0) = leapfrog glTarget (t, r) (v * e)
+ auxTgt = auxilliaryTarget lTarget t0 r0
+ n = indicate (u <= auxTgt)
+ s = indicate (auxTgt > log u - 1000)
+ in (t0, r0, t0, r0, t, n, s)
+
+ go t r u v j e = do
z <- uniform g
- (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+ (tn, rn, tp, rp, t0, n0, s0) <- go t r u v (pred j) e
- if s1 == 1
+ if s0 == 1
then do
- (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+ (tnn, rnn, tpp, rpp, t1, n1, s1) <-
if v == -1
- then go tn rn u v (pred j) e t0 r0
- else go tp rp u v (pred j) e t0 r0
-
- let p = fromIntegral n2 / fromIntegral (n1 + n2)
- n3 = n1 + n2
- t3 | p > (z :: Double) = t2
- | otherwise = t1
- a3 = a1 + a2
- na3 = na1 + na2
- s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
-
- return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
- else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
-
-innerNutsKernel
- :: PrimMonad m
- => Density
- -> Gradient
- -> Parameters
- -> Double
- -> Gen (PrimState m)
- -> m (Parameters, Double, Int)
-innerNutsKernel lTarget glTarget t e g = do
- r0 <- replicateM (length t) (normal 0 1 g)
- u <- uniformR (0, auxilliaryTarget lTarget t r0) g
+ then go tn rn u v (pred j) e
+ else go tp rp u v (pred j) e
- let go (tn, tp, rn, rp, j, tm, n, s) aOrig naOrig g
- | s == 1 = do
- vj <- symmetricCategorical [-1, 1] g
- z <- uniform g
+ let p = fromIntegral n1 / fromIntegral (n0 + n1)
+ n2 = n0 + n1
+ t2 | p > (z :: Double) = t1
+ | otherwise = t0
+ s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
- (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
- if vj == -1
- then buildTree lTarget glTarget g tn rn u vj j e t r0
- else buildTree lTarget glTarget g tp rp u vj j e t r0
+ return (tnn, rnn, tpp, rpp, t2, n2, s2)
+ else return (tn, rn, tp, rp, t0, n0, s0)
- let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
- | otherwise = t
+findReasonableEpsilon :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Parameters
+ -> Gen (PrimState m)
+ -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+ r0 <- replicateM (length t0) (normal 0 1 g)
+ let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
+ a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
- n2 = n + n1
- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
- j1 = succ j
+ go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
+ let (tn, rn) = leapfrog glTarget (t, r) e
+ in go (2 ^ a * e) tn rn
+ | otherwise = e
+
+ return $ go 1.0 t1 r1
- go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) a na g
+leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
+leapfrogIntegrator n glTarget particle e = go particle n
+ where go state ndisc
+ | ndisc <= 0 = state
+ | otherwise = go (leapfrog glTarget state e) (pred n)
- | otherwise = return (tm, aOrig, naOrig)
+leapfrog :: Gradient -> Particle -> Double -> Particle
+leapfrog glTarget (t, r) e = (tf, rf)
+ where rm = adjustMomentum glTarget e t r
+ tf = adjustPosition e rm t
+ rf = adjustMomentum glTarget e tf rm
- go (t, t, r0, r0, 0, t, 1, 1) 0 0 g
+adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c]
+adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t)
-auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
-auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+adjustPosition :: Num c => c -> [c] -> [c] -> [c]
+adjustPosition e r t = zipWith (+) t (e .* r)
acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
/ auxilliaryTarget lTarget t0 r0
+auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
+auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+
innerProduct :: Num a => [a] -> [a] -> a
innerProduct xs ys = sum $ zipWith (*) xs ys
@@ -153,11 +172,3 @@ symmetricCategorical zs g = do
fi :: (Integral a, Num b) => a -> b
fi = fromIntegral
--- Testing
-
-f :: Density
-f _ = log $ 1 / 10
-
-g :: Gradient
-g xs = replicate (length xs) 1
-
diff --git a/daNUTS.hs b/daNUTS.hs
@@ -0,0 +1,199 @@
+-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
+-- Lengths in Hamiltonian Monte Carlo.
+
+import Control.Monad
+import Control.Monad.Primitive
+import System.Random.MWC
+import System.Random.MWC.Distributions
+import Statistics.Distribution.Normal
+
+type Parameters = [Double]
+type Density = Parameters -> Double
+type Gradient = Parameters -> Parameters
+type Particle = (Parameters, Parameters)
+
+leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
+leapfrogIntegrator n glTarget particle e = go particle n
+ where go state ndisc
+ | ndisc <= 0 = state
+ | otherwise = go (leapfrog glTarget state e) (pred n)
+
+leapfrog :: Gradient -> Particle -> Double -> Particle
+leapfrog glTarget (t, r) e = (tf, rf)
+ where rm = zipWith (+) r ((e / 2) .* glTarget t)
+ tf = zipWith (+) t (e .* rm)
+ rf = zipWith (+) rm ((e / 2) .* glTarget tf)
+
+findReasonableEpsilon :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Parameters
+ -> Gen (PrimState m)
+ -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+ r0 <- replicateM (length t0) (normal 0 1 g)
+ let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
+ a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+
+ go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
+ let (tn, rn) = leapfrog glTarget (t, r) e
+ in go (2 ^ a * e) tn rn
+ | otherwise = e
+
+ return $ go 1.0 t1 r1
+
+-- this is the dual averaging buildTree
+buildTree
+ :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Gen (PrimState m)
+ -> Parameters
+ -> Parameters
+ -> Double
+ -> Double
+ -> Int
+ -> Double
+ -> Parameters
+ -> Parameters
+ -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
+buildTree lTarget glTarget g = go
+ where
+ go t r u v 0 e _ r0 = return $
+ let (t1, r1) = leapfrog glTarget (t, r) (v * e)
+ n = indicate (u <= auxilliaryTarget lTarget t1 r1)
+ s = indicate (u < exp 1000 * auxilliaryTarget lTarget t1 r1)
+ m = min 1 (acceptanceRatio lTarget t1 r1 r0 r0)
+ in (t1, r1, t1, r1, t1, n, s, m, 1)
+
+ go t r u v j e t0 r0 = do
+ z <- uniform g
+ (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+
+ if s1 == 1
+ then do
+ (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+ if v == -1
+ then go tn rn u v (pred j) e t0 r0
+ else go tp rp u v (pred j) e t0 r0
+
+ let p = fromIntegral n2 / fromIntegral (n1 + n2)
+ n3 = n1 + n2
+ t3 | p > (z :: Double) = t2
+ | otherwise = t1
+ a3 = a1 + a2
+ na3 = na1 + na2
+ s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+ return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+ else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+
+relaxingNuts = undefined
+
+-- better idea: wrap this dual averaging scheme around the actual nuts
+-- kernel itself. in fact you'd like to just be able to loosely
+-- add dual-averaging to any procedure.
+--
+-- adaptingNutsKenel lTarget glTarget t m g = do
+-- e0 <- findReasonableEpsilon lTarget glTarget t g
+--
+-- let mu = log (10 * e)
+-- epsBar0 = 0
+-- h0Bar = 0
+-- gamma = 0.05
+-- delta = 0.45 -- target mean acceptance probability
+-- tau0 = 10
+-- kappa = 0.75
+--
+-- go hBar eNext logEpsBar tToReturn n
+-- | n <= 0 = return (tToReturn, logEpsBar,
+--
+-- | otherwise = do
+-- (t0, a, na) <- innerNutsKernel lTarget glTarget t e g
+-- let hBarNext = (1 - 1 / (m - n + tau0)) * hBar
+-- + (1 / (m - n + tau0)) * (delta - a)
+--
+-- logEpsNext = mu - ((sqrt (m - n)) / gamma) * hmBar
+--
+-- logEpsBarNext = (m - n) ^ (-kappa) * logEpsNext
+-- + (1 - (m - n) ^ (-kappa)) * logEpsBar
+--
+-- go hBarNext logEpsBarNext t0 (pred n)
+
+
+
+
+innerNutsKernel
+ :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Parameters
+ -> Double
+ -> Gen (PrimState m)
+ -> m (Parameters, Double, Int)
+innerNutsKernel lTarget glTarget t e g = do
+ r0 <- replicateM (length t) (normal 0 1 g)
+ u <- uniformR (0, auxilliaryTarget lTarget t r0) g
+
+ let go (tn, tp, rn, rp, j, tm, n, s) aOrig naOrig g
+ | s == 1 = do
+ vj <- symmetricCategorical [-1, 1] g
+ z <- uniform g
+
+ (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
+ if vj == -1
+ then buildTree lTarget glTarget g tn rn u vj j e t r0
+ else buildTree lTarget glTarget g tp rp u vj j e t r0
+
+ let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
+ | otherwise = t
+
+ n2 = n + n1
+ s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+ j1 = succ j
+
+ go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) a na g
+
+ | otherwise = return (tm, aOrig, naOrig)
+
+ go (t, t, r0, r0, 0, t, 1, 1) 0 0 g
+
+auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
+auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+
+acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
+acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
+ / auxilliaryTarget lTarget t0 r0
+
+innerProduct :: Num a => [a] -> [a] -> a
+innerProduct xs ys = sum $ zipWith (*) xs ys
+
+(.*) :: Num b => b -> [b] -> [b]
+z .* xs = map (* z) xs
+
+(.-) :: Num a => [a] -> [a] -> [a]
+xs .- ys = zipWith (-) xs ys
+
+indicate :: Integral a => Bool -> a
+indicate True = 1
+indicate False = 0
+
+symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
+symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
+symmetricCategorical zs g = do
+ z <- uniform g
+ return $ zs !! truncate (z * fromIntegral (length zs) :: Double)
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+
+-- Testing
+
+f :: Density
+f _ = log $ 1 / 10
+
+g :: Gradient
+g xs = replicate (length xs) 0
+