mwc-probability

Sampling function-based probability distributions.
Log | Files | Refs | README | LICENSE

commit 56a805b2ee6169a948d53e3cad0009c1de9baf5f
parent 82d6854e02c073cb2f9945bddff1e343a66f662f
Author: Boarders <callan.mcgill@gmail.com>
Date:   Tue, 23 Jul 2019 12:17:38 -0400

Allow multinomial to accept a weighted probability distribution as argument

Diffstat:
Msrc/System/Random/MWC/Probability.hs | 23+++++++++++++++--------
1 file changed, 15 insertions(+), 8 deletions(-)

diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs @@ -340,15 +340,21 @@ negativeBinomial n p = do -- | The multinomial distribution of `n` trials and category probabilities -- `ps`. -- --- Note that `ps` is a vector of probabilities and should sum to one. +-- Note that the supplied probability container should consist of non-negative +-- values but is not required to sum to one. multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int] multinomial n ps = do - let cumulative = scanl1 (+) (F.toList ps) - replicateM n $ do - z <- uniform - case findIndex (> z) cumulative of - Just g -> return g - Nothing -> error "mwc-probability: invalid probability vector" + let (cumulative, total) = runningTotals (F.toList ps) + replicateM n $ do + z <- uniformR (0, total) + case findIndex (> z) cumulative of + Just g -> return g + Nothing -> error "mwc-probability: invalid probability vector" + where + -- Note: this is significantly faster than any + -- of the recursions one might write by hand. + runningTotals :: Num a => [a] -> ([a], a) + runningTotals xs = let adds = scanl1 (+) xs in (adds, sum xs) {-# INLINABLE multinomial #-} -- | Generalized Student's t distribution with location parameter `m`, scale @@ -404,7 +410,8 @@ poisson l = Prob $ genFromTable table where -- | A categorical distribution defined by the supplied probabilities. -- --- Note that the supplied container of probabilities must sum to 1. +-- Note that the supplied probability container should consist of non-negative +-- values but is not required to sum to one. categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int categorical ps = do xs <- multinomial 1 ps