**commit** 56a805b2ee6169a948d53e3cad0009c1de9baf5f
**parent** 82d6854e02c073cb2f9945bddff1e343a66f662f
**Author:** Boarders <callan.mcgill@gmail.com>
**Date:** Tue, 23 Jul 2019 12:17:38 -0400
Allow multinomial to accept a weighted probability distribution as argument
**Diffstat:**

1 file changed, 15 insertions(+), 8 deletions(-)

**diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs**
@@ -340,15 +340,21 @@ negativeBinomial n p = do
-- | The multinomial distribution of `n` trials and category probabilities
-- `ps`.
--
--- Note that `ps` is a vector of probabilities and should sum to one.
+-- Note that the supplied probability container should consist of non-negative
+-- values but is not required to sum to one.
multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int]
multinomial n ps = do
- let cumulative = scanl1 (+) (F.toList ps)
- replicateM n $ do
- z <- uniform
- case findIndex (> z) cumulative of
- Just g -> return g
- Nothing -> error "mwc-probability: invalid probability vector"
+ let (cumulative, total) = runningTotals (F.toList ps)
+ replicateM n $ do
+ z <- uniformR (0, total)
+ case findIndex (> z) cumulative of
+ Just g -> return g
+ Nothing -> error "mwc-probability: invalid probability vector"
+ where
+ -- Note: this is significantly faster than any
+ -- of the recursions one might write by hand.
+ runningTotals :: Num a => [a] -> ([a], a)
+ runningTotals xs = let adds = scanl1 (+) xs in (adds, sum xs)
{-# INLINABLE multinomial #-}
-- | Generalized Student's t distribution with location parameter `m`, scale
@@ -404,7 +410,8 @@ poisson l = Prob $ genFromTable table where
-- | A categorical distribution defined by the supplied probabilities.
--
--- Note that the supplied container of probabilities must sum to 1.
+-- Note that the supplied probability container should consist of non-negative
+-- values but is not required to sum to one.
categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int
categorical ps = do
xs <- multinomial 1 ps