commit 5fca576e31eb45333966ce0aaa581d9c28b95b1b
parent 2a63090a01cdd2c78b512183d07a152c9d8b19d4
Author: Jared Tobin <jared@jtobin.ca>
Date: Fri, 3 Apr 2015 20:21:27 +1000
Add multinomial, symmetric dirichlet samplers.
Diffstat:
2 files changed, 28 insertions(+), 13 deletions(-)
diff --git a/mwc-probability.cabal b/mwc-probability.cabal
@@ -1,5 +1,13 @@
name: mwc-probability
-version: 0.1.0.0
+version: 0.2.0.1
+homepage: http://github.com/jtobin/mwc-probability
+license: MIT
+license-file: LICENSE
+author: Jared Tobin
+maintainer: jared@jtobin.ca
+category: Math
+build-type: Simple
+cabal-version: >= 1.18
synopsis: Sampling function-based probability distributions.
description:
@@ -36,22 +44,13 @@ description:
> n <- uniformR (5, 10)
> binomial n p
-homepage: http://github.com/jtobin/mwc-probability
-license: MIT
-license-file: LICENSE
-author: Jared Tobin
-maintainer: jared@jtobin.ca
-category: Math
-build-type: Simple
-cabal-version: >=1.10
-
library
exposed-modules: System.Random.MWC.Probability
default-language: Haskell2010
hs-source-dirs: src
build-depends:
base >= 4.7 && < 4.8
- , mwc-random >= 0.13.3.0
- , primitive >= 0.5.4.0
- , transformers >= 0.4.2.0
+ , mwc-random
+ , primitive
+ , transformers
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -15,14 +15,17 @@ module System.Random.MWC.Probability (
, chiSquare
, beta
, dirichlet
+ , symmetricDirichlet
, bernoulli
, binomial
+ , multinomial
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.Trans.Class
+import Data.List (findIndex)
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
@@ -106,6 +109,10 @@ dirichlet as = do
return $ map (/ sum zs) zs
{-# INLINE dirichlet #-}
+symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
+symmetricDirichlet n a = dirichlet (replicate n a)
+{-# INLINE symmetricDirichlet #-}
+
bernoulli :: PrimMonad m => Double -> Prob m Bool
bernoulli p = (< p) <$> uniform
{-# INLINE bernoulli #-}
@@ -114,3 +121,12 @@ binomial :: PrimMonad m => Int -> Double -> Prob m Int
binomial n p = liftM (length . filter id) $ replicateM n (bernoulli p)
{-# INLINE binomial #-}
+multinomial :: PrimMonad m => Int -> [Double] -> Prob m [Int]
+multinomial n ps = do
+ let cumulative = scanl1 (+) ps
+ replicateM n $ do
+ z <- uniform
+ let Just g = findIndex (> z) cumulative
+ return g
+{-# INLINE multinomial #-}
+