mighty-metropolis

The classic Metropolis algorithm.
Log | Files | Refs | README | LICENSE

commit dbcb22fa16814592f8d45f8754174724047bcefc
parent 7304eeaf6f3f44f8902e8e3b156402fa830f0897
Author: Jared Tobin <jared@jtobin.ca>
Date:   Thu, 22 Dec 2016 09:14:01 +1300

Add 'chain' for working with results in memory.

Resolves #2.

Diffstat:
MNumeric/MCMC/Metropolis.hs | 52++++++++++++++++++++++++++++++++++++++++++++--------
MREADME.md | 5+++--
Mmighty-metropolis.cabal | 8+++++---
Mstack-travis.yaml | 2+-
Mstack.yaml | 2+-
Mtest/BNN.hs | 4+++-
Mtest/Rosenbrock.hs | 4+++-
7 files changed, 60 insertions(+), 17 deletions(-)

diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs @@ -21,6 +21,7 @@ module Numeric.MCMC.Metropolis ( mcmc + , chain , metropolis -- * Re-exported @@ -31,7 +32,8 @@ module Numeric.MCMC.Metropolis ( , MWC.asGenIO ) where -import Control.Monad (when) +import Control.Monad (when, replicateM) +import Control.Monad.Codensity (lowerCodensity) import Control.Monad.Primitive (PrimMonad, PrimState) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict (execStateT, get, put) @@ -40,8 +42,8 @@ import Data.Sampling.Types (Target(..), Chain(..), Transition) #if __GLASGOW_HASKELL__ < 710 import Data.Traversable (Traversable, traverse) #endif -import Pipes (Producer, yield, (>->), runEffect) -import qualified Pipes.Prelude as Pipes (mapM_, take) +import Pipes (Producer, Consumer, yield, (>->), runEffect, await) +import qualified Pipes.Prelude as Pipes (mapM_, take, map) import System.Random.MWC.Probability (Gen, Prob) import qualified System.Random.MWC.Probability as MWC @@ -69,19 +71,53 @@ metropolis radial = do accept <- lift (MWC.bernoulli acceptProb) when accept (put (Chain chainTarget proposalScore proposal chainTunables)) --- A Markov chain driven by the Metropolis transition operator. -chain +-- Drive a Markov chain via the Metropolis transition operator. +drive :: (Traversable f, PrimMonad m) => Double -> Chain (f Double) b -> Gen (PrimState m) - -> Producer (Chain (f Double) b) m () -chain radial = loop where + -> Producer (Chain (f Double) b) m c +drive radial = loop where loop state prng = do next <- lift (MWC.sample (execStateT (metropolis radial) state) prng) yield next loop next prng +-- | Trace 'n' iterations of a Markov chain and collect the results in a list. +-- +-- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) +-- >>> results <- withSystemRandom . asGenIO $ chain 3 1 [0, 0] rosenbrock +-- >>> mapM_ print results +-- [0.0,0.0] +-- [1.4754117657794871e-2,0.5033208261760778] +-- [3.8379699517007895e-3,0.24627131099479127] +chain + :: (PrimMonad m, Traversable f) + => Int + -> Double + -> f Double + -> (f Double -> Double) + -> Gen (PrimState m) + -> m [f Double] +chain n radial position target gen = runEffect $ + drive radial origin gen + >-> Pipes.map chainPosition + >-> collect n + where + ctarget = Target target Nothing + + origin = Chain { + chainScore = lTarget ctarget position + , chainTunables = Nothing + , chainTarget = ctarget + , chainPosition = position + } + + collect :: Monad m => Int -> Consumer a m [a] + collect size = lowerCodensity $ + replicateM size (lift Pipes.await) + -- | Trace 'n' iterations of a Markov chain and stream them to stdout. -- -- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) @@ -98,7 +134,7 @@ mcmc -> Gen (PrimState m) -> m () mcmc n radial chainPosition target gen = runEffect $ - chain radial Chain {..} gen + drive radial Chain {..} gen >-> Pipes.take n >-> Pipes.mapM_ (liftIO . print) where diff --git a/README.md b/README.md @@ -7,8 +7,9 @@ The classic Metropolis algorithm. Wander around parameter space according to a simple spherical Gaussian distribution. -Exports a `mcmc` function that prints a trace to stdout, as well as a -`metropolis` transition operator that can be used more generally. +Exports a `mcmc` function that prints a trace to stdout, a `chain` function for +collecting results in memory, and a `metropolis` transition operator that can +be used more generally. See the *test* directory for example usage. diff --git a/mighty-metropolis.cabal b/mighty-metropolis.cabal @@ -1,5 +1,5 @@ name: mighty-metropolis -version: 1.1.0 +version: 1.2.0 synopsis: The Metropolis algorithm. homepage: http://github.com/jtobin/mighty-metropolis license: MIT @@ -15,8 +15,9 @@ description: Wander around parameter space according to a simple spherical Gaussian distribution. . - Exports a 'mcmc' function that prints a trace to stdout, as well as a - 'metropolis' transition operator that can be used more generally. + Exports a 'mcmc' function that prints a trace to stdout, a 'chain' function + for collecting results in-memory, and a 'metropolis' transition operator that + can be used more generally. . > import Numeric.MCMC.Metropolis > @@ -38,6 +39,7 @@ library Numeric.MCMC.Metropolis build-depends: base >= 4 && < 6 + , kan-extensions >= 5 && < 6 , pipes >= 4 && < 5 , primitive >= 0.6 && < 1.0 , mcmc-types >= 1.0.1 diff --git a/stack-travis.yaml b/stack-travis.yaml @@ -2,7 +2,7 @@ flags: {} packages: - '.' extra-deps: [] -resolver: lts-7.11 +resolver: lts-7.14 compiler: ghc-8.0.1 system-ghc: false install-ghc: true diff --git a/stack.yaml b/stack.yaml @@ -2,4 +2,4 @@ flags: {} packages: - '.' extra-deps: [] -resolver: lts-7.11 +resolver: lts-7.14 diff --git a/test/BNN.hs b/test/BNN.hs @@ -11,5 +11,7 @@ bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where x1 = index xs 1 main :: IO () -main = withSystemRandom . asGenIO $ mcmc 100 1 (fromList [0, 0]) bnn +main = withSystemRandom . asGenIO $ \gen -> do + _ <- chain 50 1 (fromList [0, 0]) bnn gen + mcmc 50 1 (fromList [0, 0]) bnn gen diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs @@ -8,5 +8,7 @@ rosenbrock :: [Double] -> Double rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) main :: IO () -main = withSystemRandom . asGenIO $ mcmc 100 1 [0, 0] rosenbrock +main = withSystemRandom . asGenIO $ \gen -> do + _ <- chain 50 1 [0, 0] rosenbrock gen + mcmc 50 1 [0, 0] rosenbrock gen