commit 63fd2ebabc512a3ef5a3b13e3f25963864901de9
parent e763573b01a8bc4bc2d4aa42a28379a01d452eae
Author: Jared Tobin <jared@jtobin.ca>
Date: Mon, 7 Nov 2016 14:21:55 +1300
Misc performance improvements.
Diffstat:
2 files changed, 45 insertions(+), 15 deletions(-)
diff --git a/flat-mcmc.cabal b/flat-mcmc.cabal
@@ -1,5 +1,5 @@
name: flat-mcmc
-version: 1.1.1
+version: 1.2.1
synopsis: Painless general-purpose sampling.
homepage: http://jtobin.github.com/flat-mcmc
license: MIT
@@ -53,14 +53,16 @@ library
exposed-modules: Numeric.MCMC.Flat
build-depends:
base > 4 && < 6
+ , formatting >= 6 && < 7
, mcmc-types >= 1.0.1 && < 2
, monad-par >= 0.3.4.7 && < 1
, monad-par-extras >= 0.3.3 && < 1
, mwc-probability >= 1.0.1 && < 2
, pipes > 4 && < 5
, primitive
+ , text
, transformers
- , vector
+ , vector >= 0.10 && < 1
Test-suite rosenbrock
type: exitcode-stdio-1.0
diff --git a/lib/Numeric/MCMC/Flat.hs b/lib/Numeric/MCMC/Flat.hs
@@ -1,4 +1,5 @@
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
-- |
@@ -42,10 +43,15 @@ import Control.Monad.Par.Scheds.Direct hiding (put, get)
import Control.Monad.Par.Combinator (parMap)
import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld)
import Control.Monad.Trans.State.Strict (get, put, execStateT)
+import Data.Monoid
import Data.Sampling.Types as Sampling.Types hiding (Chain(..))
+import qualified Data.Text as T
+import qualified Data.Text.IO as T (putStrLn)
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
+import Formatting ((%))
+import qualified Formatting as F
import Pipes (Producer, lift, yield, runEffect, (>->))
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability as MWC
@@ -55,14 +61,26 @@ data Chain = Chain {
, chainPosition :: !Ensemble
}
-instance Show Chain where
- show Chain {..} =
- init
- . filter (`notElem` "[]")
- . unlines
- . V.toList
- . V.map show
- $ chainPosition
+-- | Render a Chain as a text value.
+render :: Chain -> T.Text
+render Chain {..} = renderEnsemble chainPosition
+{-# INLINE render #-}
+
+renderParticle :: Particle -> T.Text
+renderParticle =
+ T.drop 1
+ . U.foldl' glue mempty
+ where
+ glue = F.sformat (F.stext % "," % F.float)
+{-# INLINE renderParticle #-}
+
+renderEnsemble :: Ensemble -> T.Text
+renderEnsemble =
+ T.drop 1
+ . V.foldl' glue mempty
+ where
+ glue a b = a <> "\n" <> renderParticle b
+{-# INLINE renderEnsemble #-}
type Particle = U.Vector Double
@@ -71,15 +89,19 @@ type Ensemble = Vector Particle
symmetric :: PrimMonad m => Prob m Double
symmetric = fmap transform uniform where
transform z = 0.5 * (z + 1) ^ (2 :: Int)
+{-# INLINE symmetric #-}
stretch :: Particle -> Particle -> Double -> Particle
-stretch p0 p1 z = U.zipWith (+) (U.map (* z) p0) (U.map (* (1 - z)) p1)
+stretch p0 p1 z = U.zipWith str p0 p1 where
+ str x y = z * x + (1 - z) * y
+{-# INLINE stretch #-}
acceptProb :: Target Particle -> Particle -> Particle -> Double -> Double
acceptProb target particle proposal z =
lTarget target proposal
- lTarget target particle
+ log z * (fromIntegral (U.length particle) - 1)
+{-# INLINE acceptProb #-}
move :: Target Particle -> Particle -> Particle -> Double -> Double -> Particle
move target p0 p1 z zc =
@@ -88,6 +110,7 @@ move target p0 p1 z zc =
in if zc <= min 1 (exp pAccept)
then proposal
else p0
+{-# INLINE move #-}
execute
:: PrimMonad m
@@ -112,7 +135,8 @@ execute target e0 e1 n = do
result = runPar $
parMapChunk granularity worker (zip3 [1..n] zs zcs)
- return $ V.fromList result
+ return $! V.fromList result
+{-# INLINE execute #-}
-- | The 'flat' transition operator for driving a Markov chain over a space
-- of ensembles.
@@ -123,12 +147,13 @@ flat = do
Chain {..} <- get
let size = V.length chainPosition
n = truncate (fromIntegral size / 2)
- e0 = V.slice 0 n chainPosition
- e1 = V.slice n n chainPosition
+ e0 = V.unsafeSlice 0 n chainPosition
+ e1 = V.unsafeSlice n n chainPosition
result0 <- lift (execute chainTarget e0 e1 n)
result1 <- lift (execute chainTarget e1 result0 n)
let ensemble = V.concat [result0, result1]
put (Chain chainTarget ensemble)
+{-# INLINE flat #-}
chain :: PrimMonad m => Chain -> Gen (PrimState m) -> Producer Chain m ()
chain = loop where
@@ -136,6 +161,7 @@ chain = loop where
next <- lift (MWC.sample (execStateT flat state) prng)
yield next
loop next prng
+{-# INLINE chain #-}
-- | Trace 'n' iterations of a Markov chain and stream them to stdout.
--
@@ -169,9 +195,10 @@ mcmc :: Int -> Ensemble -> (Particle -> Double) -> Gen RealWorld -> IO ()
mcmc n chainPosition target gen = runEffect $
chain Chain {..} gen
>-> Pipes.take n
- >-> Pipes.mapM_ print
+ >-> Pipes.mapM_ (T.putStrLn . render)
where
chainTarget = Target target Nothing
+{-# INLINE mcmc #-}
-- A parallel map with the specified granularity.
parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b]
@@ -180,4 +207,5 @@ parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where
chunk m ys =
let (as, bs) = splitAt m ys
in as : chunk m bs
+{-# INLINE parMapChunk #-}