mwc-probability

Sampling function-based probability distributions.
Log | Files | Refs | README | LICENSE

commit 23609d9b1d808d9009c5051d41a88ec002b7d3d0
parent 48906cfeb548ea85022f456237b01c356fdc93fb
Author: Jared Tobin <jared@jtobin.io>
Date:   Wed, 29 Jan 2020 19:25:21 +0400

Tweak CRP implementation.

Fix style nits, s/chineseRestaurantProcess/crp, use more Foldable stuff.

Diffstat:
Msrc/System/Random/MWC/Probability.hs | 58++++++++++++++++++++++++++++------------------------------
1 file changed, 28 insertions(+), 30 deletions(-)

diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE CPP #-} {-# OPTIONS_GHC -Wall #-} @@ -87,7 +88,7 @@ module System.Random.MWC.Probability ( , negativeBinomial , multinomial , poisson - , chineseRestaurantProcess + , crp ) where import Control.Applicative @@ -451,28 +452,28 @@ zipf a = do go {-# INLINABLE zipf #-} - - - --- | Chinese Restaurant Process --- --- Implementation based on Griffiths, Ghahramani 2011 --- --- >>> sample (chineseRestaurantProcess 1.8 50) gen --- [22,10,7,1,2,2,4,1,1] -chineseRestaurantProcess :: PrimMonad m => - Double -- ^ Concentration parameter (> 1) - -> Int -- ^ Total number of customers - -> Prob m [Integer] -chineseRestaurantProcess a n = do - ts <- go crpInitial 1 - pure $ map getSum $ crpCustomers ts +-- | The Chinese Restaurant Process with concentration parameter `a` and number +-- of customers `n`. +-- +-- See Griffiths and Ghahramani, 2011 for details. +-- +-- >>> sample (crp 1.8 50) gen +-- [22,10,7,1,2,2,4,1,1] +crp + :: PrimMonad m + => Double -- ^ concentration parameter (> 1) + -> Int -- ^ number of customers + -> Prob m [Integer] +crp a n = do + ts <- go crpInitial 1 + pure $ F.toList (fmap getSum ts) where go acc i | i == n = pure acc | otherwise = do acc' <- crpSingle i acc a go acc' (i + 1) +{-# INLINABLE crp #-} -- | Update step of the CRP crpSingle :: (PrimMonad m, Integral b) => @@ -481,27 +482,24 @@ crpSingle :: (PrimMonad m, Integral b) => -> Double -> Prob m (CRPTables (Sum b)) crpSingle i zs a = do - zn1 <- categorical probs - pure $ crpInsert zn1 zs + zn1 <- categorical probs + pure $ crpInsert zn1 zs where probs = pms <> [pm1] - mks = getSum <$> crpCustomers zs -- # of customers sitting at each table - pms = map (\m -> fromIntegral m / (fromIntegral i - 1 + a)) mks + acc m = fromIntegral m / (fromIntegral i - 1 + a) + pms = F.toList $ fmap (acc . getSum) zs pm1 = a / (fromIntegral i - 1 + a) --- | Tables at the Chinese Restaurant +-- Tables at the Chinese Restaurant newtype CRPTables c = CRP { - getCRPTables :: IM.IntMap c } deriving (Eq, Show, Functor, Foldable) + getCRPTables :: IM.IntMap c + } deriving (Eq, Show, Functor, Foldable, Semigroup, Monoid) --- | Initial state of the CRP : one customer sitting at table #0 +-- Initial state of the CRP : one customer sitting at table #0 crpInitial :: CRPTables (Sum Integer) -crpInitial = crpInsert 0 $ CRP IM.empty +crpInitial = crpInsert 0 mempty --- | Seat one customer at table 'k' +-- Seat one customer at table 'k' crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a) crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts --- | Number of customers sitting at each table -crpCustomers :: CRPTables c -> [c] -crpCustomers = map snd . IM.toList . getCRPTables -