commit dbcb22fa16814592f8d45f8754174724047bcefc
parent 7304eeaf6f3f44f8902e8e3b156402fa830f0897
Author: Jared Tobin <jared@jtobin.ca>
Date: Thu, 22 Dec 2016 09:14:01 +1300
Add 'chain' for working with results in memory.
Resolves #2.
Diffstat:
7 files changed, 60 insertions(+), 17 deletions(-)
diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs
@@ -21,6 +21,7 @@
module Numeric.MCMC.Metropolis (
mcmc
+ , chain
, metropolis
-- * Re-exported
@@ -31,7 +32,8 @@ module Numeric.MCMC.Metropolis (
, MWC.asGenIO
) where
-import Control.Monad (when)
+import Control.Monad (when, replicateM)
+import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (execStateT, get, put)
@@ -40,8 +42,8 @@ import Data.Sampling.Types (Target(..), Chain(..), Transition)
#if __GLASGOW_HASKELL__ < 710
import Data.Traversable (Traversable, traverse)
#endif
-import Pipes (Producer, yield, (>->), runEffect)
-import qualified Pipes.Prelude as Pipes (mapM_, take)
+import Pipes (Producer, Consumer, yield, (>->), runEffect, await)
+import qualified Pipes.Prelude as Pipes (mapM_, take, map)
import System.Random.MWC.Probability (Gen, Prob)
import qualified System.Random.MWC.Probability as MWC
@@ -69,19 +71,53 @@ metropolis radial = do
accept <- lift (MWC.bernoulli acceptProb)
when accept (put (Chain chainTarget proposalScore proposal chainTunables))
--- A Markov chain driven by the Metropolis transition operator.
-chain
+-- Drive a Markov chain via the Metropolis transition operator.
+drive
:: (Traversable f, PrimMonad m)
=> Double
-> Chain (f Double) b
-> Gen (PrimState m)
- -> Producer (Chain (f Double) b) m ()
-chain radial = loop where
+ -> Producer (Chain (f Double) b) m c
+drive radial = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (metropolis radial) state) prng)
yield next
loop next prng
+-- | Trace 'n' iterations of a Markov chain and collect the results in a list.
+--
+-- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
+-- >>> results <- withSystemRandom . asGenIO $ chain 3 1 [0, 0] rosenbrock
+-- >>> mapM_ print results
+-- [0.0,0.0]
+-- [1.4754117657794871e-2,0.5033208261760778]
+-- [3.8379699517007895e-3,0.24627131099479127]
+chain
+ :: (PrimMonad m, Traversable f)
+ => Int
+ -> Double
+ -> f Double
+ -> (f Double -> Double)
+ -> Gen (PrimState m)
+ -> m [f Double]
+chain n radial position target gen = runEffect $
+ drive radial origin gen
+ >-> Pipes.map chainPosition
+ >-> collect n
+ where
+ ctarget = Target target Nothing
+
+ origin = Chain {
+ chainScore = lTarget ctarget position
+ , chainTunables = Nothing
+ , chainTarget = ctarget
+ , chainPosition = position
+ }
+
+ collect :: Monad m => Int -> Consumer a m [a]
+ collect size = lowerCodensity $
+ replicateM size (lift Pipes.await)
+
-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--
-- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
@@ -98,7 +134,7 @@ mcmc
-> Gen (PrimState m)
-> m ()
mcmc n radial chainPosition target gen = runEffect $
- chain radial Chain {..} gen
+ drive radial Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
diff --git a/README.md b/README.md
@@ -7,8 +7,9 @@
The classic Metropolis algorithm. Wander around parameter space according to a
simple spherical Gaussian distribution.
-Exports a `mcmc` function that prints a trace to stdout, as well as a
-`metropolis` transition operator that can be used more generally.
+Exports a `mcmc` function that prints a trace to stdout, a `chain` function for
+collecting results in memory, and a `metropolis` transition operator that can
+be used more generally.
See the *test* directory for example usage.
diff --git a/mighty-metropolis.cabal b/mighty-metropolis.cabal
@@ -1,5 +1,5 @@
name: mighty-metropolis
-version: 1.1.0
+version: 1.2.0
synopsis: The Metropolis algorithm.
homepage: http://github.com/jtobin/mighty-metropolis
license: MIT
@@ -15,8 +15,9 @@ description:
Wander around parameter space according to a simple spherical Gaussian
distribution.
.
- Exports a 'mcmc' function that prints a trace to stdout, as well as a
- 'metropolis' transition operator that can be used more generally.
+ Exports a 'mcmc' function that prints a trace to stdout, a 'chain' function
+ for collecting results in-memory, and a 'metropolis' transition operator that
+ can be used more generally.
.
> import Numeric.MCMC.Metropolis
>
@@ -38,6 +39,7 @@ library
Numeric.MCMC.Metropolis
build-depends:
base >= 4 && < 6
+ , kan-extensions >= 5 && < 6
, pipes >= 4 && < 5
, primitive >= 0.6 && < 1.0
, mcmc-types >= 1.0.1
diff --git a/stack-travis.yaml b/stack-travis.yaml
@@ -2,7 +2,7 @@ flags: {}
packages:
- '.'
extra-deps: []
-resolver: lts-7.11
+resolver: lts-7.14
compiler: ghc-8.0.1
system-ghc: false
install-ghc: true
diff --git a/stack.yaml b/stack.yaml
@@ -2,4 +2,4 @@ flags: {}
packages:
- '.'
extra-deps: []
-resolver: lts-7.11
+resolver: lts-7.14
diff --git a/test/BNN.hs b/test/BNN.hs
@@ -11,5 +11,7 @@ bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
x1 = index xs 1
main :: IO ()
-main = withSystemRandom . asGenIO $ mcmc 100 1 (fromList [0, 0]) bnn
+main = withSystemRandom . asGenIO $ \gen -> do
+ _ <- chain 50 1 (fromList [0, 0]) bnn gen
+ mcmc 50 1 (fromList [0, 0]) bnn gen
diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs
@@ -8,5 +8,7 @@ rosenbrock :: [Double] -> Double
rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
main :: IO ()
-main = withSystemRandom . asGenIO $ mcmc 100 1 [0, 0] rosenbrock
+main = withSystemRandom . asGenIO $ \gen -> do
+ _ <- chain 50 1 [0, 0] rosenbrock gen
+ mcmc 50 1 [0, 0] rosenbrock gen