commit 4722d49428710580e8d43d30c2faf5fac69ce0b2
Author: Jared Tobin <jared@jtobin.ca>
Date: Wed, 7 Oct 2015 20:58:23 +1300
Initial commit.
Diffstat:
9 files changed, 341 insertions(+), 0 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -0,0 +1,5 @@
+*swp
+.stack-work
+debug
+*o
+*hi
diff --git a/LICENSE b/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2015 Jared Tobin
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/Numeric/MCMC/Slice.hs b/Numeric/MCMC/Slice.hs
@@ -0,0 +1,179 @@
+{-# OPTIONS_GHC -Wall #-}
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE TypeFamilies #-}
+
+-- |
+-- Module: Numeric.MCMC.Slice
+-- Copyright: (c) 2015 Jared Tobin
+-- License: MIT
+--
+-- Maintainer: Jared Tobin <jared@jtobin.ca>
+-- Stability: unstable
+-- Portability: ghc
+--
+-- This implementation performs slice sampling by first finding a bracket about
+-- a mode (using a simple doubling heuristic), and then doing rejection
+-- sampling along it. The result is a reliable and computationally inexpensive
+-- sampling routine.
+--
+-- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
+-- while the `slice` transition can be used for more flexible purposes, such as
+-- working with samples in memory.
+--
+-- See <http://people.ee.duke.edu/~lcarin/slice.pdf Neal, 2003> for the
+-- definitive reference of the algorithm.
+
+module Numeric.MCMC.Slice (
+ mcmc
+ , slice
+
+ -- * Re-exported
+ , MWC.create
+ , MWC.createSystemRandom
+ , MWC.withSystemRandom
+ , MWC.asGenIO
+ ) where
+
+import Control.Monad.Trans.State.Strict (put, get, execStateT)
+import Control.Monad.Primitive (PrimMonad, PrimState)
+import Control.Lens hiding (index)
+import GHC.Prim (RealWorld)
+import Data.Maybe (fromMaybe)
+import Data.Sampling.Types
+import Pipes hiding (next)
+import qualified Pipes.Prelude as Pipes
+import System.Random.MWC.Probability (Prob, Gen)
+import qualified System.Random.MWC.Probability as MWC
+
+-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
+--
+-- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
+-- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
+-- -3.854097694213343e-2,0.16688601288358407
+-- -9.310661272172682e-2,0.2562387977415508
+-- -0.48500122500661846,0.46245400501919076
+mcmc
+ :: (Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a),
+ IxValue (t a) ~ Double)
+ => Int
+ -> Double
+ -> t a
+ -> (t a -> Double)
+ -> Gen RealWorld
+ -> IO ()
+mcmc n radial chainPosition target gen = runEffect $
+ chain radial Chain {..} gen
+ >-> Pipes.take n
+ >-> Pipes.mapM_ print
+ where
+ chainScore = lTarget chainTarget chainPosition
+ chainTunables = Nothing
+ chainTarget = Target target Nothing
+
+-- A Markov chain driven by the slice transition operator.
+chain
+ :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
+ IxValue (t a) ~ Double)
+ => Double
+ -> Chain (t a) b
+ -> Gen (PrimState m)
+ -> Producer (Chain (t a) b) m ()
+chain radial = loop where
+ loop state prng = do
+ next <- lift (MWC.sample (execStateT (slice radial) state) prng)
+ yield next
+ loop next prng
+
+-- | A slice sampling transition operator.
+slice
+ :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
+ IxValue (t a) ~ Double)
+ => Double
+ -> Transition m (Chain (t a) b)
+slice step = do
+ Chain _ _ position _ <- get
+ ifor_ position $ \index _ -> do
+ Chain {..} <- get
+ let bounds = (0, exp (lTarget chainTarget chainPosition))
+ height <- lift (fmap log (MWC.uniformR bounds))
+
+ let bracket =
+ findBracket (lTarget chainTarget) index step height chainPosition
+
+ perturbed <- lift $
+ rejection (lTarget chainTarget) index bracket height chainPosition
+
+ let perturbedScore = lTarget chainTarget perturbed
+ put (Chain chainTarget perturbedScore perturbed chainTunables)
+
+-- Find a bracket by expanding its bounds through powers of 2.
+findBracket
+ :: (Ord a, Ixed s, IxValue s ~ Double)
+ => (s -> a)
+ -> Index s
+ -> Double
+ -> a
+ -> s
+ -> (IxValue s, IxValue s)
+findBracket target index step height xs = go step xs xs where
+ err = error "findBracket: invalid index -- please report this as a bug!"
+ go !e !bl !br
+ | target bl < height && target br < height =
+ let l = fromMaybe err (bl ^? ix index)
+ r = fromMaybe err (br ^? ix index)
+ in (l, r)
+ | target bl < height && target br >= height =
+ let br0 = expandBracketRight index step br
+ in go (2 * e) bl br0
+ | target bl >= height && target br < height =
+ let bl0 = expandBracketLeft index step bl
+ in go (2 * e) bl0 br
+ | otherwise =
+ let bl0 = expandBracketLeft index step bl
+ br0 = expandBracketRight index step br
+ in go (2 * e) bl0 br0
+
+expandBracketLeft
+ :: (Ixed s, IxValue s ~ Double)
+ => Index s
+ -> Double
+ -> s
+ -> s
+expandBracketLeft = expandBracketBy (-)
+
+expandBracketRight
+ :: (Ixed s, IxValue s ~ Double)
+ => Index s
+ -> Double
+ -> s
+ -> s
+expandBracketRight = expandBracketBy (+)
+
+expandBracketBy
+ :: Ixed s
+ => (IxValue s -> Double -> IxValue s)
+ -> Index s
+ -> Double
+ -> s
+ -> s
+expandBracketBy f index step xs = xs & ix index %~ (`f` step )
+
+-- Perform rejection sampling within the supplied bracket.
+rejection
+ :: (Ord a, PrimMonad m, Ixed b, IxValue b ~ Double)
+ => (b -> a)
+ -> Index b
+ -> (Double, Double)
+ -> a
+ -> b
+ -> Prob m b
+rejection target dimension bracket height = go where
+ go zs = do
+ u <- MWC.uniformR bracket
+ let updated = zs & ix dimension .~ u
+ if target updated < height
+ then go updated
+ else return updated
+
diff --git a/README.md b/README.md
@@ -0,0 +1,25 @@
+# speedy-slice [![Build Status](https://secure.travis-ci.org/jtobin/speedy-slice.png)](http://travis-ci.org/jtobin/speedy-slice)
+
+Speedy slice sampling, as per [Neal, 2003](http://people.ee.duke.edu/~lcarin/slice.pdf).
+
+This implementation of the slice sampling algorithm uses `lens` as a means to
+operate over generic indexed traversable functors, so you can expect it to
+work if your target function takes a list, vector, map, sequence, etc. as its
+argument.
+
+Exports a `mcmc` function that prints a trace to stdout, as well as a
+`slice` transition operator that can be used more generally.
+
+ import Numeric.MCMC.Slice
+ import Data.Sequence (Seq, index, fromList)
+
+ bnn :: Seq Double -> Double
+ bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
+ x0 = index xs 0
+ x1 = index xs 1
+
+ main :: IO ()
+ main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn
+
+![trace](https://dl.dropboxusercontent.com/spa/u0s6617yxinm2ca/zp-9gl6z.png)
+
diff --git a/Setup.hs b/Setup.hs
@@ -0,0 +1,2 @@
+import Distribution.Simple
+main = defaultMain
diff --git a/speedy-slice.cabal b/speedy-slice.cabal
@@ -0,0 +1,76 @@
+name: speedy-slice
+version: 0.1.0.0
+synopsis: Speedy slice sampling.
+homepage: http://github.com/jtobin/speedy-slice
+license: MIT
+license-file: LICENSE
+author: Jared Tobin
+maintainer: jared@jtobin.ca
+category: Math
+build-type: Simple
+cabal-version: >=1.10
+description:
+ Speedy slice sampling.
+ .
+ This implementation of the slice sampling algorithm uses 'lens' as a means to
+ operate over generic indexed traversable functors, so you can expect it to
+ work if your target function takes a list, vector, map, sequence, etc. as its
+ argument.
+ .
+ Exports a 'mcmc' function that prints a trace to stdout, as well as a
+ 'slice' transition operator that can be used more generally.
+ .
+ > import Numeric.MCMC.Slice
+ > import Data.Sequence (Seq, index, fromList)
+ >
+ > bnn :: Seq Double -> Double
+ > bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
+ > x0 = index xs 0
+ > x1 = index xs 1
+ >
+ > main :: IO ()
+ > main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn
+
+Source-repository head
+ Type: git
+ Location: http://github.com/jtobin/speedy-slice.git
+
+library
+ default-language: Haskell2010
+ exposed-modules:
+ Numeric.MCMC.Slice
+ build-depends:
+ base < 5
+ , ghc-prim
+ , lens
+ , primitive
+ , mcmc-types >= 1.0.1
+ , mwc-probability >= 1.0.1
+ , pipes
+ , transformers
+
+Test-suite rosenbrock
+ type: exitcode-stdio-1.0
+ hs-source-dirs: test
+ main-is: Rosenbrock.hs
+ default-language: Haskell2010
+ ghc-options:
+ -rtsopts
+ build-depends:
+ base < 5
+ , mwc-probability >= 1.0.1
+ , speedy-slice
+
+Test-suite bnn
+ type: exitcode-stdio-1.0
+ hs-source-dirs: test
+ main-is: BNN.hs
+ default-language: Haskell2010
+ ghc-options:
+ -rtsopts
+ build-depends:
+ base < 5
+ , containers
+ , mwc-probability >= 1.0.1
+ , speedy-slice
+
diff --git a/stack.yaml b/stack.yaml
@@ -0,0 +1,7 @@
+flags: {}
+packages:
+ - '.'
+extra-deps:
+ - mwc-probability-1.0.1
+ - mcmc-types-1.0.1
+resolver: lts-3.3
diff --git a/test/BNN.hs b/test/BNN.hs
@@ -0,0 +1,15 @@
+{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+
+module Main where
+
+import Numeric.MCMC.Slice
+import Data.Sequence (Seq, fromList, index)
+
+bnn :: Seq Double -> Double
+bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
+ x0 = index xs 0
+ x1 = index xs 1
+
+main :: IO ()
+main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn
+
diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs
@@ -0,0 +1,13 @@
+{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+
+module Main where
+
+import Numeric.MCMC.Slice
+
+rosenbrock :: [Double] -> Double
+rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
+
+main :: IO ()
+main = withSystemRandom . asGenIO $ mcmc 10000 1 [0, 0] rosenbrock
+
+