commit cc4df614316dd34b94823b42dfd86913d5b5a4af
parent 1c74574cc0576fc48cc213b2566ad7d9df2e2977
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 8 Sep 2013 21:52:34 +1200
Specialize types, get inner loop working.
Diffstat:
A | HMC.hs | | | 224 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
M | NUTS.hs | | | 318 | +++++++++++++++++++++++++++++-------------------------------------------------- |
2 files changed, 340 insertions(+), 202 deletions(-)
diff --git a/HMC.hs b/HMC.hs
@@ -0,0 +1,224 @@
+-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
+-- Lengths in Hamiltonian Monte Carlo.
+
+{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+
+import Control.Monad
+import Control.Monad.Loops
+import Control.Monad.Primitive
+import System.Random.MWC
+import System.Random.MWC.Distributions
+
+-- TODO what am i
+dMax :: Num t => t
+dMax = 1000
+
+hmc :: (Enum a, Eq a, Ord a, Num a, PrimMonad m )
+ => ([Double] -> Double) -- ^ Log target function
+ -> ([Double] -> [Double]) -- ^ Gradient of log target
+ -> [Double] -- ^ Parameters
+ -> a -- ^ Epochs to run the chain
+ -> a -- ^ Number of discretizing steps
+ -> Double -- ^ Step size
+ -> Gen (PrimState m) -- ^ PRNG
+ -> m [[Double]] -- ^ Chain
+hmc lTarget glTarget t n ndisc e g = unfoldrM kernel (n, (t, []))
+ where
+ kernel (m, (p, _)) = do
+ (p1, r1) <- hmcKernel lTarget glTarget p ndisc e g
+ return $ if m <= 0
+ then Nothing
+ else Just (p1, (pred m, (p1, r1)))
+
+hmcKernel :: (Enum a, Eq a, Ord a, Num a, PrimMonad m)
+ => ([Double] -> Double) -- ^ Log target function
+ -> ([Double] -> [Double]) -- ^ Gradient of log target
+ -> [Double] -- ^ Parameters
+ -> a -- ^ Number of discretizing steps
+ -> Double -- ^ Step size
+ -> Gen (PrimState m) -- ^ PRNG
+ -> m ([Double], [Double]) -- ^ m (End params, end momenta)
+hmcKernel lTarget glTarget t0 ndisc e g = do
+ r0 <- replicateM (length t0) (normal 0 1 g)
+ z <- uniform g
+ let (t1, r1) = leapfrog glTarget t0 r0 ndisc e
+ a = min 1 $ hmcAcceptanceRatio lTarget t0 t1 r0 r1
+ final | a > z = (t1, map negate r1)
+ | otherwise = (t0, r0)
+ return final
+
+leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
+ => ([c] -> [c]) -- ^ Gradient of log target function
+ -> [c] -- ^ List of parameters to target
+ -> [c] -- ^ Momentum variables
+ -> a -- ^ Number of discretizing steps
+ -> c -- ^ Step size
+ -> ([c], [c]) -- ^ (End parameters, end momenta)
+leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0)
+ | otherwise = go t0 r0 ndisc
+ where go t r 0 = (t, r)
+ go t r n = let rm = zipWith (+) r (map (* (0.5 * e)) (glTarget t))
+ tt = zipWith (+) t (map (* e) rm)
+ rt = zipWith (+) rm (map (* (0.5 * e)) (glTarget t))
+ in go tt rt (pred n)
+
+-- | Acceptance ratio for a proposed move. t0/r0 denote the present state of
+-- the parameters and auxilliary variables, and t1/r1 denote the proposed
+-- state.
+hmcAcceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
+hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
+ / auxilliaryTarget lTarget t0 r0
+
+-- | Augment a log target with some auxilliary variables.
+auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
+auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+
+findReasonableEpsilon :: PrimMonad m
+ => ([Double] -> Double)
+ -> ([Double] -> [Double])
+ -> [Double]
+ -> Gen (PrimState m)
+ -> m Double
+findReasonableEpsilon lTarget glTarget t0 g = do
+ r0 <- replicateM (length t0) (normal 0 1 g)
+ let (t1, r1) = leapfrog glTarget t0 r0 1 1.0
+
+ a = 2 * indicate (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+
+ go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
+ let en = 2 ^ a * e
+ (tn, rn) = leapfrog glTarget t r 1 e
+ in go en tn rn
+ | otherwise = e
+
+ return $ go 1.0 t1 r1
+
+-- problem
+buildTree :: (Enum a, Eq a, Floating t, Fractional c, Integral c, Integral d
+ , Num a, Num e, RealFrac d, RealFrac t, PrimMonad m , Variate c)
+ => ([t] -> t)
+ -> ([t] -> [t])
+ -> Gen (PrimState m)
+ -> [t]
+ -> [t]
+ -> t
+ -> t
+ -> a
+ -> t
+ -> t1
+ -> [t]
+ -> m ([t], [t], [t], [t], [t], c, d, t, e)
+buildTree lTarget glTarget g = go
+ where
+ go t r u v 0 e _ r0 = return $
+ let (t1, r1) = leapfrog glTarget t r 1 (v * e)
+ n = indicate (u <= auxilliaryTarget lTarget t1 r1)
+ s = indicate (u < exp dMax * auxilliaryTarget lTarget t1 r1)
+ m = min 1 (hmcAcceptanceRatio lTarget t1 r1 r0 r0)
+ in (t1, r1, t1, r1, t1, n, s, m, 1)
+
+ go t r u v j e t0 r0 = do
+ z <- uniform g
+ (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
+
+ if s1 == 1
+ then do
+ (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
+ if v == -1
+ then go tn rn u v (pred j) e t0 r0
+ else go tp rp u v (pred j) e t0 r0
+
+ let p = n2 / (n1 + n2)
+
+ t3 | p > z = t2
+ | otherwise = t1
+
+ a3 = a1 + a2
+ na3 = na1 + na2
+
+ s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+ n3 = n1 + n2
+ return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+ else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+
+innerNutsKernel :: (PrimMonad m, Variate b)
+ => ([a] -> Double)
+ -> t
+ -> [a]
+ -> c
+ -> Gen (PrimState m)
+ -> m b
+innerNutsKernel lTarget glTarget t e g = do
+ r0 <- replicateM (length t) (normal 0 1 g)
+ u <- uniformR (0, auxilliaryTarget lTarget t r0) g
+
+ let go (tn, tp, rn, rp, j, tm, n, s) a b gen = do
+ vj <- symmetricCategorical [-1, 1] gen
+ z <- uniform gen
+
+ return z
+-- let go (tn, tp, rn, rp, j, tm, n, s) aOrig naOrig g
+-- | s == 1 = do
+-- vj <- symmetricCategorical [-1, 1] g
+-- z <- uniform g
+--
+-- (tnn, rnn, tpp, rpp, t1, n1, s, a, na) <-
+-- buildTree lTarget glTarget g tn rn u vj j e t r0 -- FIXME
+--
+-- return $ (t1, a, na)
+
+ go (t, t, r0, r0, 0, t, 1, 1) 0 0 g
+
+ -- let go (tn, tp, rn, rp, j, tm, n, s) aOrig naOrig g
+ -- | s == 1 = do
+ -- vj <- symmetricCategorical [-1, 1] g
+ -- z <- uniform g
+
+ -- (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
+ -- if vj == -1
+ -- then buildTree lTarget glTarget g tn rn u vj j e t r0
+ -- else buildTree lTarget glTarget g tp rp u vj j e t r0
+
+ -- let t2 | s1 == 1 && (min 1 (fromIntegral n1 / fromIntegral n :: Double) > z) = tnn
+ -- | otherwise = t
+
+ -- n2 = n + n1
+ -- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ -- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+ -- j1 = succ j
+
+ -- go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) a na g
+
+ -- | otherwise = return (tm, aOrig, naOrig)
+
+ -- return $ go (t, t, r0, r0, 0, t, 1, 1) 0 0 g
+
+
+
+
+
+-- Utilities ------------------------------------------------------------------
+
+innerProduct :: Num a => [a] -> [a] -> a
+innerProduct xs ys = sum $ zipWith (*) xs ys
+
+(.-) :: Num a => [a] -> [a] -> [a]
+xs .- ys = zipWith (-) xs ys
+
+indicate :: Integral a => Bool -> a
+indicate True = 1
+indicate False = 0
+
+-- | Round to a specified number of digits.
+roundTo :: RealFrac a => Int -> a -> a
+roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)
+
+symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
+symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
+symmetricCategorical zs g = do
+ z <- uniform g
+ return $ zs !! truncate (z * fromIntegral (length zs) :: Double)
+
diff --git a/NUTS.hs b/NUTS.hs
@@ -1,198 +1,142 @@
-- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path
-- Lengths in Hamiltonian Monte Carlo.
-{-# OPTIONS_GHC -Wall -fno-warn-type-defaults #-}
-
import Control.Monad
-import Control.Monad.Loops
import Control.Monad.Primitive
-import Data.Hashable
-import Data.HashSet (HashSet)
-import qualified Data.HashSet as HashSet
import System.Random.MWC
import System.Random.MWC.Distributions
+import Statistics.Distribution.Normal
--- TODO what am i
-dMax :: Num t => t
-dMax = 1000
-
-hmc :: (Enum a, Eq a, Ord a, Num a, PrimMonad m )
- => ([Double] -> Double) -- ^ Log target function
- -> ([Double] -> [Double]) -- ^ Gradient of log target
- -> [Double] -- ^ Parameters
- -> a -- ^ Epochs to run the chain
- -> a -- ^ Number of discretizing steps
- -> Double -- ^ Step size
- -> Gen (PrimState m) -- ^ PRNG
- -> m [[Double]] -- ^ Chain
-hmc lTarget glTarget t n ndisc e g = unfoldrM kernel (n, (t, []))
- where
- kernel (m, (p, _)) = do
- (p1, r1) <- hmcKernel lTarget glTarget p ndisc e g
- return $ if m <= 0
- then Nothing
- else Just (p1, (pred m, (p1, r1)))
-
-hmcKernel :: (Enum a, Eq a, Ord a, Num a, PrimMonad m)
- => ([Double] -> Double) -- ^ Log target function
- -> ([Double] -> [Double]) -- ^ Gradient of log target
- -> [Double] -- ^ Parameters
- -> a -- ^ Number of discretizing steps
- -> Double -- ^ Step size
- -> Gen (PrimState m) -- ^ PRNG
- -> m ([Double], [Double]) -- ^ m (End params, end momenta)
-hmcKernel lTarget glTarget t0 ndisc e g = do
- r0 <- replicateM (length t0) (normal 0 1 g)
- z <- uniform g
- let (t1, r1) = leapfrog glTarget t0 r0 ndisc e
- a = min 1 $ hmcAcceptanceRatio lTarget t0 t1 r0 r1
- final | a > z = (t1, map negate r1)
- | otherwise = (t0, r0)
- return final
-
-leapfrog :: (Enum a, Eq a, Ord a, Fractional c, Num a)
- => ([c] -> [c]) -- ^ Gradient of log target function
- -> [c] -- ^ List of parameters to target
- -> [c] -- ^ Momentum variables
- -> a -- ^ Number of discretizing steps
- -> c -- ^ Step size
- -> ([c], [c]) -- ^ (End parameters, end momenta)
-leapfrog glTarget t0 r0 ndisc e | ndisc < 0 = (t0, r0)
- | otherwise = go t0 r0 ndisc
- where go t r 0 = (t, r)
- go t r n = let rm = zipWith (+) r (map (* (0.5 * e)) (glTarget t))
- tt = zipWith (+) t (map (* e) rm)
- rt = zipWith (+) rm (map (* (0.5 * e)) (glTarget t))
- in go tt rt (pred n)
-
--- | Acceptance ratio for a proposed move. t0/r0 denote the present state of
--- the parameters and auxilliary variables, and t1/r1 denote the proposed
--- state.
-hmcAcceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
-hmcAcceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
- / auxilliaryTarget lTarget t0 r0
-
--- | Augment a log target with some auxilliary variables.
-auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
-auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+type Parameters = [Double]
+type Density = Parameters -> Double
+type Gradient = Parameters -> Parameters
+type Particle = (Parameters, Parameters)
+
+leapfrogIntegrator :: Int -> Gradient -> Particle -> Double -> Particle
+leapfrogIntegrator n glTarget particle e = go particle n
+ where go state ndisc
+ | ndisc <= 0 = state
+ | otherwise = go (leapfrog glTarget state e) (pred n)
+
+leapfrog :: Gradient -> Particle -> Double -> Particle
+leapfrog glTarget (t, r) e = (tf, rf)
+ where rm = zipWith (+) r ((e / 2) .* glTarget t)
+ tf = zipWith (+) t (e .* rm)
+ rf = zipWith (+) rm ((e / 2) .* glTarget tf)
findReasonableEpsilon :: PrimMonad m
- => ([Double] -> Double)
- -> ([Double] -> [Double])
- -> [Double]
+ => Density
+ -> Gradient
+ -> Parameters
-> Gen (PrimState m)
-> m Double
findReasonableEpsilon lTarget glTarget t0 g = do
r0 <- replicateM (length t0) (normal 0 1 g)
- let (t1, r1) = leapfrog glTarget t0 r0 1 1.0
-
- a = 2 * indicate (hmcAcceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
+ let (t1, r1) = leapfrog glTarget (t0, r0) 1.0
+ a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1
- go e t r | (hmcAcceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
- let en = 2 ^ a * e
- (tn, rn) = leapfrog glTarget t r 1 e
- in go en tn rn
+ go e t r | (acceptanceRatio lTarget t0 t r0 r) ^ a > 2 ^ (-a) =
+ let (tn, rn) = leapfrog glTarget (t, r) e
+ in go (2 ^ a * e) tn rn
| otherwise = e
return $ go 1.0 t1 r1
-buildTree :: (Enum a, Eq a, Floating t, Fractional c, Integral c, Integral d
- , Num a, Num e, RealFrac d, RealFrac t, PrimMonad m, Variate c)
- => ([t] -> t)
- -> ([t] -> [t])
+buildTree
+ :: PrimMonad m
+ => Density
+ -> Gradient
-> Gen (PrimState m)
- -> [t]
- -> [t]
- -> t
- -> t
- -> a
- -> t
- -> t1
- -> [t]
- -> m ([t], [t], [t], [t], [t], c, d, t, e)
-buildTree lTarget glTarget g = go
- where
+ -> Parameters
+ -> Parameters
+ -> Double
+ -> Double
+ -> Int
+ -> Double
+ -> Parameters
+ -> Parameters
+ -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int)
+buildTree lTarget glTarget g = go
+ where
go t r u v 0 e _ r0 = return $
- let (t1, r1) = leapfrog glTarget t r 1 (v * e)
+ let (t1, r1) = leapfrog glTarget (t, r) (v * e)
n = indicate (u <= auxilliaryTarget lTarget t1 r1)
- s = indicate (u < exp dMax * auxilliaryTarget lTarget t1 r1)
- m = min 1 (hmcAcceptanceRatio lTarget t1 r1 r0 r0)
+ s = indicate (u < exp 1000 * auxilliaryTarget lTarget t1 r1)
+ m = min 1 (acceptanceRatio lTarget t1 r1 r0 r0)
in (t1, r1, t1, r1, t1, n, s, m, 1)
go t r u v j e t0 r0 = do
z <- uniform g
(tn, rn, tp, rp, t1, n1, s1, a1, na1) <- go t r u v (pred j) e t0 r0
- if roundTo 6 s1 == 1
+ if s1 == 1
then do
(tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <-
- if roundTo 6 v == -1
+ if v == -1
then go tn rn u v (pred j) e t0 r0
else go tp rp u v (pred j) e t0 r0
-
- let p = n2 / (n1 + n2)
-
- t3 | p > z = t2
- | otherwise = t1
-
- a3 = a1 + a2
+
+ let p = fromIntegral n2 / fromIntegral (n1 + n2)
+ n3 = n1 + n2
+ t3 | p > (z :: Double) = t2
+ | otherwise = t1
+ a3 = a1 + a2
na3 = na1 + na2
-
s3 = s2 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
-
- n3 = n1 + n2
- return $ (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
- else return $ (tn, rn, tp, rp, t1, n1, s1, a1, na1)
-
-data AdaptiveState = Adapting | Resting deriving (Eq, Show)
-
--- TODO get this compiling
--- nutsKernel lTarget glTarget t d adaptiveState e h0 lEmBar0 g = do
--- r0 <- replicateM (length t) (normal 0 1 g)
--- u <- uniformR (0, auxilliaryTarget t r0) g
---
--- let (tn, tp, rn, rp, j, tm, n, s) = (t, t, r0, r0, 0, t, 1, 1)
---
--- go i tt | i == 1 = do
--- v <- discreteUniform [-1, 1] g
--- z <- uniform g
---
--- (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
--- if v == -1
--- then buildTree tn rn u v j (e * t) r0
--- else buildTree tp rp u v j (e * t) r0
---
---
--- let t2 | min (1) (n1 / n) > z = tnn
--- | otherwise = t
---
--- n2 = n + n1
--- s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
--- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
--- j1 = succ j
---
--- return $ go i t2
---
--- | otherwise = return tt
---
--- tSpun <- go s t
---
--- if adaptiveState == Adapting
--- then let hmBar = (1 - 1 / (m + t0)) * h0 + (1 / (m + t0)) * (d - a / na) -- need iteration counter
--- lEm = mu - (sqrt m / gam) * hmBar
--- lEmBar = m ^ (-kappa) * lEm + (1 - m ^ (-kappa)) * lEmBar0
--- else let em = emAdaptBar
---
--- return $ (tSpun, hmBar,
-
-
-
--- Utilities ------------------------------------------------------------------
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+
+ return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3)
+ else return (tn, rn, tp, rp, t1, n1, s1, a1, na1)
+
+innerNutsKernel
+ :: PrimMonad m
+ => Density
+ -> Gradient
+ -> Parameters
+ -> Double
+ -> Gen (PrimState m)
+ -> m (Parameters, Double, Int)
+innerNutsKernel lTarget glTarget t e g = do
+ r0 <- replicateM (length t) (normal 0 1 g)
+ u <- uniformR (0, auxilliaryTarget lTarget t r0) g
+
+ let go (tn, tp, rn, rp, j, tm, n, s) aOrig naOrig g
+ | s == 1 = do
+ vj <- symmetricCategorical [-1, 1] g
+ z <- uniform g
+
+ (tnn, rnn, tpp, rpp, t1, n1, s1, a, na) <-
+ if vj == -1
+ then buildTree lTarget glTarget g tn rn u vj j e t r0
+ else buildTree lTarget glTarget g tp rp u vj j e t r0
+
+ let t2 | s1 == 1 && min 1 (fi n1 / fi n :: Double) > z = tnn
+ | otherwise = t
+
+ n2 = n + n1
+ s2 = s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
+ * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
+ j1 = succ j
+
+ go (tnn, rnn, tpp, rpp, j1, t2, n2, s2) a na g
+
+ | otherwise = return (tm, aOrig, naOrig)
+
+ go (t, t, r0, r0, 0, t, 1, 1) 0 0 g
+
+auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a
+auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r)
+
+acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a
+acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1
+ / auxilliaryTarget lTarget t0 r0
innerProduct :: Num a => [a] -> [a] -> a
innerProduct xs ys = sum $ zipWith (*) xs ys
+(.*) :: Num b => b -> [b] -> [b]
+z .* xs = map (* z) xs
+
(.-) :: Num a => [a] -> [a] -> [a]
xs .- ys = zipWith (-) xs ys
@@ -200,50 +144,20 @@ indicate :: Integral a => Bool -> a
indicate True = 1
indicate False = 0
--- | Round to a specified number of digits.
-roundTo :: RealFrac a => Int -> a -> a
-roundTo n f = fromIntegral (round $ f * (10 ^ n) :: Int) / (10.0 ^^ n)
-
-discreteUniform :: PrimMonad m => [a] -> Gen (PrimState m) -> a
-discreteUniform [] g = error "discreteUniform: no candidates"
-discreteUniform zs g = do
+symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a
+symmetricCategorical [] _ = error "symmetricCategorical: no candidates"
+symmetricCategorical zs g = do
z <- uniform g
- return $ zs !! (truncate $ z * fromIntegral (length zs))
-
-
--- Deprecated -----------------------------------------------------------------
-
-basicBuildTree
- :: (Enum a, Eq a, Floating c, Integral t, Num a, RealFrac c, Hashable c)
- => ([c] -> c) -- ^ Log target
- -> ([c] -> [c]) -- ^ Gradient
- -> [c] -- ^ Position
- -> [c] -- ^ Momentum
- -> c -- ^ Slice variable
- -> c -- ^ Direction (-1, +1)
- -> a -- ^ Depth
- -> c -- ^ Step size
- -> ([c], [c], [c], [c], HashSet ([c], [c]), t)
-basicBuildTree lTarget glTarget = go
- where
- go t r u v 0 e =
- let (t1, r1) = leapfrog glTarget t r 1 (v * e)
- c | u <= auxilliaryTarget lTarget t1 r1 = HashSet.singleton (t1, r1)
- | otherwise = HashSet.empty
- s | u < exp dMax * auxilliaryTarget lTarget t1 r1 = 1
- | otherwise = 0
- in (t1, r1, t1, r1, c, s)
-
- go t r u v j e =
- let (tn, rn, tp, rp, c0, s0) = go t r u v (pred j) e
- (tnn, rnn, tpp, rpp, c1, s1) = if roundTo 6 v == -1
- then go tn rn u v (pred j) e
- else go tp rp u v (pred j) e
-
- s2 = s0 * s1 * indicate ((tpp .- tnn) `innerProduct` rnn >= 0)
- * indicate ((tpp .- tnn) `innerProduct` rpp >= 0)
-
- c2 = c0 `HashSet.union` c1
-
- in (tnn, rnn, tpp, rpp, c2, s2)
+ return $ zs !! truncate (z * fromIntegral (length zs) :: Double)
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+
+-- Testing
+
+f :: Density
+f _ = log $ 1 / 10
+
+g :: Gradient
+g xs = replicate (length xs) 1