flat-mcmc

Painless, efficient, general-purpose sampling from continuous distributions.
Log | Files | Refs | README | LICENSE

commit 69efdcdeb47cec6b0cf9813b025b704f3b680d8e
parent 57a93549fd7799d53f59108276bf01b7a28e53de
Author: Jared Tobin <jared@jtobin.ca>
Date:   Sun, 25 May 2014 23:10:15 +1000

Bump to 0.3.0.0.

Diffstat:
M.gitignore | 5++++-
DNumeric/MCMC/Flat.hs | 169-------------------------------------------------------------------------------
MREADME.md | 2--
ASetup.hs | 7+++++++
Mflat-mcmc.cabal | 45++++++++++++++++++++++++++++++++-------------
Asrc/Numeric/MCMC/Flat.hs | 142+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Atest/Tests.hs | 37+++++++++++++++++++++++++++++++++++++
7 files changed, 222 insertions(+), 185 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -2,7 +2,6 @@ *swp .DS_Store dist -Setup.hs *.hi *.o data @@ -16,3 +15,7 @@ demos/BNN_Flat demos/Rosenbrock_Flat demos/Himmelblau_Flat demos/SPDE_Flat +sandbox +.cabal-sandbox +cabal.sandbox.config + diff --git a/Numeric/MCMC/Flat.hs b/Numeric/MCMC/Flat.hs @@ -1,169 +0,0 @@ -{-# OPTIONS_GHC -Wall #-} -{-# LANGUAGE BangPatterns #-} - -module Numeric.MCMC.Flat ( - MarkovChain(..), Options(..), Ensemble - , runChain, readInits - ) where - -import Control.Arrow -import Control.Monad -import Control.Monad.Reader -import Control.Monad.Primitive -import System.Random.MWC -import qualified Data.Vector as V -import qualified Data.Vector.Unboxed as U -import Control.Monad.Par (NFData) -import Control.Monad.Par.Scheds.Direct -import Control.Monad.Par.Combinator -import GHC.Float -import System.IO - --- | Parallel map with a specified granularity. -parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b] -parMapChunk n f xs = do - xss <- parMap (map f) (chunk n xs) - return (concat xss) - where chunk _ [] = [] - chunk m ys = let (as, bs) = splitAt m ys - in as : chunk m bs - --- | State of the Markov chain. Current ensemble position is held in 'theta', --- while 'accepts' counts the number of proposals accepted. -data MarkovChain = MarkovChain { ensemble :: Ensemble - , accepts :: {-# UNPACK #-} !Int } - --- | Display the current state. This will be very slow and should be replaced. -instance Show MarkovChain where - show config = filter (`notElem` "[]") $ unlines $ map (show . map double2Float) (V.toList (ensemble config)) - --- | Options for the chain. The target (expected to be a log density), as --- well as the size of the ensemble. The size should be an even number. Also --- holds the specified parallel granularity as 'csize'. -data Options = Options { _size :: {-# UNPACK #-} !Int - , _nEpochs :: {-# UNPACK #-} !Int - , _burnIn :: {-# UNPACK #-} !Int - , _thinEvery :: {-# UNPACK #-} !Int - , _csize :: {-# UNPACK #-} !Int } - --- | An ensemble of particles. -type Ensemble = V.Vector [Double] - --- | A result with this type has a view of the chain's options. -type ViewsOptions = ReaderT Options - --- | Generate a random value from a distribution having the property that --- g(1/z) = zg(z). -symmetricVariate :: PrimMonad m => Gen (PrimState m) -> m Double -symmetricVariate g = do - z <- uniformR (0 :: Double, 1 :: Double) g - return $! 0.5*(z + 1)^(2 :: Int) -{-# INLINE symmetricVariate #-} - --- | The result of a single-particle Metropolis accept/reject step. This --- compares a particle's state to a perturbation made by an affine --- transformation based on a complementary particle. Non-monadic to --- more easily be used in the Par monad. -metropolisResult :: [Double] -> [Double] -- Target and alternate particles - -> Double -> Double -- z ~ g(z) and zc ~ rand - -> ([Double] -> Double) -- Target function - -> ([Double], Int) -- Result and accept counter -metropolisResult w0 w1 z zc target = - let val = target proposal - target w0 + (fromIntegral (length w0) - 1) * log z - proposal = zipWith (+) (map (*z) w0) (map (*(1-z)) w1) - in if zc <= min 1 (exp val) then (proposal, 1) else (w0, 0) -{-# INLINE metropolisResult #-} - --- | Execute Metropolis steps on the particles of a sub-ensemble by --- perturbing them with affine transformations based on particles --- in a complementary ensemble, in parallel. -executeMoves :: (Functor m, PrimMonad m) - => ([Double] -> Double) -- Target to sample - -> Ensemble -- Target sub-ensemble - -> Ensemble -- Complementary sub-ensemble - -> Int -- Size of the sub-ensembles - -> Gen (PrimState m) -- MWC PRNG - -> ViewsOptions m (Ensemble, Int) -- Updated ensemble and # of accepts -executeMoves t e0 e1 n g = do - Options _ _ _ _ csize <- ask - - zs <- replicateM n (lift $ symmetricVariate g) - zcs <- replicateM n (lift $ uniformR (0 :: Double, 1 :: Double) g) - js <- fmap U.fromList (replicateM n (lift $ uniformR (1:: Int, n) g)) - - let w0 k = e0 `V.unsafeIndex` (k - 1) - w1 k ks = e1 `V.unsafeIndex` ((ks `U.unsafeIndex` (k - 1)) - 1) - - result = runPar $ parMapChunk csize - (\(k, z, zc) -> metropolisResult (w0 k) (w1 k js) z zc t) - (zip3 [1..n] zs zcs) - (newstate, nacc) = (V.fromList . map fst &&& sum . map snd) result - - return (newstate, nacc) -{-# INLINE executeMoves #-} - --- | Perform a Metropolis accept/reject step on the ensemble by --- perturbing each element and accepting/rejecting the perturbation in --- parallel. -metropolisStep :: (Functor m, PrimMonad m) - => ([Double] -> Double) -- Target to sample - -> MarkovChain -- State of the Markov chain - -> Gen (PrimState m) -- MWC PRNG - -> ViewsOptions m MarkovChain -- Updated sub-ensemble -metropolisStep t state g = do - Options n _ _ _ _ <- ask - let n0 = truncate (fromIntegral n / (2 :: Double)) :: Int - (e, nacc) = (ensemble &&& accepts) state - (e0, e1) = (V.slice (0 :: Int) n0 &&& V.slice n0 n0) e - - -- Update each sub-ensemble - result0 <- executeMoves t e0 e1 n0 g - result1 <- executeMoves t e1 (fst result0) n0 g - - return $! - MarkovChain (V.concat $ map fst [result0, result1]) - (nacc + snd result0 + snd result1) -{-# INLINE metropolisStep #-} - --- | Diffuse through states. -runChain :: ([Double] -> Double) -- ^ Target to sample - -> Options -- ^ Options of the Markov chain - -> MarkovChain -- ^ Initial state of the Markov chain - -> Gen RealWorld -- ^ MWC PRNG - -> IO MarkovChain -- ^ End state of the Markov chain, wrapped in IO -runChain target opts initState g - | l == 0 - = error "runChain: ensemble must contain at least one particle" - | l < (length . V.head) (ensemble initState) - = do hPutStrLn stderr $ "runChain: ensemble should be twice as large as " - ++ "the target's dimension. Continuing anyway." - go opts nepochs thinEvery initState g - | burnIn < 0 || thinEvery < 0 = error "runChain: nonsensical burn-in or thinning input." - | otherwise = go opts nepochs thinEvery initState g - where - Options l nepochs burnIn thinEvery _ = opts - go o n t !c g0 | n == 0 = hPutStrLn stderr - (let nAcc = accepts c - total = nepochs * l * length (V.head $ ensemble c) - in show nAcc ++ " / " ++ show total ++ " (" ++ - show ((fromIntegral nAcc / fromIntegral total) :: Float) ++ - ") proposals accepted") >> return c - | n > (nepochs - burnIn) = do - r <- runReaderT (metropolisStep target c g0) o - go o (n - 1) t r g0 - | n `rem` t /= 0 = do - r <- runReaderT (metropolisStep target c g0) o - go o (n - 1) t r g0 - | otherwise = do - r <- runReaderT (metropolisStep target c g0) o - print r - go o (n - 1) t r g0 -{-# INLINE runChain #-} - --- | A convenience function to read and parse ensemble inits from disk. --- Assumes a text file with one particle per line, where each particle --- element is separated by whitespace. -readInits :: FilePath -> IO Ensemble -readInits p = fmap (V.fromList . map (map read . words) . lines) (readFile p) -{-# INLINE readInits #-} - diff --git a/README.md b/README.md @@ -2,5 +2,3 @@ Painless general-purpose sampling. -See the *Examples* folder for example usage. - diff --git a/Setup.hs b/Setup.hs @@ -0,0 +1,7 @@ +module Main (main) where + +import Distribution.Simple + +main :: IO () +main = defaultMain + diff --git a/flat-mcmc.cabal b/flat-mcmc.cabal @@ -1,21 +1,15 @@ --- Initial flat_mcmc.cabal generated by cabal init. For further --- documentation, see http://haskell.org/cabal/users-guide/ - name: flat-mcmc -version: 0.2.0.0 +version: 0.3.0.0 synopsis: Painless general-purpose sampling. --- description: homepage: http://jtobin.github.com/flat-mcmc --- license: license: BSD3 license-file: LICENSE author: Jared Tobin maintainer: jared@jtobin.ca --- copyright: -category: Numeric, Machine Learning, Statistics +category: Math build-type: Simple cabal-version: >=1.8 -Description: +description: Painless general-purpose sampling. @@ -24,8 +18,33 @@ Source-repository head Location: http://github.com/jtobin/flat-mcmc.git library - exposed-modules: Numeric.MCMC.Flat - -- other-modules: - build-depends: base >= 4.3, mtl >= 2.1, primitive >= 0.4, mwc-random >= 0.12, vector >= 0.9, monad-par >= 0.3, monad-par-extras >= 0.3 - ghc-options: -Wall + hs-source-dirs: + src + exposed-modules: + Numeric.MCMC.Flat + build-depends: + base >= 4.3 && < 5 + , mtl >= 2.1 + , primitive >= 0.4 + , mwc-random >= 0.12 + , mwc-probability >= 0.1 + , vector >= 0.9 + , monad-par >= 0.3 + , monad-par-extras >= 0.3 + ghc-options: + -Wall + +Test-Suite rosenbrock + type: exitcode-stdio-1.0 + hs-source-dirs: src, test + main-is: Tests.hs + build-depends: + base >= 4.3 + , mtl >= 2.1 + , primitive >= 0.4 + , mwc-random >= 0.12 + , mwc-probability >= 0.1 + , vector >= 0.9 + , monad-par >= 0.3 + , monad-par-extras >= 0.3 diff --git a/src/Numeric/MCMC/Flat.hs b/src/Numeric/MCMC/Flat.hs @@ -0,0 +1,142 @@ +{-# OPTIONS_GHC -Wall #-} +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Numeric.MCMC.Flat ( + Chain(..) + , Options(..) + , Ensemble + , Density + , flat + -- * System.Random.MWC + , sample + , withSystemRandom + , asGenIO + , asGenST + , create + , initialize + ) where + +import Control.Applicative +import Control.Arrow +import Control.Monad +import Control.Monad.Par (NFData) +import Control.Monad.Par.Scheds.Direct hiding (put, get) +import Control.Monad.Par.Combinator +import Control.Monad.Primitive +import Control.Monad.State.Strict +import Data.Vector (Vector) +import qualified Data.Vector as V +import qualified Data.Vector.Unboxed as U +import System.Random.MWC.Probability + +-- | The state of a Markov chain. +-- +-- The target function - assumed to be proportional to a log density - is +-- itself stateful, allowing for custom annealing schedules if you know what +-- you're doing. +data Chain = Chain { + logObjective :: Density + , ensemble :: Ensemble + , accepts :: !Int + , iterations :: !Int + } + +instance Show Chain where + show c = show . unlines . map (sanitize . show) $ us where + us = V.toList e + e = ensemble c + sanitize = filter (`notElem` "fromList []") + +-- | Parallelism granularity. +data Options = Options { + granularity :: !Int + } deriving (Eq, Show) + +-- | An ensemble of particles. A Markov chain is defined over the entire +-- ensemble, rather than individual particles. +type Ensemble = Vector Particle + +type Particle = U.Vector Double + +type Density = Particle -> Double + +-- | The flat-mcmc transition operator. Run a Markov chain with it by providing +-- an initial location (origin), a generator (gen), and using the usual +-- facilities from 'Control.Monad.State': +-- +-- > let chain = replicateM 5000 flat `evalStateT` origin +-- > trace <- sample chain gen +-- +flat :: PrimMonad m => StateT Chain (Prob m) Chain +flat = flatGranular $ Options 1 + +-- | The flat-mcmc transition operator with custom parallelism granularity. +flatGranular :: PrimMonad m => Options -> StateT Chain (Prob m) Chain +flatGranular (Options gran) = do + Chain target e nAccept epochs <- get + let n = truncate (fromIntegral (V.length e) / 2) + (e0, e1) = (V.slice 0 n &&& V.slice n n) e + + (e2, nAccept0) <- lift $ executeMoves target gran e0 e1 + (e3, nAccept1) <- lift $ executeMoves target gran e1 e2 + + put $! Chain { + logObjective = target + , ensemble = V.concat [e2, e3] + , accepts = nAccept + nAccept0 + nAccept1 + , iterations = succ epochs + } + + get + +parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b] +parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where + chunk _ [] = [] + chunk m ys = + let (as, bs) = splitAt m ys + in as : chunk m bs + +symmetric :: PrimMonad m => Prob m Double +symmetric = transform <$> uniform where + transform z = 0.5 * (z + 1) ^ (2 :: Int) + +stretch :: Particle -> Particle -> Double -> Particle +stretch particle altParticle z = + U.zipWith (+) (U.map (* z) particle) (U.map (* (1 - z)) altParticle) + +acceptProb :: Density -> Particle -> Particle -> Double -> Double +acceptProb target particle proposal z = + target proposal + - target particle + + log z * (fromIntegral (U.length particle) - 1) + +move :: Density -> Particle -> Particle -> Double -> Double -> (Particle, Int) +move target particle altParticle z zc = + let proposal = stretch particle altParticle z + pAccept = acceptProb target particle proposal z + in if zc <= min 0 pAccept + then (proposal, 1) -- move and count moves made + else (particle, 0) + +executeMoves + :: PrimMonad m + => Density + -> Int + -> Ensemble + -> Ensemble + -> Prob m (Ensemble, Int) +executeMoves target gran e0 e1 = do + let n = truncate $ fromIntegral (V.length e0 + V.length e1) / 2 + zs <- replicateM n symmetric + zcs <- replicateM n $ log <$> uniform + others <- replicateM n $ uniformR (0, n - 1) + + let particle j = e0 `V.unsafeIndex` j + altParticle j = e1 `V.unsafeIndex` (others !! j) + + moves = runPar $ parMapChunk gran + (\(j, z, zc) -> move target (particle j) (altParticle j) z zc) + (zip3 [0..n - 1] zs zcs) + + return $! (V.fromList . map fst &&& sum . map snd) moves + diff --git a/test/Tests.hs b/test/Tests.hs @@ -0,0 +1,37 @@ +{-# OPTIONS_GHC -fno-warn-type-defaults #-} + +module Main where + +import Control.Monad +import Control.Monad.State.Strict +import qualified Data.Vector as V +import qualified Data.Vector.Unboxed as U +import Numeric.MCMC.Flat +import System.Random.MWC.Probability + +lRosenbrock :: Density +lRosenbrock xs = + let [x0, x1] = U.toList xs + in (-1) * (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) + +defaultEnsemble :: Ensemble +defaultEnsemble = V.fromList $ map U.fromList + [[0.1, 0.5], [0.8, 0.1], [1.0, 0.2], [0.9, 0.8], [-0.2, 0.3], [-0.1, 0.9]] + +opts :: Options +opts = Options 10 + +origin :: Chain +origin = Chain { + logObjective = lRosenbrock + , ensemble = defaultEnsemble + , iterations = 0 + , accepts = 0 + } + +-- cabal test --show-details=streaming +main :: IO () +main = withSystemRandom . asGenIO $ \g -> do + trace <- sample (replicateM 5000 (flat opts) `evalStateT` origin) g + print trace +