mighty-metropolis

The classic Metropolis algorithm.
Log | Files | Refs | README | LICENSE

commit c7d11968abdf85bcc01be17aaf2902ee19dd5e10
parent e76a1a8aae4cee5261c96289227bfef48ffdaec5
Author: Jared Tobin <jared@jtobin.ca>
Date:   Tue,  6 Oct 2015 21:08:26 +1300

Add examples.

Diffstat:
MLICENSE | 2--
MNumeric/MCMC/Metropolis.hs | 37+++++++++++++++++--------------------
Mmighty-metropolis.cabal | 3+--
Mstack.yaml | 1+
Atest/BNN.hs | 11+++++++++++
Mtest/Rosenbrock.hs | 8++------
6 files changed, 32 insertions(+), 30 deletions(-)

diff --git a/LICENSE b/LICENSE @@ -1,5 +1,3 @@ -The MIT License (MIT) - Copyright (c) 2012-2015 Jared Tobin Permission is hereby granted, free of charge, to any person obtaining a copy diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs @@ -1,33 +1,29 @@ {-# OPTIONS_GHC -Wall #-} {-# LANGUAGE RecordWildCards #-} -module Numeric.MCMC.Metropolis (mcmc, metropolis) where +module Numeric.MCMC.Metropolis ( + mcmc + , metropolis + + -- * re-export + , module Data.Sampling.Types + , MWC.create + , MWC.createSystemRandom + , MWC.withSystemRandom + , MWC.asGenIO + ) where import Control.Monad (when) import Control.Monad.Primitive (PrimMonad, PrimState) import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.State.Strict (StateT, execStateT, get, put) +import Control.Monad.Trans.State.Strict (execStateT, get, put) +import Data.Sampling.Types (Target(..), Chain(..), Transition) import GHC.Prim (RealWorld) import Pipes (Producer, yield, (>->), runEffect) import qualified Pipes.Prelude as Pipes (mapM_, take) import System.Random.MWC.Probability (Gen, Prob) import qualified System.Random.MWC.Probability as MWC --- | A transition operator. -type Transition m a = StateT a (Prob m) () - --- | The @Chain@ type specifies the state of a Markov chain at any given --- iteration. -data Chain a b = Chain { - chainTarget :: a -> Double - , chainScore :: !Double - , chainPosition :: a - , chainTunables :: Maybe b - } - -instance Show a => Show (Chain a b) where - show Chain {..} = filter (`notElem` "fromList []") (show chainPosition) - -- | Propose a state transition according to a Gaussian proposal distribution -- with the specified standard deviation. propose @@ -46,7 +42,7 @@ metropolis metropolis radial = do Chain {..} <- get proposal <- lift (propose radial chainPosition) - let proposalScore = chainTarget proposal + let proposalScore = lTarget chainTarget proposal acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore))) accept <- lift (MWC.bernoulli acceptProb) @@ -74,13 +70,14 @@ mcmc -> (f Double -> Double) -> Gen RealWorld -> IO () -mcmc n radial chainPosition chainTarget gen = runEffect $ +mcmc n radial chainPosition target gen = runEffect $ chain radial Chain {..} gen >-> Pipes.take n >-> Pipes.mapM_ print where - chainScore = chainTarget chainPosition + chainScore = lTarget chainTarget chainPosition chainTunables = Nothing + chainTarget = Target target Nothing -- | Use a provided default value when the argument is NaN. whenNaN :: RealFloat a => a -> a -> a diff --git a/mighty-metropolis.cabal b/mighty-metropolis.cabal @@ -25,8 +25,8 @@ library base , pipes , primitive + , mcmc-types , mwc-probability - , vector Test-suite rosenbrock type: exitcode-stdio-1.0 @@ -38,5 +38,4 @@ Test-suite rosenbrock build-depends: base , mwc-probability - , vector diff --git a/stack.yaml b/stack.yaml @@ -1,5 +1,6 @@ flags: {} packages: - ../mwc-probability + - ../mcmc-types extra-deps: [] resolver: lts-3.3 diff --git a/test/BNN.hs b/test/BNN.hs @@ -0,0 +1,11 @@ + +module Main where + +import Numeric.MCMC.Metropolis + +bnn :: [Double] -> Double +bnn [x0, x1] = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) + +main :: IO () +main = withSystemRandom . asGenIO $ mcmc 10000 1 [0, 0] bnn + diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs @@ -2,14 +2,10 @@ module Main where import Numeric.MCMC.Metropolis -import qualified System.Random.MWC.Probability as MWC rosenbrock :: [Double] -> Double -rosenbrock xs = (-1)*(5*(x1 - x0^2)^2 + 0.05*(1 - x0)^2) where - x0 = head xs - x1 = xs !! 1 +rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) main :: IO () -main = MWC.withSystemRandom . MWC.asGenIO $ - mcmc 10000 1 [0, 0] rosenbrock +main = withSystemRandom . asGenIO $ mcmc 10000 1 [0, 0] rosenbrock