commit e8f3fd5ec4e36750716ffd4b6c84541a75d948a6
parent 6f61629be2d820bcfd3f7212898866d19508f967
Author: Jared Tobin <jared@jtobin.ca>
Date: Thu, 22 Dec 2016 09:49:47 +1300
Add 'chain' function.
Diffstat:
5 files changed, 52 insertions(+), 12 deletions(-)
diff --git a/CHANGELOG b/CHANGELOG
@@ -1,5 +1,8 @@
# Changelog
+- 1.3.0 (2016-12-22)
+ * Add 'chain' function for working with chains in memory.
+
- 1.2.0 (2016-12-20)
* Generalize base monad requirement to something matching both MonadIO and
PrimState.
diff --git a/Numeric/MCMC/Hamiltonian.hs b/Numeric/MCMC/Hamiltonian.hs
@@ -24,6 +24,7 @@
module Numeric.MCMC.Hamiltonian (
mcmc
+ , chain
, hamiltonian
-- * Re-exported
@@ -35,6 +36,8 @@ module Numeric.MCMC.Hamiltonian (
) where
import Control.Lens hiding (index)
+import Control.Monad (replicateM)
+import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Primitive (PrimState, PrimMonad)
import Control.Monad.Trans.State.Strict hiding (state)
@@ -49,7 +52,7 @@ import qualified System.Random.MWC.Probability as MWC
-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--
--- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] target
+-- >>> withSystemRandom . asGenIO $ mcmc 10000 0.05 20 [0, 0] target
mcmc
:: ( MonadIO m, PrimMonad m
, Num (IxValue (t Double)), Show (t Double), Traversable t
@@ -63,15 +66,44 @@ mcmc
-> Gen (PrimState m)
-> m ()
mcmc n step leaps chainPosition chainTarget gen = runEffect $
- chain step leaps Chain {..} gen
+ drive step leaps Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
--- A Markov chain driven by the Metropolis transition operator.
+-- | Trace 'n' iterations of a Markov chain and collect the results in a list.
+--
+-- >>> results <- withSystemRandom . asGenIO $ chain 1000 0.05 20 [0, 0] target
chain
+ :: (PrimMonad m, Traversable f
+ , FunctorWithIndex (Index (f Double)) f, Ixed (f Double)
+ , IxValue (f Double) ~ Double)
+ => Int
+ -> Double
+ -> Int
+ -> f Double
+ -> Target (f Double)
+ -> Gen (PrimState m)
+ -> m [Chain (f Double) b]
+chain n step leaps position target gen = runEffect $
+ drive step leaps origin gen
+ >-> collect n
+ where
+ origin = Chain {
+ chainScore = lTarget target position
+ , chainTunables = Nothing
+ , chainTarget = target
+ , chainPosition = position
+ }
+
+ collect :: Monad m => Int -> Consumer a m [a]
+ collect size = lowerCodensity $
+ replicateM size (lift Pipes.await)
+
+-- Drive a Markov chain.
+drive
:: (Num (IxValue (t Double)), Traversable t
, FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
, PrimMonad m, IxValue (t Double) ~ Double)
@@ -79,8 +111,8 @@ chain
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
- -> Producer (Chain (t Double) b) m ()
-chain step leaps = loop where
+ -> Producer (Chain (t Double) b) m c
+drive step leaps = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (hamiltonian step leaps) state) prng)
yield next
diff --git a/README.md b/README.md
@@ -7,8 +7,9 @@
Speedy, gradient-based traversal through parameter space.
-Exports a `mcmc` function that prints a trace to stdout, as well as a
-`hamiltonian` transition operator that can be used more generally.
+Exports a `mcmc` function that prints a trace to stdout, a `chain` function for
+collecting results in memory, and a `hamiltonian` transition operator that can
+be used more generally.
If you don't want to calculate your gradients by hand you can use the handy
[ad](https://hackage.haskell.org/package/ad) library for automatic
diff --git a/hasty-hamiltonian.cabal b/hasty-hamiltonian.cabal
@@ -1,5 +1,5 @@
name: hasty-hamiltonian
-version: 1.2.0
+version: 1.3.0
synopsis: Speedy traversal through parameter space.
homepage: http://github.com/jtobin/hasty-hamiltonian
license: MIT
@@ -8,7 +8,7 @@ author: Jared Tobin
maintainer: jared@jtobin.ca
category: Numeric
build-type: Simple
-cabal-version: >=1.10
+cabal-version: >= 1.10
Description:
Gradient-based traversal through parameter space.
.
@@ -20,8 +20,9 @@ Description:
handy <https://hackage.haskell.org/package/ad ad> library for automatic
differentiation.
.
- Exports a 'mcmc' function that prints a trace to stdout, as well as a
- 'hamiltonian' transition operator that can be used more generally.
+ Exports a 'mcmc' function that prints a trace to stdout, a 'chain' function
+ for collecting results in memory, and a 'hamiltonian' transition operator
+ that can be used more generally.
.
> import Numeric.AD (grad)
> import Numeric.MCMC.Hamiltonian
@@ -50,6 +51,7 @@ library
Numeric.MCMC.Hamiltonian
build-depends:
base >= 4 && < 6
+ , kan-extensions >= 5 && < 6
, mcmc-types >= 1.0.1
, mwc-probability >= 1.0.1
, lens >= 4 && < 5
diff --git a/test/Booth.hs b/test/Booth.hs
@@ -13,5 +13,7 @@ booth :: Target [Double]
booth = Target target (Just gTarget)
main :: IO ()
-main = withSystemRandom . asGenIO $ mcmc 100 0.05 20 [0, 0] booth
+main = withSystemRandom . asGenIO $ \gen -> do
+ _ <- chain 100 0.05 20 [0, 0] booth gen
+ mcmc 100 0.05 20 [0, 0] booth gen