commit 1acf97eecd44c3b09060f5d9d84ac8a119aeee06
parent f36a49b36964484edf7dcb25ab1b6b116219fe75
Author: Jared Tobin <jared@jtobin.ca>
Date: Thu, 22 Dec 2016 10:01:54 +1300
Add 'chain' function.
Diffstat:
6 files changed, 53 insertions(+), 12 deletions(-)
diff --git a/CHANGELOG b/CHANGELOG
@@ -1,5 +1,8 @@
# Changelog
+- 0.3.0 (2016-12-22)
+ * Add 'chain' function for working with chains in memory.
+
- 0.2.0 (2016-12-20)
* Generalize base monad requirement to something matching both MonadIO and
PrimState.
diff --git a/Numeric/MCMC/Slice.hs b/Numeric/MCMC/Slice.hs
@@ -28,6 +28,7 @@
module Numeric.MCMC.Slice (
mcmc
+ , chain
, slice
-- * Re-exported
@@ -37,6 +38,8 @@ module Numeric.MCMC.Slice (
, MWC.asGenIO
) where
+import Control.Monad (replicateM)
+import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.Trans.State.Strict (put, get, execStateT)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.IO.Class (MonadIO, liftIO)
@@ -66,7 +69,7 @@ mcmc
-> Gen (PrimState m)
-> m ()
mcmc n radial chainPosition target gen = runEffect $
- chain radial Chain {..} gen
+ drive radial Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
@@ -74,15 +77,44 @@ mcmc n radial chainPosition target gen = runEffect $
chainTunables = Nothing
chainTarget = Target target Nothing
--- A Markov chain driven by the slice transition operator.
+-- | Trace 'n' iterations of a Markov chain and collect them in a list.
+--
+-- >>> results <- withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
chain
+ :: (PrimMonad m, FoldableWithIndex (Index (f a)) f, Ixed (f a)
+ , Variate (IxValue (f a)), Num (IxValue (f a)))
+ => Int
+ -> IxValue (f a)
+ -> f a
+ -> (f a -> Double)
+ -> Gen (PrimState m)
+ -> m [Chain (f a) b]
+chain n radial position target gen = runEffect $
+ drive radial origin gen
+ >-> collect n
+ where
+ ctarget = Target target Nothing
+
+ origin = Chain {
+ chainScore = lTarget ctarget position
+ , chainTunables = Nothing
+ , chainTarget = ctarget
+ , chainPosition = position
+ }
+
+ collect :: Monad m => Int -> Consumer a m [a]
+ collect size = lowerCodensity $
+ replicateM size (lift Pipes.await)
+
+-- A Markov chain driven by the slice transition operator.
+drive
:: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
Num (IxValue (t a)), Variate (IxValue (t a)))
=> IxValue (t a)
-> Chain (t a) b
-> Gen (PrimState m)
- -> Producer (Chain (t a) b) m ()
-chain radial = loop where
+ -> Producer (Chain (t a) b) m c
+drive radial = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (slice radial) state) prng)
yield next
diff --git a/README.md b/README.md
@@ -14,8 +14,9 @@ argument.
Additionally you can sample over anything that's an instance of both `Num` and
`Variate`, which is useful in the case of discrete parameters.
-Exports a `mcmc` function that prints a trace to stdout, as well as a
-`slice` transition operator that can be used more generally.
+Exports a `mcmc` function that prints a trace to stdout, a `chain` function for
+working with results in memory, and a `slice` transition operator that can be
+used more generally.
import Numeric.MCMC.Slice
import Data.Sequence (Seq, index, fromList)
diff --git a/speedy-slice.cabal b/speedy-slice.cabal
@@ -1,5 +1,5 @@
name: speedy-slice
-version: 0.2.0
+version: 0.3.0
synopsis: Speedy slice sampling.
homepage: http://github.com/jtobin/speedy-slice
license: MIT
@@ -20,8 +20,9 @@ description:
Additionally you can sample over anything that's an instance of both 'Num' and
'Variate', which is useful in the case of discrete parameters.
.
- Exports a 'mcmc' function that prints a trace to stdout, as well as a
- 'slice' transition operator that can be used more generally.
+ Exports a 'mcmc' function that prints a trace to stdout, a 'chain' function
+ for collecting results in memory, and a 'slice' transition operator that can
+ be used more generally.
.
> import Numeric.MCMC.Slice
> import Data.Sequence (Seq, index, fromList)
@@ -44,6 +45,7 @@ library
Numeric.MCMC.Slice
build-depends:
base >= 4 && < 6
+ , kan-extensions >= 5 && < 6
, lens >= 4 && < 5
, primitive >= 0.6 && < 1.0
, mcmc-types >= 1.0.1
diff --git a/test/BNN.hs b/test/BNN.hs
@@ -11,5 +11,7 @@ bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
x1 = index xs 1
main :: IO ()
-main = withSystemRandom . asGenIO $ mcmc 100 1 (fromList [0, 0]) bnn
+main = withSystemRandom . asGenIO $ \gen -> do
+ _ <- chain 50 1 (fromList [0, 0]) bnn gen
+ mcmc 50 1 (fromList [0, 0]) bnn gen
diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs
@@ -8,6 +8,7 @@ rosenbrock :: [Double] -> Double
rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
main :: IO ()
-main = withSystemRandom . asGenIO $ mcmc 100 1 [0, 0] rosenbrock
-
+main = withSystemRandom . asGenIO $ \gen -> do
+ _ <- chain 50 1 [0, 0] rosenbrock gen
+ mcmc 50 1 [0, 0] rosenbrock gen