commit 7db25113e0a796e022dcd9acbbab3a33b26812e0
parent da4ff176919afb192278fb43fff40dff2dae54d4
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 1 Sep 2013 21:09:06 +1200
Add more sophisticated buildTree.
Diffstat:
M | NUTS.hs | | | 137 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------- |
1 file changed, 98 insertions(+), 39 deletions(-)
diff --git a/NUTS.hs b/NUTS.hs
@@ -46,41 +46,6 @@ hmcKernel lTarget glTarget t0 ndisc e g = do
| otherwise = (t0, r0)
return final
--- note that this is not the greatest buildTree we could use
-buildTree
- :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
- => ([c] -> c) -- ^ Log target
- -> ([c] -> [c]) -- ^ Gradient
- -> [c] -- ^ Position
- -> [c] -- ^ Momentum
- -> c -- ^ Slice variable
- -> c -- ^ Direction (-1, +1)
- -> a -- ^ Depth
- -> c -- ^ Step size
- -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
-buildTree lTarget glTarget = go
- where
- go t r u v 0 e =
- let (t1, r1) = leapfrog glTarget t r 1 (v * e)
- c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
- | otherwise = HashSet.empty
- s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1
- | otherwise = 0
- in (t1, r1, t1, r1, c, s)
-
- go t r u v j e =
- let (tn, rn, tp, rp, c0, s0) = go t r u v (pred j) e
- (tnn, rnn, tpp, rpp, c1, s1) = if roundTo 6 v == -1
- then go tn rn u v (pred j) e
- else go tp rp u v (pred j) e
-
- s2 = s0 * s1 * indicator ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicator ((tpp .- tnn) `innerProduct` rpp >= 0)
-
- c2 = c0 `HashSet.union` c1
-
- in (tnn, rnn, tpp, rpp, c2, s2)
-
leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
=> ([c] -> [c]) -- ^ Gradient of log target function
-> [c] -- ^ List of parameters to target
@@ -117,7 +82,7 @@ findReasonableEpsilon lTarget glTarget t0 g = do
r0 <- replicateM (length t0) (normal 0 1 g)
let (t1, r1) = leapfrog glTarget t0 r0 1 1.0
- a = 2 * indicator (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+ a = 2 * indicate (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
let en = 2 ^ a * e
@@ -127,6 +92,64 @@ findReasonableEpsilon lTarget glTarget t0 g = do
return $ go 1.0 t1 r1
+-- go needs to return in some monad
+
+
+buildTree :: (Enum a, Eq a, Floating t, Fractional c, Integral c, Integral d
+ , Num a, Num e, RealFrac d, RealFrac t, PrimMonad m, Variate c)
+ => ([t] -> t)
+ -> ([t] -> [t])
+ -> Gen (PrimState m)
+ -> [t]
+ -> [t]
+ -> t
+ -> t
+ -> a
+ -> t
+ -> t1
+ -> [t]
+ -> m ([t], [t], [t], [t], [t], c, d, t, e)
+buildTree lTarget glTarget g = go
+ where
+ go t r u v 0 e _ r0 = return $
+ let (t1, r1) = leapfrog glTarget t r 1 (v * e)
+ n = indicate (u <= auxilliaryTarget lTarget t1 r1)
+ s = indicate (u < exp 1000 * auxilliaryTarget lTarget t1 r1)
+ m = min 1 (hmcAcceptanceRatio lTarget t1 r1 r0 r0)
+ in (t1, r1, t1, r1, t1, n, s, m, 1)
+
+ go t r u v j e t0 r0 = do
+ z <- uniform g
+ (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+
+ if roundTo 6 s1 == 1
+ then do
+ (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+ if roundTo 6 v == -1
+ then go tn rn u v (pred j) e t0 r0
+ else go tp rp u v (pred j) e t0 r0
+
+ let p = n2 / (n1 + n2)
+
+ t3 | p > z = t2
+ | otherwise = t1
+
+ a3 = a1 + a2
+ na3 = na1 + na2
+
+ s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+ n3 = n1 + n2
+ return $ (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+ else return $ (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+
+
+
+
+
+
+
@@ -138,11 +161,47 @@ innerProduct xs ys = sum $ zipWith (*) xs ys
(.-) :: Num a => [a] -> [a] -> [a]
xs .- ys = zipWith (-) xs ys
-indicator :: Integral a => Bool -> a
-indicator True = 1
-indicator False = 0
+indicate :: Integral a => Bool -> a
+indicate True = 1
+indicate False = 0
-- | Round to a specified number of digits.
roundTo :: RealFrac a => Int -> a -> a
roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)
+-- Deprecated -----------------------------------------------------------------
+
+basicBuildTree
+ :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
+ => ([c] -> c) -- ^ Log target
+ -> ([c] -> [c]) -- ^ Gradient
+ -> [c] -- ^ Position
+ -> [c] -- ^ Momentum
+ -> c -- ^ Slice variable
+ -> c -- ^ Direction (-1, +1)
+ -> a -- ^ Depth
+ -> c -- ^ Step size
+ -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
+basicBuildTree lTarget glTarget = go
+ where
+ go t r u v 0 e =
+ let (t1, r1) = leapfrog glTarget t r 1 (v * e)
+ c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
+ | otherwise = HashSet.empty
+ s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1
+ | otherwise = 0
+ in (t1, r1, t1, r1, c, s)
+
+ go t r u v j e =
+ let (tn, rn, tp, rp, c0, s0) = go t r u v (pred j) e
+ (tnn, rnn, tpp, rpp, c1, s1) = if roundTo 6 v == -1
+ then go tn rn u v (pred j) e
+ else go tp rp u v (pred j) e
+
+ s2 = s0 * s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+ c2 = c0 `HashSet.union` c1
+
+ in (tnn, rnn, tpp, rpp, c2, s2)
+