commit 1c74574cc0576fc48cc213b2566ad7d9df2e2977
parent 7db25113e0a796e022dcd9acbbab3a33b26812e0
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 1 Sep 2013 22:03:40 +1200
Add NUTS kernel skeleton.
Diffstat:
M | NUTS.hs | | | 70 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------- |
1 file changed, 56 insertions(+), 14 deletions(-)
diff --git a/NUTS.hs b/NUTS.hs
@@ -12,6 +12,10 @@ import qualified Data.HashSet as HashSet
import System.Random.MWC
import System.Random.MWC.Distributions
+-- TODO what am i
+dMax :: Num t => t
+dMax = 1000
+
hmc :: (Enum a, Eq a, Ord a, Num a, PrimMonad m )
=> ([Double] -> Double) -- ^ Log target function
-> ([Double] -> [Double]) -- ^ Gradient of log target
@@ -92,9 +96,6 @@ findReasonableEpsilon lTarget glTarget t0 g = do
return $ go 1.0 t1 r1
--- go needs to return in some monad
-
-
buildTree :: (Enum a, Eq a, Floating t, Fractional c, Integral c, Integral d
, Num a, Num e, RealFrac d, RealFrac t, PrimMonad m, Variate c)
=> ([t] -> t)
@@ -114,7 +115,7 @@ buildTree lTarget glTarget g = go
go t r u v 0 e _ r0 = return $
let (t1, r1) = leapfrog glTarget t r 1 (v * e)
n = indicate (u <= auxilliaryTarget lTarget t1 r1)
- s = indicate (u < exp 1000 * auxilliaryTarget lTarget t1 r1)
+ s = indicate (u < exp dMax * auxilliaryTarget lTarget t1 r1)
m = min 1 (hmcAcceptanceRatio lTarget t1 r1 r0 r0)
in (t1, r1, t1, r1, t1, n, s, m, 1)
@@ -137,19 +138,53 @@ buildTree lTarget glTarget g = go
a3 = a1 + a2
na3 = na1 + na2
- s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+ s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
- n3 = n1 + n2
+ n3 = n1 + n2
return $ (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
else return $ (tn, rn, tp, rp, t1, n1, s1, a1, na1)
-
-
-
-
-
-
+data AdaptiveState = Adapting | Resting deriving (Eq, Show)
+
+-- TODO get this compiling
+-- nutsKernel lTarget glTarget t d adaptiveState e h0 lEmBar0 g = do
+-- r0 <- replicateM (length t) (normal 0 1 g)
+-- u <- uniformR (0, auxilliaryTarget t r0) g
+--
+-- let (tn, tp, rn, rp, j, tm, n, s) = (t, t, r0, r0, 0, t, 1, 1)
+--
+-- go i tt | i == 1 = do
+-- v <- discreteUniform [-1, 1] g
+-- z <- uniform g
+--
+-- (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
+-- if v == -1
+-- then buildTree tn rn u v j (e * t) r0
+-- else buildTree tp rp u v j (e * t) r0
+--
+--
+-- let t2 | min (1) (n1 / n) > z = tnn
+-- | otherwise = t
+--
+-- n2 = n + n1
+-- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+-- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+-- j1 = succ j
+--
+-- return $ go i t2
+--
+-- | otherwise = return tt
+--
+-- tSpun <- go s t
+--
+-- if adaptiveState == Adapting
+-- then let hmBar = (1 - 1 / (m + t0)) * h0 + (1 / (m + t0)) * (d - a / na) -- need iteration counter
+-- lEm = mu - (sqrt m / gam) * hmBar
+-- lEmBar = m ^ (-kappa) * lEm + (1 - m ^ (-kappa)) * lEmBar0
+-- else let em = emAdaptBar
+--
+-- return $ (tSpun, hmBar,
@@ -169,6 +204,13 @@ indicate False = 0
roundTo :: RealFrac a => Int -> a -> a
roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)
+discreteUniform :: PrimMonad m => [a] -> Gen (PrimState m) -> a
+discreteUniform [] g = error "discreteUniform: no candidates"
+discreteUniform zs g = do
+ z <- uniform g
+ return $ zs !! (truncate $ z * fromIntegral (length zs))
+
+
-- Deprecated -----------------------------------------------------------------
basicBuildTree
@@ -188,7 +230,7 @@ basicBuildTree lTarget glTarget = go
let (t1, r1) = leapfrog glTarget t r 1 (v * e)
c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
| otherwise = HashSet.empty
- s | u < exp 1000 * auxilliaryTarget lTarget t1 r1 = 1
+ s | u < exp dMax * auxilliaryTarget lTarget t1 r1 = 1
| otherwise = 0
in (t1, r1, t1, r1, c, s)