commit 9a0f99cfe7afbc6341cd1095ea54b38336b64fec
parent 82d6854e02c073cb2f9945bddff1e343a66f662f
Author: Jared Tobin <jared@jtobin.ca>
Date: Tue, 23 Jul 2019 14:23:54 -0230
Merge pull request #16 from Boarders/weighted-probability
Allow weighted probability distribution
Diffstat:
1 file changed, 15 insertions(+), 8 deletions(-)
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -340,15 +340,21 @@ negativeBinomial n p = do
-- | The multinomial distribution of `n` trials and category probabilities
-- `ps`.
--
--- Note that `ps` is a vector of probabilities and should sum to one.
+-- Note that the supplied probability container should consist of non-negative
+-- values but is not required to sum to one.
multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int]
multinomial n ps = do
- let cumulative = scanl1 (+) (F.toList ps)
- replicateM n $ do
- z <- uniform
- case findIndex (> z) cumulative of
- Just g -> return g
- Nothing -> error "mwc-probability: invalid probability vector"
+ let (cumulative, total) = runningTotals (F.toList ps)
+ replicateM n $ do
+ z <- uniformR (0, total)
+ case findIndex (> z) cumulative of
+ Just g -> return g
+ Nothing -> error "mwc-probability: invalid probability vector"
+ where
+ -- Note: this is significantly faster than any
+ -- of the recursions one might write by hand.
+ runningTotals :: Num a => [a] -> ([a], a)
+ runningTotals xs = let adds = scanl1 (+) xs in (adds, sum xs)
{-# INLINABLE multinomial #-}
-- | Generalized Student's t distribution with location parameter `m`, scale
@@ -404,7 +410,8 @@ poisson l = Prob $ genFromTable table where
-- | A categorical distribution defined by the supplied probabilities.
--
--- Note that the supplied container of probabilities must sum to 1.
+-- Note that the supplied probability container should consist of non-negative
+-- values but is not required to sum to one.
categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int
categorical ps = do
xs <- multinomial 1 ps