commit fc999b22c23d28f841ae8bee26a01d077fb52d74
parent f820f3347723d2fca5266a49d446a65db656f5bd
Author: Jared Tobin <jared@jtobin.ca>
Date: Wed, 30 Mar 2016 17:53:16 +0700
Misc updates.
Diffstat:
3 files changed, 86 insertions(+), 120 deletions(-)
diff --git a/flat-mcmc.cabal b/flat-mcmc.cabal
@@ -24,11 +24,12 @@ library
exposed-modules: Numeric.MCMC.Flat
build-depends:
base < 5
- , lens >= 4 && < 5
- , mcmc-types >= 1.0.1
+ , mcmc-types >= 1.0.1 && < 2
, monad-par
- , mwc-probability >= 1.0.1
+ , monad-par-extras
+ , mwc-probability >= 1.0.1 && < 2
, pipes >= 4 && < 5
, primitive
, transformers
+ , vector
diff --git a/src/Numeric/MCMC/Flat.hs b/src/Numeric/MCMC/Flat.hs
@@ -1,143 +1,110 @@
-{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+{-# LANGUAGE RecordWildCards #-}
-module Numeric.MCMC.Flat (
- Chain(..)
- , Options(..)
- , Ensemble
- , Density
- , flat
- , flatGranular
- -- * System.Random.MWC
- , sample
- , withSystemRandom
- , asGenIO
- , asGenST
- , create
- , initialize
- ) where
-
-import Control.Applicative
-import Control.Arrow
-import Control.Monad
+module Numeric.MCMC.Flat (mcmc) where
+
+import Control.Monad (replicateM)
import Control.Monad.Par (NFData)
import Control.Monad.Par.Scheds.Direct hiding (put, get)
-import Control.Monad.Par.Combinator
-import Control.Monad.Primitive
-import Control.Monad.State.Strict
+import Control.Monad.Par.Combinator (parMap)
+import Control.Monad.Primitive (PrimMonad, PrimState, RealWorld)
+import Control.Monad.Trans.State.Strict (get, put, execStateT)
+import Data.Sampling.Types hiding (Chain(..))
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
-import System.Random.MWC.Probability
+import Pipes (Producer, lift, yield, runEffect, (>->))
+import qualified Pipes.Prelude as Pipes
+import System.Random.MWC.Probability as MWC
--- | The state of a Markov chain.
---
--- The target function - assumed to be proportional to a log density - is
--- itself stateful, allowing for custom annealing schedules if you know what
--- you're doing.
data Chain = Chain {
- logObjective :: Density
- , ensemble :: Ensemble
- , accepts :: !Int
- , iterations :: !Int
+ chainTarget :: Target Particle
+ , chainPosition :: !Ensemble
}
instance Show Chain where
- show c = show . unlines . map (sanitize . show) $ us where
- us = V.toList e
- e = ensemble c
- sanitize = filter (`notElem` "fromList []")
-
--- | Parallelism granularity.
-data Options = Options {
- granularity :: !Int
- } deriving (Eq, Show)
-
--- | An ensemble of particles. A Markov chain is defined over the entire
--- ensemble, rather than individual particles.
-type Ensemble = Vector Particle
-
-type Particle = U.Vector Double
-
-type Density = Particle -> Double
-
--- | The flat-mcmc transition operator. Run a Markov chain with it by providing
--- an initial location (origin), a generator (gen), and using the usual
--- facilities from 'Control.Monad.State':
---
--- > let chain = replicateM 5000 flat `evalStateT` origin
--- > trace <- sample chain gen
---
-flat :: PrimMonad m => StateT Chain (Prob m) Chain
-flat = flatGranular $ Options 1
-
--- | The flat-mcmc transition operator with custom parallelism granularity.
-flatGranular :: PrimMonad m => Options -> StateT Chain (Prob m) Chain
-flatGranular (Options gran) = do
- Chain target e nAccept epochs <- get
- let n = truncate (fromIntegral (V.length e) / 2)
- (e0, e1) = (V.slice 0 n &&& V.slice n n) e
-
- (e2, nAccept0) <- lift $ executeMoves target gran e0 e1
- (e3, nAccept1) <- lift $ executeMoves target gran e1 e2
+ show Chain {..} = show chainPosition -- FIXME better?
- put $! Chain {
- logObjective = target
- , ensemble = V.concat [e2, e3]
- , accepts = nAccept + nAccept0 + nAccept1
- , iterations = succ epochs
- }
+type Particle = Vector Double
- get
-
-parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b]
-parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where
- chunk _ [] = []
- chunk m ys =
- let (as, bs) = splitAt m ys
- in as : chunk m bs
+type Ensemble = Vector Particle
symmetric :: PrimMonad m => Prob m Double
-symmetric = transform <$> uniform where
+symmetric = fmap transform uniform where
transform z = 0.5 * (z + 1) ^ (2 :: Int)
stretch :: Particle -> Particle -> Double -> Particle
-stretch particle altParticle z =
- U.zipWith (+) (U.map (* z) particle) (U.map (* (1 - z)) altParticle)
+stretch p0 p1 z = V.zipWith (+) (V.map (* z) p0) (V.map (* (1 - z)) p1)
-acceptProb :: Density -> Particle -> Particle -> Double -> Double
+acceptProb :: Target Particle -> Particle -> Particle -> Double -> Double
acceptProb target particle proposal z =
- target proposal
- - target particle
- + log z * (fromIntegral (U.length particle) - 1)
-
-move :: Density -> Particle -> Particle -> Double -> Double -> (Particle, Int)
-move target particle altParticle z zc =
- let proposal = stretch particle altParticle z
- pAccept = acceptProb target particle proposal z
+ lTarget target proposal
+ - lTarget target particle
+ + log z * (fromIntegral (V.length particle) - 1)
+
+move :: Target Particle -> Particle -> Particle -> Double -> Double -> Particle
+move target p0 p1 z zc =
+ let proposal = stretch p0 p1 z
+ pAccept = acceptProb target p0 proposal z
in if zc <= min 0 pAccept
- then (proposal, 1) -- move and count moves made
- else (particle, 0)
+ then proposal
+ else p0
-executeMoves
+execute
:: PrimMonad m
- => Density
- -> Int
+ => Target Particle
-> Ensemble
-> Ensemble
- -> Prob m (Ensemble, Int)
-executeMoves target gran e0 e1 = do
- let n = truncate $ fromIntegral (V.length e0 + V.length e1) / 2
- zs <- replicateM n symmetric
- zcs <- replicateM n $ log <$> uniform
- others <- replicateM n $ uniformR (0, n - 1)
+ -> Int
+ -> Prob m Ensemble
+execute target e0 e1 n = do
+ zs <- replicateM n symmetric
+ zcs <- replicateM n uniform
+ vjs <- replicateM n uniform
- let particle j = e0 `V.unsafeIndex` j
- altParticle j = e1 `V.unsafeIndex` (others !! j)
+ let js = U.fromList vjs
+ w0 k = e0 `V.unsafeIndex` pred k
+ w1 k ks = e1 `V.unsafeIndex` pred (ks `U.unsafeIndex` pred k)
- moves = runPar $ parMapChunk gran
- (\(j, z, zc) -> move target (particle j) (altParticle j) z zc)
- (zip3 [0..n - 1] zs zcs)
+ worker (k, z, zc) = move target (w0 k) (w1 k js) z zc
+ result = runPar $
+ parMapChunk 2 worker (zip3 [1..n] zs zcs) -- FIXME granularity option
+
+ return $ V.fromList result
+
+flat
+ :: PrimMonad m
+ => Transition m Chain
+flat = do
+ Chain {..} <- get
+ let size = V.length chainPosition
+ n = truncate (fromIntegral size / 2)
+ e0 = V.slice 0 n chainPosition
+ e1 = V.slice n n chainPosition
+ result0 <- lift (execute chainTarget e0 e1 n)
+ result1 <- lift (execute chainTarget e1 result0 n)
+ let ensemble = V.concat [result0, result1]
+ put (Chain chainTarget ensemble)
+
+chain :: PrimMonad m => Chain -> Gen (PrimState m) -> Producer Chain m ()
+chain = loop where
+ loop state prng = do
+ next <- lift (MWC.sample (execStateT flat state) prng)
+ yield next
+ loop next prng
+
+mcmc :: Int -> Ensemble -> (Particle -> Double) -> Gen RealWorld -> IO ()
+mcmc n chainPosition target gen = runEffect $
+ chain Chain {..} gen
+ >-> Pipes.take n
+ >-> Pipes.mapM_ print
+ where
+ chainTarget = Target target Nothing
- return $! (V.fromList . map fst &&& sum . map snd) moves
+parMapChunk :: NFData b => Int -> (a -> b) -> [a] -> Par [b]
+parMapChunk n f xs = concat <$> parMap (map f) (chunk n xs) where
+ chunk _ [] = []
+ chunk m ys =
+ let (as, bs) = splitAt m ys
+ in as : chunk m bs
diff --git a/stack.yaml b/stack.yaml
@@ -1,7 +1,5 @@
flags: {}
packages:
- '.'
-extra-deps:
- - mwc-probability-1.0.2
- - mcmc-types-1.0.1
-resolver: lts-3.8
+extra-deps: []
+resolver: lts-5.2