commit 6a3bcf6bbd457fce731445aa4e885b5d78064527
parent 58e929aca44b596a04142d59369d80256f8c159e
Author: Jared Tobin <jared@jtobin.io>
Date: Sat, 2 May 2020 17:23:43 +0400
Add discrete distribution.
Resolves #18.
Diffstat:
1 file changed, 15 insertions(+), 0 deletions(-)
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -83,6 +83,7 @@ module System.Random.MWC.Probability (
, discreteUniform
, zipf
, categorical
+ , discrete
, bernoulli
, binomial
, negativeBinomial
@@ -426,6 +427,20 @@ categorical ps = do
_ -> error "mwc-probability: invalid probability vector"
{-# INLINABLE categorical #-}
+-- | A categorical distribution defined by the supplied support.
+--
+-- Note that the supplied probabilities should be non-negative, but are not
+-- required to sum to one.
+--
+-- >>> samples 10 (discrete [(0.1, "yeah"), (0.9, "nah")]) gen
+-- ["yeah","nah","nah","nah","nah","yeah","nah","nah","nah","nah"]
+discrete :: (Foldable f, PrimMonad m) => f (Double, a) -> Prob m a
+discrete d = do
+ let (ps, xs) = unzip (F.toList d)
+ idx <- categorical ps
+ pure (xs !! idx)
+{-# INLINABLE discrete #-}
+
-- | The Zipf-Mandelbrot distribution.
--
-- Note that `a` should be positive, and that values close to 1 should be