hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 1c74574cc0576fc48cc213b2566ad7d9df2e2977
parent 7db25113e0a796e022dcd9acbbab3a33b26812e0
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun,  1 Sep 2013 22:03:40 +1200

Add NUTS kernel skeleton.

Diffstat:
MNUTS.hs | 70++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------
1 file changed, 56 insertions(+), 14 deletions(-)

diff --git a/NUTS.hs b/NUTS.hs @@ -12,6 +12,10 @@ import qualified Data.HashSet as HashSet import System.Random.MWC import System.Random.MWC.Distributions +-- TODO what am i +dMax :: Num t => t +dMax = 1000 + hmc :: (Enum a, Eq a, Ord a, Num a, PrimMonad m ) => ([Double] -> Double) -- ^ Log target function -> ([Double] -> [Double]) -- ^ Gradient of log target @@ -92,9 +96,6 @@ findReasonableEpsilon lTarget glTarget t0 g = do return $ go 1.0 t1 r1 --- go needs to return in some monad - - buildTree :: (Enum a, Eq a, Floating t, Fractional c, Integral c, Integral d , Num a, Num e, RealFrac d, RealFrac t, PrimMonad m, Variate c) => ([t] -> t) @@ -114,7 +115,7 @@ buildTree lTarget glTarget g = go go t r u v 0 e _ r0 = return $ let (t1, r1) = leapfrog glTarget t r 1 (v * e) n = indicate (u <= auxilliaryTarget lTarget t1 r1) - s = indicate (u < exp 1000 * auxilliaryTarget lTarget t1 r1) + s = indicate (u < exp dMax * auxilliaryTarget lTarget t1 r1) m = min 1 (hmcAcceptanceRatio lTarget t1 r1 r0 r0) in (t1, r1, t1, r1, t1, n, s, m, 1) @@ -137,19 +138,53 @@ buildTree lTarget glTarget g = go a3 = a1 + a2 na3 = na1 + na2 - s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) - n3 = n1 + n2 + n3 = n1 + n2 return $ (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3) else return $ (tn, rn, tp, rp, t1, n1, s1, a1, na1) - - - - - - +data AdaptiveState = Adapting | Resting deriving (Eq, Show) + +-- TODO get this compiling +-- nutsKernel lTarget glTarget t d adaptiveState e h0 lEmBar0 g = do +-- r0 <- replicateM (length t) (normal 0 1 g) +-- u <- uniformR (0, auxilliaryTarget t r0) g +-- +-- let (tn, tp, rn, rp, j, tm, n, s) = (t, t, r0, r0, 0, t, 1, 1) +-- +-- go i tt | i == 1 = do +-- v <- discreteUniform [-1, 1] g +-- z <- uniform g +-- +-- (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <- +-- if v == -1 +-- then buildTree tn rn u v j (e * t) r0 +-- else buildTree tp rp u v j (e * t) r0 +-- +-- +-- let t2 | min (1) (n1 / n) > z = tnn +-- | otherwise = t +-- +-- n2 = n + n1 +-- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) +-- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) +-- j1 = succ j +-- +-- return $ go i t2 +-- +-- | otherwise = return tt +-- +-- tSpun <- go s t +-- +-- if adaptiveState == Adapting +-- then let hmBar = (1 - 1 / (m + t0)) * h0 + (1 / (m + t0)) * (d - a / na) -- need iteration counter +-- lEm = mu - (sqrt m / gam) * hmBar +-- lEmBar = m ^ (-kappa) * lEm + (1 - m ^ (-kappa)) * lEmBar0 +-- else let em = emAdaptBar +-- +-- return $ (tSpun, hmBar, @@ -169,6 +204,13 @@ indicate False = 0 roundTo :: RealFrac a => Int -> a -> a roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n) +discreteUniform :: PrimMonad m => [a] -> Gen (PrimState m) -> a +discreteUniform [] g = error "discreteUniform: no candidates" +discreteUniform zs g = do + z <- uniform g + return $ zs !! (truncate $ z * fromIntegral (length zs)) + + -- Deprecated ----------------------------------------------------------------- basicBuildTree @@ -188,7 +230,7 @@ basicBuildTree lTarget glTarget = go let (t1, r1) = leapfrog glTarget t r 1 (v * e) c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1) | otherwise = HashSet.empty - s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1 + s | u < exp dMax * auxilliaryTarget lTarget t1 r1 = 1 | otherwise = 0 in (t1, r1, t1, r1, c, s)