mwc-probability

Sampling function-based probability distributions.
Log | Files | Refs | README | LICENSE

commit 9a0f99cfe7afbc6341cd1095ea54b38336b64fec
parent 82d6854e02c073cb2f9945bddff1e343a66f662f
Author: Jared Tobin <jared@jtobin.ca>
Date:   Tue, 23 Jul 2019 14:23:54 -0230

Merge pull request #16 from Boarders/weighted-probability

Allow weighted probability distribution
Diffstat:
Msrc/System/Random/MWC/Probability.hs | 23+++++++++++++++--------
1 file changed, 15 insertions(+), 8 deletions(-)

diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs @@ -340,15 +340,21 @@ negativeBinomial n p = do -- | The multinomial distribution of `n` trials and category probabilities -- `ps`. -- --- Note that `ps` is a vector of probabilities and should sum to one. +-- Note that the supplied probability container should consist of non-negative +-- values but is not required to sum to one. multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int] multinomial n ps = do - let cumulative = scanl1 (+) (F.toList ps) - replicateM n $ do - z <- uniform - case findIndex (> z) cumulative of - Just g -> return g - Nothing -> error "mwc-probability: invalid probability vector" + let (cumulative, total) = runningTotals (F.toList ps) + replicateM n $ do + z <- uniformR (0, total) + case findIndex (> z) cumulative of + Just g -> return g + Nothing -> error "mwc-probability: invalid probability vector" + where + -- Note: this is significantly faster than any + -- of the recursions one might write by hand. + runningTotals :: Num a => [a] -> ([a], a) + runningTotals xs = let adds = scanl1 (+) xs in (adds, sum xs) {-# INLINABLE multinomial #-} -- | Generalized Student's t distribution with location parameter `m`, scale @@ -404,7 +410,8 @@ poisson l = Prob $ genFromTable table where -- | A categorical distribution defined by the supplied probabilities. -- --- Note that the supplied container of probabilities must sum to 1. +-- Note that the supplied probability container should consist of non-negative +-- values but is not required to sum to one. categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int categorical ps = do xs <- multinomial 1 ps