commit 94a6b08622908a5a2d76939fcec82200bebb9b29
parent 6ab413d748616ae48e0ddaf659f5e2728c7d3424
Author: Jared Tobin <jared@jtobin.ca>
Date: Thu, 22 Dec 2016 10:08:19 +1300
Add 'chain' function.
Diffstat:
4 files changed, 36 insertions(+), 5 deletions(-)
diff --git a/CHANGELOG b/CHANGELOG
@@ -1,5 +1,8 @@
# Changelog
+- 0.5.0 (2016-12-22)
+ * Add 'chain' function for working with chains in memory.
+
- 0.4.0 (2016-12-20)
* Generalize base monad requirement to something matching both MonadIO and
PrimState.
diff --git a/declarative.cabal b/declarative.cabal
@@ -1,5 +1,5 @@
name: declarative
-version: 0.4.0
+version: 0.5.0
synopsis: DIY Markov Chains.
homepage: http://github.com/jtobin/declarative
license: MIT
@@ -62,6 +62,7 @@ library
, Numeric.MCMC.Anneal
build-depends:
base >= 4 && < 6
+ , kan-extensions >= 5 && < 6
, mcmc-types >= 1.0.1
, mwc-probability >= 1.0.1
, mighty-metropolis >= 1.0.1
diff --git a/lib/Numeric/MCMC.hs b/lib/Numeric/MCMC.hs
@@ -69,6 +69,7 @@ module Numeric.MCMC (
, frequency
, anneal
, mcmc
+ , chain
-- * Re-exported
, module Data.Sampling.Types
@@ -87,6 +88,8 @@ module Numeric.MCMC (
, RealWorld
) where
+import Control.Monad (replicateM)
+import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld)
import Control.Monad.Trans.State.Strict (execStateT)
@@ -161,21 +164,43 @@ mcmc
-> Gen (PrimState m)
-> m ()
mcmc n chainPosition transition chainTarget gen = runEffect $
- chain transition Chain {..} gen
+ drive transition Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
--- A Markov chain driven by an arbitrary transition operator.
+-- | Trace 'n' iterations of a Markov chain and collect them in a list.
+--
+-- >>> results <- withSystemRandom . asGenIO $ chain 3 [0, 0] (metropolis 0.5) rosenbrock
chain
+ :: (MonadIO m, PrimMonad m)
+ => Int
+ -> t a
+ -> Transition m (Chain (t a) b)
+ -> Target (t a)
+ -> Gen (PrimState m)
+ -> m [Chain (t a) b]
+chain n chainPosition transition chainTarget gen = runEffect $
+ drive transition Chain {..} gen
+ >-> collect n
+ where
+ chainScore = lTarget chainTarget chainPosition
+ chainTunables = Nothing
+
+ collect :: Monad m => Int -> Consumer a m [a]
+ collect size = lowerCodensity $
+ replicateM size (lift Pipes.await)
+
+-- A Markov chain driven by an arbitrary transition operator.
+drive
:: PrimMonad m
=> Transition m b
-> b
-> Gen (PrimState m)
-> Producer b m a
-chain transition = loop where
+drive transition = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT transition state) prng)
yield next
diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs
@@ -17,5 +17,7 @@ transition =
(sampleT (slice 2.0) (slice 3.0))
main :: IO ()
-main = withSystemRandom . asGenIO $ mcmc 100 [0, 0] transition rosenbrock
+main = withSystemRandom . asGenIO $ \gen -> do
+ _ <- chain 100 [0, 0] transition rosenbrock gen
+ mcmc 100 [0, 0] transition rosenbrock gen