declarative

DIY Markov Chains
Log | Files | Refs | README | LICENSE

commit 94a6b08622908a5a2d76939fcec82200bebb9b29
parent 6ab413d748616ae48e0ddaf659f5e2728c7d3424
Author: Jared Tobin <jared@jtobin.ca>
Date:   Thu, 22 Dec 2016 10:08:19 +1300

Add 'chain' function.

Diffstat:
MCHANGELOG | 3+++
Mdeclarative.cabal | 3++-
Mlib/Numeric/MCMC.hs | 31++++++++++++++++++++++++++++---
Mtest/Rosenbrock.hs | 4+++-
4 files changed, 36 insertions(+), 5 deletions(-)

diff --git a/CHANGELOG b/CHANGELOG @@ -1,5 +1,8 @@ # Changelog +- 0.5.0 (2016-12-22) + * Add 'chain' function for working with chains in memory. + - 0.4.0 (2016-12-20) * Generalize base monad requirement to something matching both MonadIO and PrimState. diff --git a/declarative.cabal b/declarative.cabal @@ -1,5 +1,5 @@ name: declarative -version: 0.4.0 +version: 0.5.0 synopsis: DIY Markov Chains. homepage: http://github.com/jtobin/declarative license: MIT @@ -62,6 +62,7 @@ library , Numeric.MCMC.Anneal build-depends: base >= 4 && < 6 + , kan-extensions >= 5 && < 6 , mcmc-types >= 1.0.1 , mwc-probability >= 1.0.1 , mighty-metropolis >= 1.0.1 diff --git a/lib/Numeric/MCMC.hs b/lib/Numeric/MCMC.hs @@ -69,6 +69,7 @@ module Numeric.MCMC ( , frequency , anneal , mcmc + , chain -- * Re-exported , module Data.Sampling.Types @@ -87,6 +88,8 @@ module Numeric.MCMC ( , RealWorld ) where +import Control.Monad (replicateM) +import Control.Monad.Codensity (lowerCodensity) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld) import Control.Monad.Trans.State.Strict (execStateT) @@ -161,21 +164,43 @@ mcmc -> Gen (PrimState m) -> m () mcmc n chainPosition transition chainTarget gen = runEffect $ - chain transition Chain {..} gen + drive transition Chain {..} gen >-> Pipes.take n >-> Pipes.mapM_ (liftIO . print) where chainScore = lTarget chainTarget chainPosition chainTunables = Nothing --- A Markov chain driven by an arbitrary transition operator. +-- | Trace 'n' iterations of a Markov chain and collect them in a list. +-- +-- >>> results <- withSystemRandom . asGenIO $ chain 3 [0, 0] (metropolis 0.5) rosenbrock chain + :: (MonadIO m, PrimMonad m) + => Int + -> t a + -> Transition m (Chain (t a) b) + -> Target (t a) + -> Gen (PrimState m) + -> m [Chain (t a) b] +chain n chainPosition transition chainTarget gen = runEffect $ + drive transition Chain {..} gen + >-> collect n + where + chainScore = lTarget chainTarget chainPosition + chainTunables = Nothing + + collect :: Monad m => Int -> Consumer a m [a] + collect size = lowerCodensity $ + replicateM size (lift Pipes.await) + +-- A Markov chain driven by an arbitrary transition operator. +drive :: PrimMonad m => Transition m b -> b -> Gen (PrimState m) -> Producer b m a -chain transition = loop where +drive transition = loop where loop state prng = do next <- lift (MWC.sample (execStateT transition state) prng) yield next diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs @@ -17,5 +17,7 @@ transition = (sampleT (slice 2.0) (slice 3.0)) main :: IO () -main = withSystemRandom . asGenIO $ mcmc 100 [0, 0] transition rosenbrock +main = withSystemRandom . asGenIO $ \gen -> do + _ <- chain 100 [0, 0] transition rosenbrock gen + mcmc 100 [0, 0] transition rosenbrock gen