commit e76a1a8aae4cee5261c96289227bfef48ffdaec5
parent 13d1a785c773373c32ca01f11e27e210b4c6a401
Author: Jared Tobin <jared@jtobin.ca>
Date: Tue, 6 Oct 2015 15:48:39 +1300
Generalize to any Foldable.
Diffstat:
2 files changed, 29 insertions(+), 29 deletions(-)
diff --git a/Numeric/MCMC/Metropolis.hs b/Numeric/MCMC/Metropolis.hs
@@ -1,6 +1,5 @@
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE RecordWildCards #-}
-{-# LANGUAGE ViewPatterns #-}
module Numeric.MCMC.Metropolis (mcmc, metropolis) where
@@ -8,8 +7,6 @@ import Control.Monad (when)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT, execStateT, get, put)
-import Data.Vector.Unboxed (Vector)
-import qualified Data.Vector.Unboxed as Vector (mapM, fromList)
import GHC.Prim (RealWorld)
import Pipes (Producer, yield, (>->), runEffect)
import qualified Pipes.Prelude as Pipes (mapM_, take)
@@ -17,34 +14,35 @@ import System.Random.MWC.Probability (Gen, Prob)
import qualified System.Random.MWC.Probability as MWC
-- | A transition operator.
-type Transition m = StateT Chain (Prob m) ()
+type Transition m a = StateT a (Prob m) ()
-- | The @Chain@ type specifies the state of a Markov chain at any given
-- iteration.
-data Chain = Chain {
- chainTarget :: Vector Double -> Double
+data Chain a b = Chain {
+ chainTarget :: a -> Double
, chainScore :: !Double
- , chainPosition :: !(Vector Double)
+ , chainPosition :: a
+ , chainTunables :: Maybe b
}
-instance Show Chain where
+instance Show a => Show (Chain a b) where
show Chain {..} = filter (`notElem` "fromList []") (show chainPosition)
-- | Propose a state transition according to a Gaussian proposal distribution
-- with the specified standard deviation.
propose
- :: PrimMonad m
+ :: (PrimMonad m, Traversable f)
=> Double
- -> Vector Double
- -> Prob m (Vector Double)
-propose radial = Vector.mapM perturb where
+ -> f Double
+ -> Prob m (f Double)
+propose radial = traverse perturb where
perturb m = MWC.normal m radial
-- | A Metropolis transition operator.
metropolis
- :: PrimMonad m
+ :: (Traversable f, PrimMonad m)
=> Double
- -> Transition m
+ -> Transition m (Chain (f Double) b)
metropolis radial = do
Chain {..} <- get
proposal <- lift (propose radial chainPosition)
@@ -52,15 +50,15 @@ metropolis radial = do
acceptProb = whenNaN 0 (exp (min 0 (proposalScore - chainScore)))
accept <- lift (MWC.bernoulli acceptProb)
- when accept (put (Chain chainTarget proposalScore proposal))
+ when accept (put (Chain chainTarget proposalScore proposal chainTunables))
-- | A Markov chain.
chain
- :: PrimMonad m
+ :: (Traversable f, PrimMonad m)
=> Double
- -> Chain
+ -> Chain (f Double) b
-> Gen (PrimState m)
- -> Producer Chain m ()
+ -> Producer (Chain (f Double) b) m ()
chain radial = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (metropolis radial) state) prng)
@@ -69,18 +67,20 @@ chain radial = loop where
-- | Trace 'n' iterations of a Markov chain.
mcmc
- :: Int
+ :: (Traversable f, Show (f Double))
+ => Int
-> Double
- -> [Double]
- -> (Vector Double -> Double)
+ -> f Double
+ -> (f Double -> Double)
-> Gen RealWorld
-> IO ()
-mcmc n radial (Vector.fromList -> chainPosition) chainTarget gen = runEffect $
+mcmc n radial chainPosition chainTarget gen = runEffect $
chain radial Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ print
where
- chainScore = chainTarget chainPosition
+ chainScore = chainTarget chainPosition
+ chainTunables = Nothing
-- | Use a provided default value when the argument is NaN.
whenNaN :: RealFloat a => a -> a -> a
diff --git a/test/Rosenbrock.hs b/test/Rosenbrock.hs
@@ -1,15 +1,15 @@
-module Rosenbrock where
+module Main where
import Numeric.MCMC.Metropolis
import qualified System.Random.MWC.Probability as MWC
-import Data.Vector.Unboxed (Vector, unsafeIndex)
-rosenbrock :: Vector Double -> Double
+rosenbrock :: [Double] -> Double
rosenbrock xs = (-1)*(5*(x1 - x0^2)^2 + 0.05*(1 - x0)^2) where
- x0 = unsafeIndex xs 0
- x1 = unsafeIndex xs 1
+ x0 = head xs
+ x1 = xs !! 1
main :: IO ()
-main = MWC.withSystemRandom . MWC.asGenIO $ mcmc 100000 1 [0, 0] rosenbrock
+main = MWC.withSystemRandom . MWC.asGenIO $
+ mcmc 10000 1 [0, 0] rosenbrock