speedy-slice

Speedy slice sampling.
Log | Files | Refs | README | LICENSE

commit 4722d49428710580e8d43d30c2faf5fac69ce0b2
Author: Jared Tobin <jared@jtobin.ca>
Date:   Wed,  7 Oct 2015 20:58:23 +1300

Initial commit.

Diffstat:
A.gitignore | 5+++++
ALICENSE | 19+++++++++++++++++++
ANumeric/MCMC/Slice.hs | 179+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
AREADME.md | 25+++++++++++++++++++++++++
ASetup.hs | 2++
Aspeedy-slice.cabal | 76++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Astack.yaml | 7+++++++
Atest/BNN.hs | 15+++++++++++++++
Atest/Rosenbrock.hs | 13+++++++++++++
9 files changed, 341 insertions(+), 0 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -0,0 +1,5 @@ +*swp +.stack-work +debug +*o +*hi diff --git a/LICENSE b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2015 Jared Tobin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/Numeric/MCMC/Slice.hs b/Numeric/MCMC/Slice.hs @@ -0,0 +1,179 @@ +{-# OPTIONS_GHC -Wall #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TypeFamilies #-} + +-- | +-- Module: Numeric.MCMC.Slice +-- Copyright: (c) 2015 Jared Tobin +-- License: MIT +-- +-- Maintainer: Jared Tobin <jared@jtobin.ca> +-- Stability: unstable +-- Portability: ghc +-- +-- This implementation performs slice sampling by first finding a bracket about +-- a mode (using a simple doubling heuristic), and then doing rejection +-- sampling along it. The result is a reliable and computationally inexpensive +-- sampling routine. +-- +-- The 'mcmc' function streams a trace to stdout to be processed elsewhere, +-- while the `slice` transition can be used for more flexible purposes, such as +-- working with samples in memory. +-- +-- See <http://people.ee.duke.edu/~lcarin/slice.pdf Neal, 2003> for the +-- definitive reference of the algorithm. + +module Numeric.MCMC.Slice ( + mcmc + , slice + + -- * Re-exported + , MWC.create + , MWC.createSystemRandom + , MWC.withSystemRandom + , MWC.asGenIO + ) where + +import Control.Monad.Trans.State.Strict (put, get, execStateT) +import Control.Monad.Primitive (PrimMonad, PrimState) +import Control.Lens hiding (index) +import GHC.Prim (RealWorld) +import Data.Maybe (fromMaybe) +import Data.Sampling.Types +import Pipes hiding (next) +import qualified Pipes.Prelude as Pipes +import System.Random.MWC.Probability (Prob, Gen) +import qualified System.Random.MWC.Probability as MWC + +-- | Trace 'n' iterations of a Markov chain and stream them to stdout. +-- +-- >>> let rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) +-- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] rosenbrock +-- -3.854097694213343e-2,0.16688601288358407 +-- -9.310661272172682e-2,0.2562387977415508 +-- -0.48500122500661846,0.46245400501919076 +mcmc + :: (Show (t a), FoldableWithIndex (Index (t a)) t, Ixed (t a), + IxValue (t a) ~ Double) + => Int + -> Double + -> t a + -> (t a -> Double) + -> Gen RealWorld + -> IO () +mcmc n radial chainPosition target gen = runEffect $ + chain radial Chain {..} gen + >-> Pipes.take n + >-> Pipes.mapM_ print + where + chainScore = lTarget chainTarget chainPosition + chainTunables = Nothing + chainTarget = Target target Nothing + +-- A Markov chain driven by the slice transition operator. +chain + :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a), + IxValue (t a) ~ Double) + => Double + -> Chain (t a) b + -> Gen (PrimState m) + -> Producer (Chain (t a) b) m () +chain radial = loop where + loop state prng = do + next <- lift (MWC.sample (execStateT (slice radial) state) prng) + yield next + loop next prng + +-- | A slice sampling transition operator. +slice + :: (PrimMonad m, FoldableWithIndex (Index (t a)) t, Ixed (t a), + IxValue (t a) ~ Double) + => Double + -> Transition m (Chain (t a) b) +slice step = do + Chain _ _ position _ <- get + ifor_ position $ \index _ -> do + Chain {..} <- get + let bounds = (0, exp (lTarget chainTarget chainPosition)) + height <- lift (fmap log (MWC.uniformR bounds)) + + let bracket = + findBracket (lTarget chainTarget) index step height chainPosition + + perturbed <- lift $ + rejection (lTarget chainTarget) index bracket height chainPosition + + let perturbedScore = lTarget chainTarget perturbed + put (Chain chainTarget perturbedScore perturbed chainTunables) + +-- Find a bracket by expanding its bounds through powers of 2. +findBracket + :: (Ord a, Ixed s, IxValue s ~ Double) + => (s -> a) + -> Index s + -> Double + -> a + -> s + -> (IxValue s, IxValue s) +findBracket target index step height xs = go step xs xs where + err = error "findBracket: invalid index -- please report this as a bug!" + go !e !bl !br + | target bl < height && target br < height = + let l = fromMaybe err (bl ^? ix index) + r = fromMaybe err (br ^? ix index) + in (l, r) + | target bl < height && target br >= height = + let br0 = expandBracketRight index step br + in go (2 * e) bl br0 + | target bl >= height && target br < height = + let bl0 = expandBracketLeft index step bl + in go (2 * e) bl0 br + | otherwise = + let bl0 = expandBracketLeft index step bl + br0 = expandBracketRight index step br + in go (2 * e) bl0 br0 + +expandBracketLeft + :: (Ixed s, IxValue s ~ Double) + => Index s + -> Double + -> s + -> s +expandBracketLeft = expandBracketBy (-) + +expandBracketRight + :: (Ixed s, IxValue s ~ Double) + => Index s + -> Double + -> s + -> s +expandBracketRight = expandBracketBy (+) + +expandBracketBy + :: Ixed s + => (IxValue s -> Double -> IxValue s) + -> Index s + -> Double + -> s + -> s +expandBracketBy f index step xs = xs & ix index %~ (`f` step ) + +-- Perform rejection sampling within the supplied bracket. +rejection + :: (Ord a, PrimMonad m, Ixed b, IxValue b ~ Double) + => (b -> a) + -> Index b + -> (Double, Double) + -> a + -> b + -> Prob m b +rejection target dimension bracket height = go where + go zs = do + u <- MWC.uniformR bracket + let updated = zs & ix dimension .~ u + if target updated < height + then go updated + else return updated + diff --git a/README.md b/README.md @@ -0,0 +1,25 @@ +# speedy-slice [![Build Status](https://secure.travis-ci.org/jtobin/speedy-slice.png)](http://travis-ci.org/jtobin/speedy-slice) + +Speedy slice sampling, as per [Neal, 2003](http://people.ee.duke.edu/~lcarin/slice.pdf). + +This implementation of the slice sampling algorithm uses `lens` as a means to +operate over generic indexed traversable functors, so you can expect it to +work if your target function takes a list, vector, map, sequence, etc. as its +argument. + +Exports a `mcmc` function that prints a trace to stdout, as well as a +`slice` transition operator that can be used more generally. + + import Numeric.MCMC.Slice + import Data.Sequence (Seq, index, fromList) + + bnn :: Seq Double -> Double + bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where + x0 = index xs 0 + x1 = index xs 1 + + main :: IO () + main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn + +![trace](https://dl.dropboxusercontent.com/spa/u0s6617yxinm2ca/zp-9gl6z.png) + diff --git a/Setup.hs b/Setup.hs @@ -0,0 +1,2 @@ +import Distribution.Simple +main = defaultMain diff --git a/speedy-slice.cabal b/speedy-slice.cabal @@ -0,0 +1,76 @@ +name: speedy-slice +version: 0.1.0.0 +synopsis: Speedy slice sampling. +homepage: http://github.com/jtobin/speedy-slice +license: MIT +license-file: LICENSE +author: Jared Tobin +maintainer: jared@jtobin.ca +category: Math +build-type: Simple +cabal-version: >=1.10 +description: + Speedy slice sampling. + . + This implementation of the slice sampling algorithm uses 'lens' as a means to + operate over generic indexed traversable functors, so you can expect it to + work if your target function takes a list, vector, map, sequence, etc. as its + argument. + . + Exports a 'mcmc' function that prints a trace to stdout, as well as a + 'slice' transition operator that can be used more generally. + . + > import Numeric.MCMC.Slice + > import Data.Sequence (Seq, index, fromList) + > + > bnn :: Seq Double -> Double + > bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where + > x0 = index xs 0 + > x1 = index xs 1 + > + > main :: IO () + > main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn + +Source-repository head + Type: git + Location: http://github.com/jtobin/speedy-slice.git + +library + default-language: Haskell2010 + exposed-modules: + Numeric.MCMC.Slice + build-depends: + base < 5 + , ghc-prim + , lens + , primitive + , mcmc-types >= 1.0.1 + , mwc-probability >= 1.0.1 + , pipes + , transformers + +Test-suite rosenbrock + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Rosenbrock.hs + default-language: Haskell2010 + ghc-options: + -rtsopts + build-depends: + base < 5 + , mwc-probability >= 1.0.1 + , speedy-slice + +Test-suite bnn + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: BNN.hs + default-language: Haskell2010 + ghc-options: + -rtsopts + build-depends: + base < 5 + , containers + , mwc-probability >= 1.0.1 + , speedy-slice + diff --git a/stack.yaml b/stack.yaml @@ -0,0 +1,7 @@ +flags: {} +packages: + - '.' +extra-deps: + - mwc-probability-1.0.1 + - mcmc-types-1.0.1 +resolver: lts-3.3 diff --git a/test/BNN.hs b/test/BNN.hs @@ -0,0 +1,15 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Main where + +import Numeric.MCMC.Slice +import Data.Sequence (Seq, fromList, index) + +bnn :: Seq Double -> Double +bnn xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) where + x0 = index xs 0 + x1 = index xs 1 + +main :: IO () +main = withSystemRandom . asGenIO $ mcmc 10000 1 (fromList [0, 0]) bnn + diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs @@ -0,0 +1,13 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Main where + +import Numeric.MCMC.Slice + +rosenbrock :: [Double] -> Double +rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) + +main :: IO () +main = withSystemRandom . asGenIO $ mcmc 10000 1 [0, 0] rosenbrock + +