mwc-probability

Sampling function-based probability distributions.
Log | Files | Refs | README | LICENSE

commit 48906cfeb548ea85022f456237b01c356fdc93fb
parent 1430771f0359a7f1be709069f0ff6a3e06997042
Author: Marco Zocca <ocramz>
Date:   Tue, 21 Jan 2020 10:30:39 +0100

PY -> crp

Diffstat:
Msrc/System/Random/MWC/Probability.hs | 77+++++++++++++++++++++++++++++++++++++----------------------------------------
1 file changed, 37 insertions(+), 40 deletions(-)

diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs @@ -87,7 +87,7 @@ module System.Random.MWC.Probability ( , negativeBinomial , multinomial , poisson - , pitmanYor + , chineseRestaurantProcess ) where import Control.Applicative @@ -427,14 +427,14 @@ categorical ps = do -- | The Zipf-Mandelbrot distribution. -- --- Note that `a` should be positive, and that values close to 1 should be --- avoided as they are very computationally intensive. +-- Note that `a` should be positive, and that values close to 1 should be +-- avoided as they are very computationally intensive. -- --- >>> samples 10 (zipf 1.1) gen --- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50] +-- >>> samples 10 (zipf 1.1) gen +-- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50] -- --- >>> samples 10 (zipf 1.5) gen --- [19,3,3,1,1,2,1,191,2,1] +-- >>> samples 10 (zipf 1.5) gen +-- [19,3,3,1,1,2,1,191,2,1] zipf :: (PrimMonad m, Integral b) => Double -> Prob m b zipf a = do let @@ -452,41 +452,42 @@ zipf a = do {-# INLINABLE zipf #-} --- * Chinese Restaurant process --- | Pitman-Yor process +-- | Chinese Restaurant Process -- --- This implementation is given in terms of the Chinese Restaurant process -pitmanYor :: (PrimMonad f) => - Double -- ^ a \in [0, 1] - -> Double -- ^ b > 0 - -> Int -- ^ Total number of customers entering the Chinese Restaurant - -> Prob f [Integer] -pitmanYor a b n = do - ts <- go crpInitial (n - 1) - pure $ map getSum $ customers ts +-- Implementation based on Griffiths, Ghahramani 2011 +-- +-- >>> sample (chineseRestaurantProcess 1.8 50) gen +-- [22,10,7,1,2,2,4,1,1] +chineseRestaurantProcess :: PrimMonad m => + Double -- ^ Concentration parameter (> 1) + -> Int -- ^ Total number of customers + -> Prob m [Integer] +chineseRestaurantProcess a n = do + ts <- go crpInitial 1 + pure $ map getSum $ crpCustomers ts where - go acc 0 = pure acc - go acc i = do - acc' <- pitmanYorSingle a b acc - go acc' (i - 1) - -pitmanYorSingle :: (PrimMonad m, Integral a) => - Double -- a \in [0, 1] - -> Double -- b > 0 - -> CRPTables (Sum a) - -> Prob m (CRPTables (Sum a)) -pitmanYorSingle a b zs = do + go acc i + | i == n = pure acc + | otherwise = do + acc' <- crpSingle i acc a + go acc' (i + 1) + +-- | Update step of the CRP +crpSingle :: (PrimMonad m, Integral b) => + Int + -> CRPTables (Sum b) + -> Double + -> Prob m (CRPTables (Sum b)) +crpSingle i zs a = do zn1 <- categorical probs pure $ crpInsert zn1 zs where - m = fromIntegral $ uniques zs - n = fromIntegral $ sum (getSum <$> zs) - d = n + b probs = pms <> [pm1] - pm1 = (m * a + b) / d - pms = map (\x -> (fromIntegral (getSum x) - a) / d) $ customers zs + mks = getSum <$> crpCustomers zs -- # of customers sitting at each table + pms = map (\m -> fromIntegral m / (fromIntegral i - 1 + a)) mks + pm1 = a / (fromIntegral i - 1 + a) -- | Tables at the Chinese Restaurant newtype CRPTables c = CRP { @@ -500,11 +501,7 @@ crpInitial = crpInsert 0 $ CRP IM.empty crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a) crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts --- | Number of tables -uniques :: CRPTables a -> Int -uniques (CRP ts) = length ts - -- | Number of customers sitting at each table -customers :: CRPTables c -> [c] -customers = map snd . IM.toList . getCRPTables +crpCustomers :: CRPTables c -> [c] +crpCustomers = map snd . IM.toList . getCRPTables