commit c7d11968abdf85bcc01be17aaf2902ee19dd5e10
parent e76a1a8aae4cee5261c96289227bfef48ffdaec5
Author: Jared Tobin <jared@jtobin.ca>
Date: Tue, 6 Oct 2015 21:08:26 +1300
Add examples.
Diffstat:
6 files changed, 32 insertions(+), 30 deletions(-)
diff --git a/LICENSE b/LICENSE
@@ -1,5 +1,3 @@
-The MIT License (MIT)
-
Copyright (c) 2012-2015 Jared Tobin
Permission is hereby granted, free of charge, to any person obtaining a copy
diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs
@@ -1,33 +1,29 @@
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE RecordWildCards #-}
-module Numeric.MCMC.Metropolis (mcmc, metropolis) where
+module Numeric.MCMC.Metropolis (
+ mcmc
+ , metropolis
+
+ -- * re-export
+ , module Data.Sampling.Types
+ , MWC.create
+ , MWC.createSystemRandom
+ , MWC.withSystemRandom
+ , MWC.asGenIO
+ ) where
import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.Trans.Class (lift)
-import Control.Monad.Trans.State.Strict (StateT, execStateT, get, put)
+import Control.Monad.Trans.State.Strict (execStateT, get, put)
+import Data.Sampling.Types (Target(..), Chain(..), Transition)
import GHC.Prim (RealWorld)
import Pipes (Producer, yield, (>->), runEffect)
import qualified Pipes.Prelude as Pipes (mapM_, take)
import System.Random.MWC.Probability (Gen, Prob)
import qualified System.Random.MWC.Probability as MWC
--- | A transition operator.
-type Transition m a = StateT a (Prob m) ()
-
--- | The @Chain@ type specifies the state of a Markov chain at any given
--- iteration.
-data Chain a b = Chain {
- chainTarget :: a -> Double
- , chainScore :: !Double
- , chainPosition :: a
- , chainTunables :: Maybe b
- }
-
-instance Show a => Show (Chain a b) where
- show Chain {..} = filter (`notElem` "fromList []") (show chainPosition)
-
-- | Propose a state transition according to a Gaussian proposal distribution
-- with the specified standard deviation.
propose
@@ -46,7 +42,7 @@ metropolis
metropolis radial = do
Chain {..} <- get
proposal <- lift (propose radial chainPosition)
- let proposalScore = chainTarget proposal
+ let proposalScore = lTarget chainTarget proposal
acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore)))
accept <- lift (MWC.bernoulli acceptProb)
@@ -74,13 +70,14 @@ mcmc
-> (f Double -> Double)
-> Gen RealWorld
-> IO ()
-mcmc n radial chainPosition chainTarget gen = runEffect $
+mcmc n radial chainPosition target gen = runEffect $
chain radial Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ print
where
- chainScore = chainTarget chainPosition
+ chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
+ chainTarget = Target target Nothing
-- | Use a provided default value when the argument is NaN.
whenNaN :: RealFloat a => a -> a -> a
diff --git a/mighty-metropolis.cabal b/mighty-metropolis.cabal
@@ -25,8 +25,8 @@ library
base
, pipes
, primitive
+ , mcmc-types
, mwc-probability
- , vector
Test-suite rosenbrock
type: exitcode-stdio-1.0
@@ -38,5 +38,4 @@ Test-suite rosenbrock
build-depends:
base
, mwc-probability
- , vector
diff --git a/stack.yaml b/stack.yaml
@@ -1,5 +1,6 @@
flags: {}
packages:
- ../mwc-probability
+ - ../mcmc-types
extra-deps: []
resolver: lts-3.3
diff --git a/test/BNN.hs b/test/BNN.hs
@@ -0,0 +1,11 @@
+
+module Main where
+
+import Numeric.MCMC.Metropolis
+
+bnn :: [Double] -> Double
+bnn [x0, x1] = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1)
+
+main :: IO ()
+main = withSystemRandom . asGenIO $ mcmc 10000 1 [0, 0] bnn
+
diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs
@@ -2,14 +2,10 @@
module Main where
import Numeric.MCMC.Metropolis
-import qualified System.Random.MWC.Probability as MWC
rosenbrock :: [Double] -> Double
-rosenbrock xs = (-1)*(5*(x1 - x0^2)^2 + 0.05*(1 - x0)^2) where
- x0 = head xs
- x1 = xs !! 1
+rosenbrock [x0, x1] = negate (5 *(x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
main :: IO ()
-main = MWC.withSystemRandom . MWC.asGenIO $
- mcmc 10000 1 [0, 0] rosenbrock
+main = withSystemRandom . asGenIO $ mcmc 10000 1 [0, 0] rosenbrock