commit a697801d23ea539e52a1acb3c58ae9552a193692
parent d078d0f36cca92328ed685ccbc90700ee1e3ce65
Author: Jared Tobin <jared@jtobin.io>
Date: Thu, 21 May 2020 19:07:15 +0400
Formatting nits.
Diffstat:
1 file changed, 19 insertions(+), 18 deletions(-)
diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs
@@ -71,7 +71,9 @@ metropolis radial tunable = do
acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore)))
accept <- lift (MWC.bernoulli acceptProb)
- when accept (put (Chain chainTarget proposalScore proposal (tunable <*> Just proposal)))
+ when accept $ do
+ let tuned = tunable <*> Just proposal
+ put (Chain chainTarget proposalScore proposal tuned)
-- Drive a Markov chain via the Metropolis transition operator.
drive
@@ -83,12 +85,13 @@ drive
-> Producer (Chain (f Double) b) m c
drive radial tunable = loop where
loop state prng = do
- next <- lift (MWC.sample (execStateT (metropolis radial tunable) state) prng)
+ let rvar = execStateT (metropolis radial tunable) state
+ next <- lift (MWC.sample rvar prng)
yield next
loop next prng
-- | Return a list of @Chain@ values potentially with tunable values computed
--- from each position.
+-- from each position.
chain' ::
(PrimMonad m, Traversable f)
=> Int
@@ -99,16 +102,15 @@ chain' ::
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain' n radial position target tunable gen =
- runEffect $ drive radial tunable origin gen >-> collect n
+ runEffect $ drive radial tunable origin gen >-> collect n
where
ctarget = Target target Nothing
- origin =
- Chain
- { chainScore = lTarget ctarget position
- , chainTunables = tunable <*> Just position
- , chainTarget = ctarget
- , chainPosition = position
- }
+ origin = Chain
+ { chainScore = lTarget ctarget position
+ , chainTunables = tunable <*> Just position
+ , chainTarget = ctarget
+ , chainPosition = position
+ }
collect :: Monad m => Int -> Consumer a m [a]
collect size = lowerCodensity $ replicateM size (lift Pipes.await)
@@ -122,14 +124,13 @@ chain' n radial position target tunable gen =
-- 3.8379699517007895e-3,0.24627131099479127
chain
:: (PrimMonad m, Traversable f)
- => Int -- ^ Number of iterations
- -> Double -- ^ Step standard deviation
- -> f Double -- ^ Starting position
- -> (f Double -> Double) -- ^ The log-density (up to additive constant)
- -> Gen (PrimState m)
+ => Int -- ^ Number of iterations
+ -> Double -- ^ Step standard deviation
+ -> f Double -- ^ Starting position
+ -> (f Double -> Double) -- ^ Log-density (to additive constant)
+ -> Gen (PrimState m) -- ^ PRNG
-> m [Chain (f Double) b]
-chain n radial position target =
- chain' n radial position target Nothing
+chain n radial position target = chain' n radial position target Nothing
-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--