commit ad7925d2aa33a7ed0f6822370ddc12a746e40f03
parent b7b3b030cea290e3a5d1876eab9cc31d75158410
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 13 Oct 2013 21:47:15 +1300
Switch buildTree to a fully recursive version.
Diffstat:
2 files changed, 68 insertions(+), 73 deletions(-)
diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs
@@ -6,13 +6,10 @@ module Numeric.MCMC.NUTS where
import Control.Monad
import Control.Monad.Loops
import Control.Monad.Primitive
-import System.Random.MWC -- FIXME change to Prob monad (Mersenne64)
+import System.Random.MWC
import System.Random.MWC.Distributions
import Statistics.Distribution.Normal
-import Debug.Trace
-
--- FIXME change to probably api
type Parameters = [Double]
type Density = Parameters -> Double
type Gradient = Parameters -> Parameters
@@ -25,13 +22,14 @@ newtype BuildTree = BuildTree {
instance Show BuildTree where
show (BuildTree (tm, rm, tp, rp, t', n, s)) =
"\n" ++ "tm: " ++ show tm
- ++ "\n" ++ "rm: " ++ show rm
+ -- ++ "\n" ++ "rm: " ++ show rm
++ "\n" ++ "tp: " ++ show tp
- ++ "\n" ++ "rp: " ++ show rp
+ -- ++ "\n" ++ "rp: " ++ show rp
++ "\n" ++ "t': " ++ show t'
++ "\n" ++ "n : " ++ show n
++ "\n" ++ "s : " ++ show s
+-- | The NUTS sampler.
nuts
:: PrimMonad m
=> Density
@@ -49,6 +47,7 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t)
then Nothing
else Just (t0, (pred n, t1))
+-- | A single iteration of NUTS.
nutsKernel
:: PrimMonad m
=> Density
@@ -78,8 +77,9 @@ nutsKernel lTarget glTarget e t g = do
buildTree lTarget glTarget g tp rp logu vj j e
return (tn, rn, tpp', rpp', t1', n1', s1')
- let t2 | s1 == 1 && (fi n1 / fi n :: Double) > z = t1
- | otherwise = t
+ let t2 | s1 == 1
+ && (fi n1 / fi n :: Double) > z = t1
+ | otherwise = tm
n2 = n + n1
s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
@@ -92,6 +92,7 @@ nutsKernel lTarget glTarget e t g = do
go (t, t, r0, r0, 0, t, 1, 1) g
+-- | Build the 'tree' of candidate states.
buildTree
:: PrimMonad m
=> Density
@@ -104,103 +105,98 @@ buildTree
-> Int
-> Double
-> m ([Double], [Double], [Double], [Double], [Double], Int, Int)
-buildTree lTarget glTarget g = go
- where
- go t r logu v 0 e = return $
- let (t0, r0) = leapfrog glTarget (t, r) (v * e)
- auxTarget = auxilliaryTarget lTarget t0 r0
- n = indicate (logu < auxTarget)
- s = indicate (logu - 1000 < auxTarget)
- in (t0, r0, t0, r0, t0, n, s)
-
- go t r logu v j e = do
- z <- uniform g
- (tn, rn, tp, rp, t0, n0, s0) <- go t r logu v (pred j) e
-
- if s0 == 1
+buildTree lTarget glTarget g t r logu v 0 e = do
+ let (t0, r0) = leapfrog glTarget (t, r) (v * e)
+ lAuxTarget = log $ auxilliaryTarget lTarget t0 r0
+ n = indicate (logu < lAuxTarget)
+ s = indicate (logu - 1000 < lAuxTarget)
+ return (t0, r0, t0, r0, t0, n, s)
+
+buildTree lTarget glTarget g t r logu v j e = do
+ z <- uniform g
+ (tn, rn, tp, rp, t0, n0, s0) <-
+ buildTree lTarget glTarget g t r logu v (pred j) e
+
+ if s0 == 1
+ then do
+ (tnn, rnn, tpp, rpp, t1, n1, s1) <-
+ if v == -1
then do
- (tnn, rnn, tpp, rpp, t1, n1, s1) <-
- if v == -1
- then do
- (tnn', rnn', _, _, t1', n1', s1') <- go tn rn logu v (pred j) e
- return (tnn', rnn', tp, rp, t1', n1', s1')
- else do
- (_, _, tpp', rpp', t1', n1', s1') <- go tp rp logu v (pred j) e
- return (tn, rn, tpp', rpp', t1', n1', s1')
-
- let p = fromIntegral n1 / fromIntegral (n0 + n1)
- n2 = n0 + n1
- t2 | p > (z :: Double) = t1
- | otherwise = t0
- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ (tnn', rnn', _, _, t1', n1', s1') <-
+ buildTree lTarget glTarget g tn rn logu v (pred j) e
+ return (tnn', rnn', tp, rp, t1', n1', s1')
+ else do
+ (_, _, tpp', rpp', t1', n1', s1') <-
+ buildTree lTarget glTarget g tp rp logu v (pred j) e
+ return (tn, rn, tpp', rpp', t1', n1', s1')
+
+ let accept = (fi n1 / max (fi (n0 + n1)) 1) > (z :: Double)
+ n2 = n0 + n1
+ t2 | accept = t1
+ | otherwise = t0
+ s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
* indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
- return (tnn, rnn, tpp, rpp, t2, n2, s2)
- else return (tn, rn, tp, rp, t0, n0, s0)
-
-findReasonableEpsilon
- :: PrimMonad m
- => Density
- -> Gradient
- -> Parameters
- -> Gen (PrimState m)
- -> m Double
-findReasonableEpsilon lTarget glTarget t0 g = do
- r0 <- replicateM (length t0) (normal 0 1 g)
- let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
- a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
-
- go e t r | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) =
- let (tn, rn) = leapfrog glTarget (t, r) e
- in go (2 ^^ a * e) tn rn
- | otherwise = e
-
- return $ go 1.0 t1 r1
+ return (tnn, rnn, tpp, rpp, t2, n2, s2)
+ else return (tn, rn, tp, rp, t0, n0, s0)
+-- | Simulate Hamiltonian dynamics for n steps.
leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
leapfrogIntegrator n glTarget particle e = go particle n
where go state ndisc
| ndisc <= 0 = state
| otherwise = go (leapfrog glTarget state e) (pred n)
+-- | Simulate a single step of Hamiltonian dynamics.
leapfrog :: Gradient -> Particle -> Double -> Particle
leapfrog glTarget (t, r) e = (tf, rf)
- where rm = adjustMomentum glTarget e t r
- tf = adjustPosition e rm t
- rf = adjustMomentum glTarget e tf rm
+ where
+ rm = adjustMomentum glTarget e t r
+ tf = adjustPosition e rm t
+ rf = adjustMomentum glTarget e tf rm
+-- | Adjust momentum.
adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c]
adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t)
+-- | Adjust position.
adjustPosition :: Num c => c -> [c] -> [c] -> [c]
adjustPosition e r t = zipWith (+) t (e .* r)
+-- | The MH acceptance ratio for a given proposal.
acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
/ auxilliaryTarget lTarget t0 r0
+-- | The negative potential.
auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+-- | Simple inner product.
innerProduct :: Num a => [a] -> [a] -> a
innerProduct xs ys = sum $ zipWith (*) xs ys
+-- | Vectorized multiplication.
(.*) :: Num b => b -> [b] -> [b]
z .* xs = map (* z) xs
+-- | Vectorized subtraction.
(.-) :: Num a => [a] -> [a] -> [a]
xs .- ys = zipWith (-) xs ys
+-- | Indicator function.
indicate :: Integral a => Bool -> a
indicate True = 1
indicate False = 0
+-- | A symmetric categorical (discrete uniform) distribution.
symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
symmetricCategorical zs g = do
z <- uniform g
return $ zs !! truncate (z * fromIntegral (length zs) :: Double)
+-- | Alias for fromIntegral.
fi :: (Integral a, Num b) => a -> b
fi = fromIntegral
diff --git a/tests/Test.hs b/tests/Test.hs
@@ -11,25 +11,24 @@ lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
glTarget :: [Double] -> [Double]
glTarget = grad lTarget
-t0 :: [Double]
-t0 = [0.0, 0.0]
-
-r0 :: [Double]
-r0 = [0.0, 0.0]
-
-logu = -0.12840 -- from octave
-u = exp logu
-v = -1 :: Double
-
-n = 9 :: Int
-e = 0.1 :: Double
+-- glTarget [x, y] =
+-- let dx = 20 * x * (y - x ^ 2) + 0.1 * (1 - x)
+-- dy = -10 * (y - x ^ 2)
+-- in [dx, dy]
+
+t0 = [0.0, 0.0] :: [Double]
+r0 = [0.0, 0.0] :: [Double]
+logu = -0.12840 :: Double
+v = -1 :: Double
+n = 5 :: Int
+e = 0.1 :: Double
runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree
runBuildTree g = do
liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e
main = do
- test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 1000 0.1 t0
+ test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 20000 0.075 t0
mapM_ print test