hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 84720a3c176a17f52410a3b140855aad814ce735
parent 26038098d57eb64c8afcac8fc2dacc1e3f3927d0
Author: Jared Tobin <jared@jtobin.ca>
Date:   Mon, 16 Sep 2013 08:59:47 +1200

Reorganize module.

Diffstat:
DNUTS.hs | 174-------------------------------------------------------------------------------
RHoffmanGelman2011_NUTS.pdf -> reference/HoffmanGelman2011_NUTS.pdf | 0
Asrc/Numeric/MCMC/Examples/Rosenbrock.hs | 16++++++++++++++++
Asrc/Numeric/MCMC/NUTS.hs | 175+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
RHMC.hs -> working/HMC.hs | 0
RdaNUTS.hs -> working/daNUTS.hs | 0
6 files changed, 191 insertions(+), 174 deletions(-)

diff --git a/NUTS.hs b/NUTS.hs @@ -1,174 +0,0 @@ --- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path --- Lengths in Hamiltonian Monte Carlo. - -module NUTS where - -import Control.Monad -import Control.Monad.Primitive -import System.Random.MWC -- FIXME change to Prob monad -import System.Random.MWC.Distributions -import Statistics.Distribution.Normal - --- FIXME change to probably api -type Parameters = [Double] -type Density = Parameters -> Double -type Gradient = Parameters -> Parameters -type Particle = (Parameters, Parameters) - --- FIXME must be streaming -nuts :: PrimMonad m - => Density - -> Gradient - -> Int - -> Parameters - -> Gen (PrimState m) - -> m Parameters -nuts lTarget glTarget m t g = do - e <- findReasonableEpsilon lTarget glTarget t g - let go 0 t0 = return t0 - go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n) - - go m t - -nutsKernel :: PrimMonad m - => Density - -> Gradient - -> Double - -> Parameters - -> Gen (PrimState m) - -> m Parameters -nutsKernel lTarget glTarget e t g = do - r0 <- replicateM (length t) (normal 0 1 g) - u <- uniformR (0, auxilliaryTarget lTarget t r0) g - - let go (tn, tp, rn, rp, j, tm, n, s) g - | s == 1 = do - vj <- symmetricCategorical [-1, 1] g - z <- uniform g - - (tnn, rnn, tpp, rpp, t1, n1, s1) <- - if vj == -1 - then buildTree lTarget glTarget g tn rn u vj j e - else buildTree lTarget glTarget g tp rp u vj j e - - let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn - | otherwise = t - - n2 = n + n1 - s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) - j1 = succ j - - go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) g - - | otherwise = return tm - - go (t, t, r0, r0, 0, t, 1, 1) g - -buildTree - :: PrimMonad m - => Density - -> Gradient - -> Gen (PrimState m) - -> Parameters - -> Parameters - -> Double - -> Double - -> Int - -> Double - -> m ([Double], [Double], [Double], [Double], [Double], Int, Int) -buildTree lTarget glTarget g = go - where - go t r u v 0 e = return $ - let (t0, r0) = leapfrog glTarget (t, r) (v * e) - auxTgt = auxilliaryTarget lTarget t0 r0 - n = indicate (u <= auxTgt) - s = indicate (auxTgt > log u - 1000) - in (t0, r0, t0, r0, t, n, s) - - go t r u v j e = do - z <- uniform g - (tn, rn, tp, rp, t0, n0, s0) <- go t r u v (pred j) e - - if s0 == 1 - then do - (tnn, rnn, tpp, rpp, t1, n1, s1) <- - if v == -1 - then go tn rn u v (pred j) e - else go tp rp u v (pred j) e - - let p = fromIntegral n1 / fromIntegral (n0 + n1) - n2 = n0 + n1 - t2 | p > (z :: Double) = t1 - | otherwise = t0 - s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) - - return (tnn, rnn, tpp, rpp, t2, n2, s2) - else return (tn, rn, tp, rp, t0, n0, s0) - -findReasonableEpsilon :: PrimMonad m - => Density - -> Gradient - -> Parameters - -> Gen (PrimState m) - -> m Double -findReasonableEpsilon lTarget glTarget t0 g = do - r0 <- replicateM (length t0) (normal 0 1 g) - let (t1, r1) = leapfrog glTarget (t0, r0) 1.0 - a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 - - go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = - let (tn, rn) = leapfrog glTarget (t, r) e - in go (2 ^ a * e) tn rn - | otherwise = e - - return $ go 1.0 t1 r1 - -leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle -leapfrogIntegrator n glTarget particle e = go particle n - where go state ndisc - | ndisc <= 0 = state - | otherwise = go (leapfrog glTarget state e) (pred n) - -leapfrog :: Gradient -> Particle -> Double -> Particle -leapfrog glTarget (t, r) e = (tf, rf) - where rm = adjustMomentum glTarget e t r - tf = adjustPosition e rm t - rf = adjustMomentum glTarget e tf rm - -adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c] -adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t) - -adjustPosition :: Num c => c -> [c] -> [c] -> [c] -adjustPosition e r t = zipWith (+) t (e .* r) - -acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a -acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1 - / auxilliaryTarget lTarget t0 r0 - -auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a -auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r) - -innerProduct :: Num a => [a] -> [a] -> a -innerProduct xs ys = sum $ zipWith (*) xs ys - -(.*) :: Num b => b -> [b] -> [b] -z .* xs = map (* z) xs - -(.-) :: Num a => [a] -> [a] -> [a] -xs .- ys = zipWith (-) xs ys - -indicate :: Integral a => Bool -> a -indicate True = 1 -indicate False = 0 - -symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a -symmetricCategorical [] _ = error "symmetricCategorical: no candidates" -symmetricCategorical zs g = do - z <- uniform g - return $ zs !! truncate (z * fromIntegral (length zs) :: Double) - -fi :: (Integral a, Num b) => a -> b -fi = fromIntegral - diff --git a/HoffmanGelman2011_NUTS.pdf b/reference/HoffmanGelman2011_NUTS.pdf Binary files differ. diff --git a/src/Numeric/MCMC/Examples/Rosenbrock.hs b/src/Numeric/MCMC/Examples/Rosenbrock.hs @@ -0,0 +1,16 @@ +import Numeric.AD +import System.Random.MWC + +lTarget :: RealFloat a => [a] -> a +lTarget [x0, x1] = (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) + +glTarget :: [Double] -> [Double] +glTarget = grad lTarget + +inits :: [Double] +inits = [5.0, 5.0] + +epochs :: Int +epochs = 100 + + diff --git a/src/Numeric/MCMC/NUTS.hs b/src/Numeric/MCMC/NUTS.hs @@ -0,0 +1,175 @@ +-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path +-- Lengths in Hamiltonian Monte Carlo. + +module Numeric.MCMC.NUTS where + +import Control.Monad +import Control.Monad +import Control.Monad.Primitive +import System.Random.MWC -- FIXME change to Prob monad +import System.Random.MWC.Distributions +import Statistics.Distribution.Normal + +-- FIXME change to probably api +type Parameters = [Double] +type Density = Parameters -> Double +type Gradient = Parameters -> Parameters +type Particle = (Parameters, Parameters) + +-- FIXME must be streaming +nuts :: PrimMonad m + => Density + -> Gradient + -> Int + -> Parameters + -> Gen (PrimState m) + -> m Parameters +nuts lTarget glTarget m t g = do + e <- findReasonableEpsilon lTarget glTarget t g + let go 0 t0 = return t0 + go n t0 = nutsKernel lTarget glTarget e t0 g >>= go (pred n) + + go m t + +nutsKernel :: PrimMonad m + => Density + -> Gradient + -> Double + -> Parameters + -> Gen (PrimState m) + -> m Parameters +nutsKernel lTarget glTarget e t g = do + r0 <- replicateM (length t) (normal 0 1 g) + u <- uniformR (0, auxilliaryTarget lTarget t r0) g + + let go (tn, tp, rn, rp, j, tm, n, s) g + | s == 1 = do + vj <- symmetricCategorical [-1, 1] g + z <- uniform g + + (tnn, rnn, tpp, rpp, t1, n1, s1) <- + if vj == -1 + then buildTree lTarget glTarget g tn rn u vj j e + else buildTree lTarget glTarget g tp rp u vj j e + + let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn + | otherwise = t + + n2 = n + n1 + s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + j1 = succ j + + go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) g + + | otherwise = return tm + + go (t, t, r0, r0, 0, t, 1, 1) g + +buildTree + :: PrimMonad m + => Density + -> Gradient + -> Gen (PrimState m) + -> Parameters + -> Parameters + -> Double + -> Double + -> Int + -> Double + -> m ([Double], [Double], [Double], [Double], [Double], Int, Int) +buildTree lTarget glTarget g = go + where + go t r u v 0 e = return $ + let (t0, r0) = leapfrog glTarget (t, r) (v * e) + auxTgt = auxilliaryTarget lTarget t0 r0 + n = indicate (u <= auxTgt) + s = indicate (auxTgt > log u - 1000) + in (t0, r0, t0, r0, t, n, s) + + go t r u v j e = do + z <- uniform g + (tn, rn, tp, rp, t0, n0, s0) <- go t r u v (pred j) e + + if s0 == 1 + then do + (tnn, rnn, tpp, rpp, t1, n1, s1) <- + if v == -1 + then go tn rn u v (pred j) e + else go tp rp u v (pred j) e + + let p = fromIntegral n1 / fromIntegral (n0 + n1) + n2 = n0 + n1 + t2 | p > (z :: Double) = t1 + | otherwise = t0 + s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + + return (tnn, rnn, tpp, rpp, t2, n2, s2) + else return (tn, rn, tp, rp, t0, n0, s0) + +findReasonableEpsilon :: PrimMonad m + => Density + -> Gradient + -> Parameters + -> Gen (PrimState m) + -> m Double +findReasonableEpsilon lTarget glTarget t0 g = do + r0 <- replicateM (length t0) (normal 0 1 g) + let (t1, r1) = leapfrog glTarget (t0, r0) 1.0 + a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 + + go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = + let (tn, rn) = leapfrog glTarget (t, r) e + in go (2 ^ a * e) tn rn + | otherwise = e + + return $ go 1.0 t1 r1 + +leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle +leapfrogIntegrator n glTarget particle e = go particle n + where go state ndisc + | ndisc <= 0 = state + | otherwise = go (leapfrog glTarget state e) (pred n) + +leapfrog :: Gradient -> Particle -> Double -> Particle +leapfrog glTarget (t, r) e = (tf, rf) + where rm = adjustMomentum glTarget e t r + tf = adjustPosition e rm t + rf = adjustMomentum glTarget e tf rm + +adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c] +adjustMomentum glTarget e t r = zipWith (+) r ((e / 2) .* glTarget t) + +adjustPosition :: Num c => c -> [c] -> [c] -> [c] +adjustPosition e r t = zipWith (+) t (e .* r) + +acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a +acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1 + / auxilliaryTarget lTarget t0 r0 + +auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a +auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r) + +innerProduct :: Num a => [a] -> [a] -> a +innerProduct xs ys = sum $ zipWith (*) xs ys + +(.*) :: Num b => b -> [b] -> [b] +z .* xs = map (* z) xs + +(.-) :: Num a => [a] -> [a] -> [a] +xs .- ys = zipWith (-) xs ys + +indicate :: Integral a => Bool -> a +indicate True = 1 +indicate False = 0 + +symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a +symmetricCategorical [] _ = error "symmetricCategorical: no candidates" +symmetricCategorical zs g = do + z <- uniform g + return $ zs !! truncate (z * fromIntegral (length zs) :: Double) + +fi :: (Integral a, Num b) => a -> b +fi = fromIntegral + diff --git a/HMC.hs b/working/HMC.hs diff --git a/daNUTS.hs b/working/daNUTS.hs