flat-mcmc

Painless, efficient, general-purpose sampling from continuous distributions.
Log | Files | Refs | README | LICENSE

commit fc999b22c23d28f841ae8bee26a01d077fb52d74
parent f820f3347723d2fca5266a49d446a65db656f5bd
Author: Jared Tobin <jared@jtobin.ca>
Date:   Wed, 30 Mar 2016 17:53:16 +0700

Misc updates.

Diffstat:
Mflat-mcmc.cabal | 7++++---
Msrc/Numeric/MCMC/Flat.hs | 193+++++++++++++++++++++++++++++++++----------------------------------------------
Mstack.yaml | 6++----
3 files changed, 86 insertions(+), 120 deletions(-)

diff --git a/flat-mcmc.cabal b/flat-mcmc.cabal @@ -24,11 +24,12 @@ library exposed-modules: Numeric.MCMC.Flat build-depends: base < 5 - , lens >= 4 && < 5 - , mcmc-types >= 1.0.1 + , mcmc-types >= 1.0.1 && < 2 , monad-par - , mwc-probability >= 1.0.1 + , monad-par-extras + , mwc-probability >= 1.0.1 && < 2 , pipes >= 4 && < 5 , primitive , transformers + , vector diff --git a/src/Numeric/MCMC/Flat.hs b/src/Numeric/MCMC/Flat.hs @@ -1,143 +1,110 @@ -{-# OPTIONS_GHC -Wall #-} {-# OPTIONS_GHC -fno-warn-type-defaults #-} +{-# LANGUAGE RecordWildCards #-} -module Numeric.MCMC.Flat ( - Chain(..) - , Options(..) - , Ensemble - , Density - , flat - , flatGranular - -- * System.Random.MWC - , sample - , withSystemRandom - , asGenIO - , asGenST - , create - , initialize - ) where - -import Control.Applicative -import Control.Arrow -import Control.Monad +module Numeric.MCMC.Flat (mcmc) where + +import Control.Monad (replicateM) import Control.Monad.Par (NFData) import Control.Monad.Par.Scheds.Direct hiding (put, get) -import Control.Monad.Par.Combinator -import Control.Monad.Primitive -import Control.Monad.State.Strict +import Control.Monad.Par.Combinator (parMap) +import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld) +import Control.Monad.Trans.State.Strict (get, put, execStateT) +import Data.Sampling.Types hiding (Chain(..)) import Data.Vector (Vector) import qualified Data.Vector as V import qualified Data.Vector.Unboxed as U -import System.Random.MWC.Probability +import Pipes (Producer, lift, yield, runEffect, (>->)) +import qualified Pipes.Prelude as Pipes +import System.Random.MWC.Probability as MWC --- | The state of a Markov chain. --- --- The target function - assumed to be proportional to a log density - is --- itself stateful, allowing for custom annealing schedules if you know what --- you're doing. data Chain = Chain { - logObjective :: Density - , ensemble :: Ensemble - , accepts :: !Int - , iterations :: !Int + chainTarget :: Target Particle + , chainPosition :: !Ensemble } instance Show Chain where - show c = show . unlines . map (sanitize . show) $ us where - us = V.toList e - e = ensemble c - sanitize = filter (`notElem` "fromList []") - --- | Parallelism granularity. -data Options = Options { - granularity :: !Int - } deriving (Eq, Show) - --- | An ensemble of particles. A Markov chain is defined over the entire --- ensemble, rather than individual particles. -type Ensemble = Vector Particle - -type Particle = U.Vector Double - -type Density = Particle -> Double - --- | The flat-mcmc transition operator. Run a Markov chain with it by providing --- an initial location (origin), a generator (gen), and using the usual --- facilities from 'Control.Monad.State': --- --- > let chain = replicateM 5000 flat `evalStateT` origin --- > trace <- sample chain gen --- -flat :: PrimMonad m => StateT Chain (Prob m) Chain -flat = flatGranular $ Options 1 - --- | The flat-mcmc transition operator with custom parallelism granularity. -flatGranular :: PrimMonad m => Options -> StateT Chain (Prob m) Chain -flatGranular (Options gran) = do - Chain target e nAccept epochs <- get - let n = truncate (fromIntegral (V.length e) / 2) - (e0, e1) = (V.slice 0 n &&& V.slice n n) e - - (e2, nAccept0) <- lift $ executeMoves target gran e0 e1 - (e3, nAccept1) <- lift $ executeMoves target gran e1 e2 + show Chain {..} = show chainPosition -- FIXME better? - put $! Chain { - logObjective = target - , ensemble = V.concat [e2, e3] - , accepts = nAccept + nAccept0 + nAccept1 - , iterations = succ epochs - } +type Particle = Vector Double - get - -parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b] -parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where - chunk _ [] = [] - chunk m ys = - let (as, bs) = splitAt m ys - in as : chunk m bs +type Ensemble = Vector Particle symmetric :: PrimMonad m => Prob m Double -symmetric = transform <$> uniform where +symmetric = fmap transform uniform where transform z = 0.5 * (z + 1) ^ (2 :: Int) stretch :: Particle -> Particle -> Double -> Particle -stretch particle altParticle z = - U.zipWith (+) (U.map (* z) particle) (U.map (* (1 - z)) altParticle) +stretch p0 p1 z = V.zipWith (+) (V.map (* z) p0) (V.map (* (1 - z)) p1) -acceptProb :: Density -> Particle -> Particle -> Double -> Double +acceptProb :: Target Particle -> Particle -> Particle -> Double -> Double acceptProb target particle proposal z = - target proposal - - target particle - + log z * (fromIntegral (U.length particle) - 1) - -move :: Density -> Particle -> Particle -> Double -> Double -> (Particle, Int) -move target particle altParticle z zc = - let proposal = stretch particle altParticle z - pAccept = acceptProb target particle proposal z + lTarget target proposal + - lTarget target particle + + log z * (fromIntegral (V.length particle) - 1) + +move :: Target Particle -> Particle -> Particle -> Double -> Double -> Particle +move target p0 p1 z zc = + let proposal = stretch p0 p1 z + pAccept = acceptProb target p0 proposal z in if zc <= min 0 pAccept - then (proposal, 1) -- move and count moves made - else (particle, 0) + then proposal + else p0 -executeMoves +execute :: PrimMonad m - => Density - -> Int + => Target Particle -> Ensemble -> Ensemble - -> Prob m (Ensemble, Int) -executeMoves target gran e0 e1 = do - let n = truncate $ fromIntegral (V.length e0 + V.length e1) / 2 - zs <- replicateM n symmetric - zcs <- replicateM n $ log <$> uniform - others <- replicateM n $ uniformR (0, n - 1) + -> Int + -> Prob m Ensemble +execute target e0 e1 n = do + zs <- replicateM n symmetric + zcs <- replicateM n uniform + vjs <- replicateM n uniform - let particle j = e0 `V.unsafeIndex` j - altParticle j = e1 `V.unsafeIndex` (others !! j) + let js = U.fromList vjs + w0 k = e0 `V.unsafeIndex` pred k + w1 k ks = e1 `V.unsafeIndex` pred (ks `U.unsafeIndex` pred k) - moves = runPar $ parMapChunk gran - (\(j, z, zc) -> move target (particle j) (altParticle j) z zc) - (zip3 [0..n - 1] zs zcs) + worker (k, z, zc) = move target (w0 k) (w1 k js) z zc + result = runPar $ + parMapChunk 2 worker (zip3 [1..n] zs zcs) -- FIXME granularity option + + return $ V.fromList result + +flat + :: PrimMonad m + => Transition m Chain +flat = do + Chain {..} <- get + let size = V.length chainPosition + n = truncate (fromIntegral size / 2) + e0 = V.slice 0 n chainPosition + e1 = V.slice n n chainPosition + result0 <- lift (execute chainTarget e0 e1 n) + result1 <- lift (execute chainTarget e1 result0 n) + let ensemble = V.concat [result0, result1] + put (Chain chainTarget ensemble) + +chain :: PrimMonad m => Chain -> Gen (PrimState m) -> Producer Chain m () +chain = loop where + loop state prng = do + next <- lift (MWC.sample (execStateT flat state) prng) + yield next + loop next prng + +mcmc :: Int -> Ensemble -> (Particle -> Double) -> Gen RealWorld -> IO () +mcmc n chainPosition target gen = runEffect $ + chain Chain {..} gen + >-> Pipes.take n + >-> Pipes.mapM_ print + where + chainTarget = Target target Nothing - return $! (V.fromList . map fst &&& sum . map snd) moves +parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b] +parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where + chunk _ [] = [] + chunk m ys = + let (as, bs) = splitAt m ys + in as : chunk m bs diff --git a/stack.yaml b/stack.yaml @@ -1,7 +1,5 @@ flags: {} packages: - '.' -extra-deps: - - mwc-probability-1.0.2 - - mcmc-types-1.0.1 -resolver: lts-3.8 +extra-deps: [] +resolver: lts-5.2