hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 7db25113e0a796e022dcd9acbbab3a33b26812e0
parent da4ff176919afb192278fb43fff40dff2dae54d4
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun,  1 Sep 2013 21:09:06 +1200

Add more sophisticated buildTree.

Diffstat:
MNUTS.hs | 137++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------
1 file changed, 98 insertions(+), 39 deletions(-)

diff --git a/NUTS.hs b/NUTS.hs @@ -46,41 +46,6 @@ hmcKernel lTarget glTarget t0 ndisc e g = do | otherwise = (t0, r0) return final --- note that this is not the greatest buildTree we could use -buildTree - :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c) - => ([c] -> c) -- ^ Log target - -> ([c] -> [c]) -- ^ Gradient - -> [c] -- ^ Position - -> [c] -- ^ Momentum - -> c -- ^ Slice variable - -> c -- ^ Direction (-1, +1) - -> a -- ^ Depth - -> c -- ^ Step size - -> ([c], [c], [c], [c], HashSet ([c], [c]), t) -buildTree lTarget glTarget = go - where - go t r u v 0 e = - let (t1, r1) = leapfrog glTarget t r 1 (v * e) - c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1) - | otherwise = HashSet.empty - s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1 - | otherwise = 0 - in (t1, r1, t1, r1, c, s) - - go t r u v j e = - let (tn, rn, tp, rp, c0, s0) = go t r u v (pred j) e - (tnn, rnn, tpp, rpp, c1, s1) = if roundTo 6 v == -1 - then go tn rn u v (pred j) e - else go tp rp u v (pred j) e - - s2 = s0 * s1 * indicator ((tpp .- tnn) `innerProduct` rnn >= 0) - * indicator ((tpp .- tnn) `innerProduct` rpp >= 0) - - c2 = c0 `HashSet.union` c1 - - in (tnn, rnn, tpp, rpp, c2, s2) - leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a) => ([c] -> [c]) -- ^ Gradient of log target function -> [c] -- ^ List of parameters to target @@ -117,7 +82,7 @@ findReasonableEpsilon lTarget glTarget t0 g = do r0 <- replicateM (length t0) (normal 0 1 g) let (t1, r1) = leapfrog glTarget t0 r0 1 1.0 - a = 2 * indicator (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 + a = 2 * indicate (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) = let en = 2 ^ a * e @@ -127,6 +92,64 @@ findReasonableEpsilon lTarget glTarget t0 g = do return $ go 1.0 t1 r1 +-- go needs to return in some monad + + +buildTree :: (Enum a, Eq a, Floating t, Fractional c, Integral c, Integral d + , Num a, Num e, RealFrac d, RealFrac t, PrimMonad m, Variate c) + => ([t] -> t) + -> ([t] -> [t]) + -> Gen (PrimState m) + -> [t] + -> [t] + -> t + -> t + -> a + -> t + -> t1 + -> [t] + -> m ([t], [t], [t], [t], [t], c, d, t, e) +buildTree lTarget glTarget g = go + where + go t r u v 0 e _ r0 = return $ + let (t1, r1) = leapfrog glTarget t r 1 (v * e) + n = indicate (u <= auxilliaryTarget lTarget t1 r1) + s = indicate (u < exp 1000 * auxilliaryTarget lTarget t1 r1) + m = min 1 (hmcAcceptanceRatio lTarget t1 r1 r0 r0) + in (t1, r1, t1, r1, t1, n, s, m, 1) + + go t r u v j e t0 r0 = do + z <- uniform g + (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0 + + if roundTo 6 s1 == 1 + then do + (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <- + if roundTo 6 v == -1 + then go tn rn u v (pred j) e t0 r0 + else go tp rp u v (pred j) e t0 r0 + + let p = n2 / (n1 + n2) + + t3 | p > z = t2 + | otherwise = t1 + + a3 = a1 + a2 + na3 = na1 + na2 + + s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + + n3 = n1 + n2 + return $ (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3) + else return $ (tn, rn, tp, rp, t1, n1, s1, a1, na1) + + + + + + + @@ -138,11 +161,47 @@ innerProduct xs ys = sum $ zipWith (*) xs ys (.-) :: Num a => [a] -> [a] -> [a] xs .- ys = zipWith (-) xs ys -indicator :: Integral a => Bool -> a -indicator True = 1 -indicator False = 0 +indicate :: Integral a => Bool -> a +indicate True = 1 +indicate False = 0 -- | Round to a specified number of digits. roundTo :: RealFrac a => Int -> a -> a roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n) +-- Deprecated ----------------------------------------------------------------- + +basicBuildTree + :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c) + => ([c] -> c) -- ^ Log target + -> ([c] -> [c]) -- ^ Gradient + -> [c] -- ^ Position + -> [c] -- ^ Momentum + -> c -- ^ Slice variable + -> c -- ^ Direction (-1, +1) + -> a -- ^ Depth + -> c -- ^ Step size + -> ([c], [c], [c], [c], HashSet ([c], [c]), t) +basicBuildTree lTarget glTarget = go + where + go t r u v 0 e = + let (t1, r1) = leapfrog glTarget t r 1 (v * e) + c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1) + | otherwise = HashSet.empty + s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1 + | otherwise = 0 + in (t1, r1, t1, r1, c, s) + + go t r u v j e = + let (tn, rn, tp, rp, c0, s0) = go t r u v (pred j) e + (tnn, rnn, tpp, rpp, c1, s1) = if roundTo 6 v == -1 + then go tn rn u v (pred j) e + else go tp rp u v (pred j) e + + s2 = s0 * s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0) + * indicate ((tpp .- tnn) `innerProduct` rpp >= 0) + + c2 = c0 `HashSet.union` c1 + + in (tnn, rnn, tpp, rpp, c2, s2) +