commit 4722d49428710580e8d43d30c2faf5fac69ce0b2
Author: Jared Tobin <jared@jtobin.ca>
Date:   Wed,  7 Oct 2015 20:58:23 +1300
Initial commit.
Diffstat:
9 files changed, 341 insertions(+), 0 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -0,0 +1,5 @@
+*swp
+.stack-work
+debug
+*o
+*hi
diff --git a/LICENSE b/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2015 Jared Tobin
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/Numeric/MCMC/Slice.hs b/Numeric/MCMC/Slice.hs
@@ -0,0 +1,179 @@
+{-# OPTIONS_GHC -Wall #-}
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE TypeFamilies #-}
+
+-- |
+-- Module: Numeric.MCMC.Slice
+-- Copyright: (c) 2015 Jared Tobin
+-- License: MIT
+--
+-- Maintainer: Jared Tobin <jared@jtobin.ca>
+-- Stability: unstable
+-- Portability: ghc
+--
+-- This implementation performs slice sampling by first finding a bracket about
+-- a mode (using a simple doubling heuristic), and then doing rejection
+-- sampling along it.  The result is a reliable and computationally inexpensive
+-- sampling routine.
+--
+-- The 'mcmc' function streams a trace to stdout to be processed elsewhere,
+-- while the `slice` transition can be used for more flexible purposes, such as
+-- working with samples in memory.
+--
+-- See <http://people.ee.duke.edu/~lcarin/slice.pdf Neal, 2003> for the
+-- definitive reference of the algorithm.
+
+module Numeric.MCMC.Slice (
+    mcmc
+  , slice
+
+  -- * Re-exported
+  , MWC.create
+  , MWC.createSystemRandom
+  , MWC.withSystemRandom
+  , MWC.asGenIO
+  ) where
+
+import Control.Monad.Trans.State.Strict (put, get, execStateT)
+import Control.Monad.Primitive (PrimMonad, PrimState)
+import Control.Lens hiding (index)
+import GHC.Prim (RealWorld)
+import Data.Maybe (fromMaybe)
+import Data.Sampling.Types
+import Pipes hiding (next)
+import qualified Pipes.Prelude as Pipes
+import System.Random.MWC.Probability (Prob, Gen)
+import qualified System.Random.MWC.Probability as MWC
+
+-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
+--
+-- >>> let rosenbrock [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
+-- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock
+-- -3.854097694213343e-2,0.16688601288358407
+-- -9.310661272172682e-2,0.2562387977415508
+-- -0.48500122500661846,0.46245400501919076
+mcmc
+  :: (Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a),
+     IxValue (t a) ~ Double)
+  => Int
+  -> Double
+  -> t a
+  -> (t a -> Double)
+  -> Gen RealWorld
+  -> IO ()
+mcmc n radial chainPosition target gen = runEffect $
+        chain radial Chain {..} gen
+    >-> Pipes.take n
+    >-> Pipes.mapM_ print
+  where
+    chainScore    = lTarget chainTarget chainPosition
+    chainTunables = Nothing
+    chainTarget   = Target target Nothing
+
+-- A Markov chain driven by the slice transition operator.
+chain
+  :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
+     IxValue (t a) ~ Double)
+  => Double
+  -> Chain (t a) b
+  -> Gen (PrimState m)
+  -> Producer (Chain (t a) b) m ()
+chain radial = loop where
+  loop state prng = do
+    next <- lift (MWC.sample (execStateT (slice radial) state) prng)
+    yield next
+    loop next prng
+
+-- | A slice sampling transition operator.
+slice
+  :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a),
+      IxValue (t a) ~ Double)
+  => Double
+  -> Transition m (Chain (t a) b)
+slice step = do
+  Chain _ _ position _ <- get
+  ifor_ position $ \index _ -> do
+    Chain {..} <- get
+    let bounds = (0, exp (lTarget chainTarget chainPosition))
+    height    <- lift (fmap log (MWC.uniformR bounds))
+
+    let bracket =
+          findBracket (lTarget chainTarget) index step height chainPosition
+
+    perturbed <- lift $
+      rejection (lTarget chainTarget) index bracket height chainPosition
+
+    let perturbedScore = lTarget chainTarget perturbed
+    put (Chain chainTarget perturbedScore perturbed chainTunables)
+
+-- Find a bracket by expanding its bounds through powers of 2.
+findBracket
+  :: (Ord a, Ixed s, IxValue s ~ Double)
+  => (s -> a)
+  -> Index s
+  -> Double
+  -> a
+  -> s
+  -> (IxValue s, IxValue s)
+findBracket target index step height xs = go step xs xs where
+  err = error "findBracket: invalid index -- please report this as a bug!"
+  go !e !bl !br
+    | target bl < height && target br < height =
+        let l = fromMaybe err (bl ^? ix index)
+            r = fromMaybe err (br ^? ix index)
+        in  (l, r)
+    | target bl < height && target br >= height =
+        let br0 = expandBracketRight index step br
+        in  go (2 * e) bl br0
+    | target bl >= height && target br < height =
+        let bl0 = expandBracketLeft index step bl
+        in  go (2 * e) bl0 br
+    | otherwise =
+        let bl0 = expandBracketLeft index step bl
+            br0 = expandBracketRight index step br
+        in  go (2 * e) bl0 br0
+
+expandBracketLeft
+  :: (Ixed s, IxValue s ~ Double)
+  => Index s
+  -> Double
+  -> s
+  -> s
+expandBracketLeft = expandBracketBy (-)
+
+expandBracketRight
+  :: (Ixed s, IxValue s ~ Double)
+  => Index s
+  -> Double
+  -> s
+  -> s
+expandBracketRight = expandBracketBy (+)
+
+expandBracketBy
+  :: Ixed s
+  => (IxValue s -> Double -> IxValue s)
+  -> Index s
+  -> Double
+  -> s
+  -> s
+expandBracketBy f index step xs = xs & ix index %~ (`f` step )
+
+-- Perform rejection sampling within the supplied bracket.
+rejection
+  :: (Ord a, PrimMonad m, Ixed b, IxValue b ~ Double)
+  => (b -> a)
+  -> Index b
+  -> (Double, Double)
+  -> a
+  -> b
+  -> Prob m b
+rejection target dimension bracket height = go where
+  go zs = do
+    u <- MWC.uniformR bracket
+    let  updated = zs & ix dimension .~ u
+    if   target updated < height
+    then go updated
+    else return updated
+
diff --git a/README.md b/README.md
@@ -0,0 +1,25 @@
+# speedy-slice [](http://travis-ci.org/jtobin/speedy-slice)
+
+Speedy slice sampling, as per [Neal, 2003](http://people.ee.duke.edu/~lcarin/slice.pdf).
+
+This implementation of the slice sampling algorithm uses `lens` as a means to
+operate over generic indexed traversable functors, so you can expect it to
+work if your target function takes a list, vector, map, sequence, etc. as its
+argument.
+
+Exports a `mcmc` function that prints a trace to stdout, as well as a
+`slice` transition operator that can be used more generally.
+
+    import Numeric.MCMC.Slice
+    import Data.Sequence (Seq, index, fromList)
+
+    bnn :: Seq Double -> Double
+    bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
+      x0 = index xs 0
+      x1 = index xs 1
+
+    main :: IO ()
+    main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn
+
+
+
diff --git a/Setup.hs b/Setup.hs
@@ -0,0 +1,2 @@
+import Distribution.Simple
+main = defaultMain
diff --git a/speedy-slice.cabal b/speedy-slice.cabal
@@ -0,0 +1,76 @@
+name:                speedy-slice
+version:             0.1.0.0
+synopsis:            Speedy slice sampling.
+homepage:            http://github.com/jtobin/speedy-slice
+license:             MIT
+license-file:        LICENSE
+author:              Jared Tobin
+maintainer:          jared@jtobin.ca
+category:            Math
+build-type:          Simple
+cabal-version:       >=1.10
+description:
+  Speedy slice sampling.
+  .
+  This implementation of the slice sampling algorithm uses 'lens' as a means to
+  operate over generic indexed traversable functors, so you can expect it to
+  work if your target function takes a list, vector, map, sequence, etc. as its
+  argument.
+  .
+  Exports a 'mcmc' function that prints a trace to stdout, as well as a
+  'slice' transition operator that can be used more generally.
+  .
+  > import Numeric.MCMC.Slice
+  > import Data.Sequence (Seq, index, fromList)
+  >
+  > bnn :: Seq Double -> Double
+  > bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
+  >   x0 = index xs 0
+  >   x1 = index xs 1
+  >
+  > main :: IO ()
+  > main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn
+
+Source-repository head
+  Type:     git
+  Location: http://github.com/jtobin/speedy-slice.git
+
+library
+  default-language:    Haskell2010
+  exposed-modules:
+      Numeric.MCMC.Slice
+  build-depends:
+      base            <  5
+    , ghc-prim
+    , lens
+    , primitive
+    , mcmc-types      >= 1.0.1
+    , mwc-probability >= 1.0.1
+    , pipes
+    , transformers
+
+Test-suite rosenbrock
+  type:                exitcode-stdio-1.0
+  hs-source-dirs:      test
+  main-is:             Rosenbrock.hs
+  default-language:    Haskell2010
+  ghc-options:
+    -rtsopts
+  build-depends:
+      base              < 5
+    , mwc-probability   >= 1.0.1
+    , speedy-slice
+
+Test-suite bnn
+  type:                exitcode-stdio-1.0
+  hs-source-dirs:      test
+  main-is:             BNN.hs
+  default-language:    Haskell2010
+  ghc-options:
+    -rtsopts
+  build-depends:
+      base              < 5
+    , containers
+    , mwc-probability   >= 1.0.1
+    , speedy-slice
+
diff --git a/stack.yaml b/stack.yaml
@@ -0,0 +1,7 @@
+flags: {}
+packages:
+  - '.'
+extra-deps:
+  - mwc-probability-1.0.1
+  - mcmc-types-1.0.1
+resolver: lts-3.3
diff --git a/test/BNN.hs b/test/BNN.hs
@@ -0,0 +1,15 @@
+{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+
+module Main where
+
+import Numeric.MCMC.Slice
+import Data.Sequence (Seq, fromList, index)
+
+bnn :: Seq Double -> Double
+bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where
+  x0 = index xs 0
+  x1 = index xs 1
+
+main :: IO ()
+main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn
+
diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs
@@ -0,0 +1,13 @@
+{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+
+module Main where
+
+import Numeric.MCMC.Slice
+
+rosenbrock :: [Double] -> Double
+rosenbrock [x0, x1] = negate (5  *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
+
+main :: IO ()
+main = withSystemRandom . asGenIO $ mcmc 10000 1 [0, 0] rosenbrock
+
+