flat-mcmc

Painless, efficient, general-purpose sampling from continuous distributions.
Log | Files | Refs | README | LICENSE

commit 63fd2ebabc512a3ef5a3b13e3f25963864901de9
parent e763573b01a8bc4bc2d4aa42a28379a01d452eae
Author: Jared Tobin <jared@jtobin.ca>
Date:   Mon,  7 Nov 2016 14:21:55 +1300

Misc performance improvements.

Diffstat:
Mflat-mcmc.cabal | 6++++--
Mlib/Numeric/MCMC/Flat.hs | 54+++++++++++++++++++++++++++++++++++++++++-------------
2 files changed, 45 insertions(+), 15 deletions(-)

diff --git a/flat-mcmc.cabal b/flat-mcmc.cabal @@ -1,5 +1,5 @@ name: flat-mcmc -version: 1.1.1 +version: 1.2.1 synopsis: Painless general-purpose sampling. homepage: http://jtobin.github.com/flat-mcmc license: MIT @@ -53,14 +53,16 @@ library exposed-modules: Numeric.MCMC.Flat build-depends: base > 4 && < 6 + , formatting >= 6 && < 7 , mcmc-types >= 1.0.1 && < 2 , monad-par >= 0.3.4.7 && < 1 , monad-par-extras >= 0.3.3 && < 1 , mwc-probability >= 1.0.1 && < 2 , pipes > 4 && < 5 , primitive + , text , transformers - , vector + , vector >= 0.10 && < 1 Test-suite rosenbrock type: exitcode-stdio-1.0 diff --git a/lib/Numeric/MCMC/Flat.hs b/lib/Numeric/MCMC/Flat.hs @@ -1,4 +1,5 @@ {-# OPTIONS_GHC -fno-warn-type-defaults #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} -- | @@ -42,10 +43,15 @@ import Control.Monad.Par.Scheds.Direct hiding (put, get) import Control.Monad.Par.Combinator (parMap) import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld) import Control.Monad.Trans.State.Strict (get, put, execStateT) +import Data.Monoid import Data.Sampling.Types as Sampling.Types hiding (Chain(..)) +import qualified Data.Text as T +import qualified Data.Text.IO as T (putStrLn) import Data.Vector (Vector) import qualified Data.Vector as V import qualified Data.Vector.Unboxed as U +import Formatting ((%)) +import qualified Formatting as F import Pipes (Producer, lift, yield, runEffect, (>->)) import qualified Pipes.Prelude as Pipes import System.Random.MWC.Probability as MWC @@ -55,14 +61,26 @@ data Chain = Chain { , chainPosition :: !Ensemble } -instance Show Chain where - show Chain {..} = - init - . filter (`notElem` "[]") - . unlines - . V.toList - . V.map show - $ chainPosition +-- | Render a Chain as a text value. +render :: Chain -> T.Text +render Chain {..} = renderEnsemble chainPosition +{-# INLINE render #-} + +renderParticle :: Particle -> T.Text +renderParticle = + T.drop 1 + . U.foldl' glue mempty + where + glue = F.sformat (F.stext % "," % F.float) +{-# INLINE renderParticle #-} + +renderEnsemble :: Ensemble -> T.Text +renderEnsemble = + T.drop 1 + . V.foldl' glue mempty + where + glue a b = a <> "\n" <> renderParticle b +{-# INLINE renderEnsemble #-} type Particle = U.Vector Double @@ -71,15 +89,19 @@ type Ensemble = Vector Particle symmetric :: PrimMonad m => Prob m Double symmetric = fmap transform uniform where transform z = 0.5 * (z + 1) ^ (2 :: Int) +{-# INLINE symmetric #-} stretch :: Particle -> Particle -> Double -> Particle -stretch p0 p1 z = U.zipWith (+) (U.map (* z) p0) (U.map (* (1 - z)) p1) +stretch p0 p1 z = U.zipWith str p0 p1 where + str x y = z * x + (1 - z) * y +{-# INLINE stretch #-} acceptProb :: Target Particle -> Particle -> Particle -> Double -> Double acceptProb target particle proposal z = lTarget target proposal - lTarget target particle + log z * (fromIntegral (U.length particle) - 1) +{-# INLINE acceptProb #-} move :: Target Particle -> Particle -> Particle -> Double -> Double -> Particle move target p0 p1 z zc = @@ -88,6 +110,7 @@ move target p0 p1 z zc = in if zc <= min 1 (exp pAccept) then proposal else p0 +{-# INLINE move #-} execute :: PrimMonad m @@ -112,7 +135,8 @@ execute target e0 e1 n = do result = runPar $ parMapChunk granularity worker (zip3 [1..n] zs zcs) - return $ V.fromList result + return $! V.fromList result +{-# INLINE execute #-} -- | The 'flat' transition operator for driving a Markov chain over a space -- of ensembles. @@ -123,12 +147,13 @@ flat = do Chain {..} <- get let size = V.length chainPosition n = truncate (fromIntegral size / 2) - e0 = V.slice 0 n chainPosition - e1 = V.slice n n chainPosition + e0 = V.unsafeSlice 0 n chainPosition + e1 = V.unsafeSlice n n chainPosition result0 <- lift (execute chainTarget e0 e1 n) result1 <- lift (execute chainTarget e1 result0 n) let ensemble = V.concat [result0, result1] put (Chain chainTarget ensemble) +{-# INLINE flat #-} chain :: PrimMonad m => Chain -> Gen (PrimState m) -> Producer Chain m () chain = loop where @@ -136,6 +161,7 @@ chain = loop where next <- lift (MWC.sample (execStateT flat state) prng) yield next loop next prng +{-# INLINE chain #-} -- | Trace 'n' iterations of a Markov chain and stream them to stdout. -- @@ -169,9 +195,10 @@ mcmc :: Int -> Ensemble -> (Particle -> Double) -> Gen RealWorld -> IO () mcmc n chainPosition target gen = runEffect $ chain Chain {..} gen >-> Pipes.take n - >-> Pipes.mapM_ print + >-> Pipes.mapM_ (T.putStrLn . render) where chainTarget = Target target Nothing +{-# INLINE mcmc #-} -- A parallel map with the specified granularity. parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b] @@ -180,4 +207,5 @@ parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where chunk m ys = let (as, bs) = splitAt m ys in as : chunk m bs +{-# INLINE parMapChunk #-}