mighty-metropolis

The classic Metropolis algorithm.
Log | Files | Refs | README | LICENSE

commit e76a1a8aae4cee5261c96289227bfef48ffdaec5
parent 13d1a785c773373c32ca01f11e27e210b4c6a401
Author: Jared Tobin <jared@jtobin.ca>
Date:   Tue,  6 Oct 2015 15:48:39 +1300

Generalize to any Foldable.

Diffstat:
MNumeric/MCMC/Metropolis.hs | 46+++++++++++++++++++++++-----------------------
Mtest/Rosenbrock.hs | 12++++++------
2 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs @@ -1,6 +1,5 @@ {-# OPTIONS_GHC -Wall #-} {-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ViewPatterns #-} module Numeric.MCMC.Metropolis (mcmc, metropolis) where @@ -8,8 +7,6 @@ import Control.Monad (when) import Control.Monad.Primitive (PrimMonad, PrimState) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict (StateT, execStateT, get, put) -import Data.Vector.Unboxed (Vector) -import qualified Data.Vector.Unboxed as Vector (mapM, fromList) import GHC.Prim (RealWorld) import Pipes (Producer, yield, (>->), runEffect) import qualified Pipes.Prelude as Pipes (mapM_, take) @@ -17,34 +14,35 @@ import System.Random.MWC.Probability (Gen, Prob) import qualified System.Random.MWC.Probability as MWC -- | A transition operator. -type Transition m = StateT Chain (Prob m) () +type Transition m a = StateT a (Prob m) () -- | The @Chain@ type specifies the state of a Markov chain at any given -- iteration. -data Chain = Chain { - chainTarget :: Vector Double -> Double +data Chain a b = Chain { + chainTarget :: a -> Double , chainScore :: !Double - , chainPosition :: !(Vector Double) + , chainPosition :: a + , chainTunables :: Maybe b } -instance Show Chain where +instance Show a => Show (Chain a b) where show Chain {..} = filter (`notElem` "fromList []") (show chainPosition) -- | Propose a state transition according to a Gaussian proposal distribution -- with the specified standard deviation. propose - :: PrimMonad m + :: (PrimMonad m, Traversable f) => Double - -> Vector Double - -> Prob m (Vector Double) -propose radial = Vector.mapM perturb where + -> f Double + -> Prob m (f Double) +propose radial = traverse perturb where perturb m = MWC.normal m radial -- | A Metropolis transition operator. metropolis - :: PrimMonad m + :: (Traversable f, PrimMonad m) => Double - -> Transition m + -> Transition m (Chain (f Double) b) metropolis radial = do Chain {..} <- get proposal <- lift (propose radial chainPosition) @@ -52,15 +50,15 @@ metropolis radial = do acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore))) accept <- lift (MWC.bernoulli acceptProb) - when accept (put (Chain chainTarget proposalScore proposal)) + when accept (put (Chain chainTarget proposalScore proposal chainTunables)) -- | A Markov chain. chain - :: PrimMonad m + :: (Traversable f, PrimMonad m) => Double - -> Chain + -> Chain (f Double) b -> Gen (PrimState m) - -> Producer Chain m () + -> Producer (Chain (f Double) b) m () chain radial = loop where loop state prng = do next <- lift (MWC.sample (execStateT (metropolis radial) state) prng) @@ -69,18 +67,20 @@ chain radial = loop where -- | Trace 'n' iterations of a Markov chain. mcmc - :: Int + :: (Traversable f, Show (f Double)) + => Int -> Double - -> [Double] - -> (Vector Double -> Double) + -> f Double + -> (f Double -> Double) -> Gen RealWorld -> IO () -mcmc n radial (Vector.fromList -> chainPosition) chainTarget gen = runEffect $ +mcmc n radial chainPosition chainTarget gen = runEffect $ chain radial Chain {..} gen >-> Pipes.take n >-> Pipes.mapM_ print where - chainScore = chainTarget chainPosition + chainScore = chainTarget chainPosition + chainTunables = Nothing -- | Use a provided default value when the argument is NaN. whenNaN :: RealFloat a => a -> a -> a diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs @@ -1,15 +1,15 @@ -module Rosenbrock where +module Main where import Numeric.MCMC.Metropolis import qualified System.Random.MWC.Probability as MWC -import Data.Vector.Unboxed (Vector, unsafeIndex) -rosenbrock :: Vector Double -> Double +rosenbrock :: [Double] -> Double rosenbrock xs = (-1)*(5*(x1 - x0^2)^2 + 0.05*(1 - x0)^2) where - x0 = unsafeIndex xs 0 - x1 = unsafeIndex xs 1 + x0 = head xs + x1 = xs !! 1 main :: IO () -main = MWC.withSystemRandom . MWC.asGenIO $ mcmc 100000 1 [0, 0] rosenbrock +main = MWC.withSystemRandom . MWC.asGenIO $ + mcmc 10000 1 [0, 0] rosenbrock