hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit ad7925d2aa33a7ed0f6822370ddc12a746e40f03
parent b7b3b030cea290e3a5d1876eab9cc31d75158410
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun, 13 Oct 2013 21:47:15 +1300

Switch buildTree to a fully recursive version.

Diffstat:
Msrc/Numeric/MCMC/NUTS.hs | 116++++++++++++++++++++++++++++++++++++++-----------------------------------------
Mtests/Test.hs | 25++++++++++++-------------
2 files changed, 68 insertions(+), 73 deletions(-)

diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs @@ -6,13 +6,10 @@ module Numeric.MCMC.NUTS where import Control.Monad import Control.Monad.Loops import Control.Monad.Primitive -import System.Random.MWC -- FIXME change to Prob monad (Mersenne64) +import System.Random.MWC import System.Random.MWC.Distributions import Statistics.Distribution.Normal -import Debug.Trace - --- FIXME change to probably api type Parameters = [Double] type Density = Parameters -> Double type Gradient = Parameters -> Parameters @@ -25,13 +22,14 @@ newtype BuildTree = BuildTree { instance Show BuildTree where show (BuildTree (tm, rm, tp, rp, t', n, s)) = "\n" ++ "tm: " ++ show tm - ++ "\n" ++ "rm: " ++ show rm + -- ++ "\n" ++ "rm: " ++ show rm ++ "\n" ++ "tp: " ++ show tp - ++ "\n" ++ "rp: " ++ show rp + -- ++ "\n" ++ "rp: " ++ show rp ++ "\n" ++ "t': " ++ show t' ++ "\n" ++ "n : " ++ show n ++ "\n" ++ "s : " ++ show s +-- | The NUTS sampler. nuts :: PrimMonad m => Density @@ -49,6 +47,7 @@ nuts lTarget glTarget m e t g = unfoldrM (kernel e) (m, t) then Nothing else Just (t0, (pred n, t1)) +-- | A single iteration of NUTS. nutsKernel :: PrimMonad m => Density @@ -78,8 +77,9 @@ nutsKernel lTarget glTarget e t g = do buildTree lTarget glTarget g tp rp logu vj j e return (tn, rn, tpp', rpp', t1', n1', s1') - let t2 | s1 == 1 && (fi n1 / fi n :: Double) > z = t1 - | otherwise = t + let t2 | s1 == 1 + && (fi n1 / fi n :: Double) > z = t1 + | otherwise = tm n2 = n + n1 s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) @@ -92,6 +92,7 @@ nutsKernel lTarget glTarget e t g = do go (t, t, r0, r0, 0, t, 1, 1) g +-- | Build the 'tree' of candidate states. buildTree :: PrimMonad m => Density @@ -104,103 +105,98 @@ buildTree -> Int -> Double -> m ([Double], [Double], [Double], [Double], [Double], Int, Int) -buildTree lTarget glTarget g = go - where - go t r logu v 0 e = return $ - let (t0, r0) = leapfrog glTarget (t, r) (v * e) - auxTarget = auxilliaryTarget lTarget t0 r0 - n = indicate (logu < auxTarget) - s = indicate (logu - 1000 < auxTarget) - in (t0, r0, t0, r0, t0, n, s) - - go t r logu v j e = do - z <- uniform g - (tn, rn, tp, rp, t0, n0, s0) <- go t r logu v (pred j) e - - if s0 == 1 +buildTree lTarget glTarget g t r logu v 0 e = do + let (t0, r0) = leapfrog glTarget (t, r) (v * e) + lAuxTarget = log $ auxilliaryTarget lTarget t0 r0 + n = indicate (logu < lAuxTarget) + s = indicate (logu - 1000 < lAuxTarget) + return (t0, r0, t0, r0, t0, n, s) + +buildTree lTarget glTarget g t r logu v j e = do + z <- uniform g + (tn, rn, tp, rp, t0, n0, s0) <- + buildTree lTarget glTarget g t r logu v (pred j) e + + if s0 == 1 + then do + (tnn, rnn, tpp, rpp, t1, n1, s1) <- + if v == -1 then do - (tnn, rnn, tpp, rpp, t1, n1, s1) <- - if v == -1 - then do - (tnn', rnn', _, _, t1', n1', s1') <- go tn rn logu v (pred j) e - return (tnn', rnn', tp, rp, t1', n1', s1') - else do - (_, _, tpp', rpp', t1', n1', s1') <- go tp rp logu v (pred j) e - return (tn, rn, tpp', rpp', t1', n1', s1') - - let p = fromIntegral n1 / fromIntegral (n0 + n1) - n2 = n0 + n1 - t2 | p > (z :: Double) = t1 - | otherwise = t0 - s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + (tnn', rnn', _, _, t1', n1', s1') <- + buildTree lTarget glTarget g tn rn logu v (pred j) e + return (tnn', rnn', tp, rp, t1', n1', s1') + else do + (_, _, tpp', rpp', t1', n1', s1') <- + buildTree lTarget glTarget g tp rp logu v (pred j) e + return (tn, rn, tpp', rpp', t1', n1', s1') + + let accept = (fi n1 / max (fi (n0 + n1)) 1) > (z :: Double) + n2 = n0 + n1 + t2 | accept = t1 + | otherwise = t0 + s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) - return (tnn, rnn, tpp, rpp, t2, n2, s2) - else return (tn, rn, tp, rp, t0, n0, s0) - -findReasonableEpsilon - :: PrimMonad m - => Density - -> Gradient - -> Parameters - -> Gen (PrimState m) - -> m Double -findReasonableEpsilon lTarget glTarget t0 g = do - r0 <- replicateM (length t0) (normal 0 1 g) - let (t1, r1) = leapfrog glTarget (t0, r0) 1.0 - a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 - - go e t r | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) = - let (tn, rn) = leapfrog glTarget (t, r) e - in go (2 ^^ a * e) tn rn - | otherwise = e - - return $ go 1.0 t1 r1 + return (tnn, rnn, tpp, rpp, t2, n2, s2) + else return (tn, rn, tp, rp, t0, n0, s0) +-- | Simulate Hamiltonian dynamics for n steps. leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle leapfrogIntegrator n glTarget particle e = go particle n where go state ndisc | ndisc <= 0 = state | otherwise = go (leapfrog glTarget state e) (pred n) +-- | Simulate a single step of Hamiltonian dynamics. leapfrog :: Gradient -> Particle -> Double -> Particle leapfrog glTarget (t, r) e = (tf, rf) - where rm = adjustMomentum glTarget e t r - tf = adjustPosition e rm t - rf = adjustMomentum glTarget e tf rm + where + rm = adjustMomentum glTarget e t r + tf = adjustPosition e rm t + rf = adjustMomentum glTarget e tf rm +-- | Adjust momentum. adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c] adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t) +-- | Adjust position. adjustPosition :: Num c => c -> [c] -> [c] -> [c] adjustPosition e r t = zipWith (+) t (e .* r) +-- | The MH acceptance ratio for a given proposal. acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1 / auxilliaryTarget lTarget t0 r0 +-- | The negative potential. auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r) +-- | Simple inner product. innerProduct :: Num a => [a] -> [a] -> a innerProduct xs ys = sum $ zipWith (*) xs ys +-- | Vectorized multiplication. (.*) :: Num b => b -> [b] -> [b] z .* xs = map (* z) xs +-- | Vectorized subtraction. (.-) :: Num a => [a] -> [a] -> [a] xs .- ys = zipWith (-) xs ys +-- | Indicator function. indicate :: Integral a => Bool -> a indicate True = 1 indicate False = 0 +-- | A symmetric categorical (discrete uniform) distribution. symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a symmetricCategorical [] _ = error "symmetricCategorical: no candidates" symmetricCategorical zs g = do z <- uniform g return $ zs !! truncate (z * fromIntegral (length zs) :: Double) +-- | Alias for fromIntegral. fi :: (Integral a, Num b) => a -> b fi = fromIntegral diff --git a/tests/Test.hs b/tests/Test.hs @@ -11,25 +11,24 @@ lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) glTarget :: [Double] -> [Double] glTarget = grad lTarget -t0 :: [Double] -t0 = [0.0, 0.0] - -r0 :: [Double] -r0 = [0.0, 0.0] - -logu = -0.12840 -- from octave -u = exp logu -v = -1 :: Double - -n = 9 :: Int -e = 0.1 :: Double +-- glTarget [x, y] = +-- let dx = 20 * x * (y - x ^ 2) + 0.1 * (1 - x) +-- dy = -10 * (y - x ^ 2) +-- in [dx, dy] + +t0 = [0.0, 0.0] :: [Double] +r0 = [0.0, 0.0] :: [Double] +logu = -0.12840 :: Double +v = -1 :: Double +n = 5 :: Int +e = 0.1 :: Double runBuildTree :: PrimMonad m => Gen (PrimState m) -> m BuildTree runBuildTree g = do liftM BuildTree $ buildTree lTarget glTarget g t0 r0 logu v n e main = do - test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 1000 0.1 t0 + test <- withSystemRandom . asGenIO $ nuts lTarget glTarget 20000 0.075 t0 mapM_ print test