commit 23609d9b1d808d9009c5051d41a88ec002b7d3d0
parent 48906cfeb548ea85022f456237b01c356fdc93fb
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 29 Jan 2020 19:25:21 +0400
Tweak CRP implementation.
Fix style nits, s/chineseRestaurantProcess/crp, use more Foldable stuff.
Diffstat:
1 file changed, 28 insertions(+), 30 deletions(-)
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}
@@ -87,7 +88,7 @@ module System.Random.MWC.Probability (
, negativeBinomial
, multinomial
, poisson
- , chineseRestaurantProcess
+ , crp
) where
import Control.Applicative
@@ -451,28 +452,28 @@ zipf a = do
go
{-# INLINABLE zipf #-}
-
-
-
--- | Chinese Restaurant Process
---
--- Implementation based on Griffiths, Ghahramani 2011
---
--- >>> sample (chineseRestaurantProcess 1.8 50) gen
--- [22,10,7,1,2,2,4,1,1]
-chineseRestaurantProcess :: PrimMonad m =>
- Double -- ^ Concentration parameter (> 1)
- -> Int -- ^ Total number of customers
- -> Prob m [Integer]
-chineseRestaurantProcess a n = do
- ts <- go crpInitial 1
- pure $ map getSum $ crpCustomers ts
+-- | The Chinese Restaurant Process with concentration parameter `a` and number
+-- of customers `n`.
+--
+-- See Griffiths and Ghahramani, 2011 for details.
+--
+-- >>> sample (crp 1.8 50) gen
+-- [22,10,7,1,2,2,4,1,1]
+crp
+ :: PrimMonad m
+ => Double -- ^ concentration parameter (> 1)
+ -> Int -- ^ number of customers
+ -> Prob m [Integer]
+crp a n = do
+ ts <- go crpInitial 1
+ pure $ F.toList (fmap getSum ts)
where
go acc i
| i == n = pure acc
| otherwise = do
acc' <- crpSingle i acc a
go acc' (i + 1)
+{-# INLINABLE crp #-}
-- | Update step of the CRP
crpSingle :: (PrimMonad m, Integral b) =>
@@ -481,27 +482,24 @@ crpSingle :: (PrimMonad m, Integral b) =>
-> Double
-> Prob m (CRPTables (Sum b))
crpSingle i zs a = do
- zn1 <- categorical probs
- pure $ crpInsert zn1 zs
+ zn1 <- categorical probs
+ pure $ crpInsert zn1 zs
where
probs = pms <> [pm1]
- mks = getSum <$> crpCustomers zs -- # of customers sitting at each table
- pms = map (\m -> fromIntegral m / (fromIntegral i - 1 + a)) mks
+ acc m = fromIntegral m / (fromIntegral i - 1 + a)
+ pms = F.toList $ fmap (acc . getSum) zs
pm1 = a / (fromIntegral i - 1 + a)
--- | Tables at the Chinese Restaurant
+-- Tables at the Chinese Restaurant
newtype CRPTables c = CRP {
- getCRPTables :: IM.IntMap c } deriving (Eq, Show, Functor, Foldable)
+ getCRPTables :: IM.IntMap c
+ } deriving (Eq, Show, Functor, Foldable, Semigroup, Monoid)
--- | Initial state of the CRP : one customer sitting at table #0
+-- Initial state of the CRP : one customer sitting at table #0
crpInitial :: CRPTables (Sum Integer)
-crpInitial = crpInsert 0 $ CRP IM.empty
+crpInitial = crpInsert 0 mempty
--- | Seat one customer at table 'k'
+-- Seat one customer at table 'k'
crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts
--- | Number of customers sitting at each table
-crpCustomers :: CRPTables c -> [c]
-crpCustomers = map snd . IM.toList . getCRPTables
-