commit 48906cfeb548ea85022f456237b01c356fdc93fb
parent 1430771f0359a7f1be709069f0ff6a3e06997042
Author: Marco Zocca <ocramz>
Date: Tue, 21 Jan 2020 10:30:39 +0100
PY -> crp
Diffstat:
1 file changed, 37 insertions(+), 40 deletions(-)
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -87,7 +87,7 @@ module System.Random.MWC.Probability (
, negativeBinomial
, multinomial
, poisson
- , pitmanYor
+ , chineseRestaurantProcess
) where
import Control.Applicative
@@ -427,14 +427,14 @@ categorical ps = do
-- | The Zipf-Mandelbrot distribution.
--
--- Note that `a` should be positive, and that values close to 1 should be
--- avoided as they are very computationally intensive.
+-- Note that `a` should be positive, and that values close to 1 should be
+-- avoided as they are very computationally intensive.
--
--- >>> samples 10 (zipf 1.1) gen
--- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50]
+-- >>> samples 10 (zipf 1.1) gen
+-- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50]
--
--- >>> samples 10 (zipf 1.5) gen
--- [19,3,3,1,1,2,1,191,2,1]
+-- >>> samples 10 (zipf 1.5) gen
+-- [19,3,3,1,1,2,1,191,2,1]
zipf :: (PrimMonad m, Integral b) => Double -> Prob m b
zipf a = do
let
@@ -452,41 +452,42 @@ zipf a = do
{-# INLINABLE zipf #-}
--- * Chinese Restaurant process
--- | Pitman-Yor process
+-- | Chinese Restaurant Process
--
--- This implementation is given in terms of the Chinese Restaurant process
-pitmanYor :: (PrimMonad f) =>
- Double -- ^ a \in [0, 1]
- -> Double -- ^ b > 0
- -> Int -- ^ Total number of customers entering the Chinese Restaurant
- -> Prob f [Integer]
-pitmanYor a b n = do
- ts <- go crpInitial (n - 1)
- pure $ map getSum $ customers ts
+-- Implementation based on Griffiths, Ghahramani 2011
+--
+-- >>> sample (chineseRestaurantProcess 1.8 50) gen
+-- [22,10,7,1,2,2,4,1,1]
+chineseRestaurantProcess :: PrimMonad m =>
+ Double -- ^ Concentration parameter (> 1)
+ -> Int -- ^ Total number of customers
+ -> Prob m [Integer]
+chineseRestaurantProcess a n = do
+ ts <- go crpInitial 1
+ pure $ map getSum $ crpCustomers ts
where
- go acc 0 = pure acc
- go acc i = do
- acc' <- pitmanYorSingle a b acc
- go acc' (i - 1)
-
-pitmanYorSingle :: (PrimMonad m, Integral a) =>
- Double -- a \in [0, 1]
- -> Double -- b > 0
- -> CRPTables (Sum a)
- -> Prob m (CRPTables (Sum a))
-pitmanYorSingle a b zs = do
+ go acc i
+ | i == n = pure acc
+ | otherwise = do
+ acc' <- crpSingle i acc a
+ go acc' (i + 1)
+
+-- | Update step of the CRP
+crpSingle :: (PrimMonad m, Integral b) =>
+ Int
+ -> CRPTables (Sum b)
+ -> Double
+ -> Prob m (CRPTables (Sum b))
+crpSingle i zs a = do
zn1 <- categorical probs
pure $ crpInsert zn1 zs
where
- m = fromIntegral $ uniques zs
- n = fromIntegral $ sum (getSum <$> zs)
- d = n + b
probs = pms <> [pm1]
- pm1 = (m * a + b) / d
- pms = map (\x -> (fromIntegral (getSum x) - a) / d) $ customers zs
+ mks = getSum <$> crpCustomers zs -- # of customers sitting at each table
+ pms = map (\m -> fromIntegral m / (fromIntegral i - 1 + a)) mks
+ pm1 = a / (fromIntegral i - 1 + a)
-- | Tables at the Chinese Restaurant
newtype CRPTables c = CRP {
@@ -500,11 +501,7 @@ crpInitial = crpInsert 0 $ CRP IM.empty
crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts
--- | Number of tables
-uniques :: CRPTables a -> Int
-uniques (CRP ts) = length ts
-
-- | Number of customers sitting at each table
-customers :: CRPTables c -> [c]
-customers = map snd . IM.toList . getCRPTables
+crpCustomers :: CRPTables c -> [c]
+crpCustomers = map snd . IM.toList . getCRPTables