commit 90a2c401d9d8d339adc0bbc6d439291a05378698
parent bbe26ba96a8917c526438b9031ac951cf547a3d1
Author: Jared Tobin <jared@jtobin.ca>
Date: Wed, 29 Jan 2020 19:30:51 +0400
Merge pull request #17 from jtobin/crp
Add implementation of Chinese Restaurant process
Diffstat:
2 files changed, 65 insertions(+), 6 deletions(-)
diff --git a/mwc-probability.cabal b/mwc-probability.cabal
@@ -55,6 +55,7 @@ library
hs-source-dirs: src
build-depends:
base >= 4.8 && < 6
+ , containers >= 0.6
, mwc-random > 0.13 && < 0.15
, primitive >= 0.6 && < 1.0
, transformers >= 0.5 && < 1.0
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -1,3 +1,6 @@
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}
@@ -85,6 +88,7 @@ module System.Random.MWC.Probability (
, negativeBinomial
, multinomial
, poisson
+ , crp
) where
import Control.Applicative
@@ -92,11 +96,13 @@ import Control.Monad
import Control.Monad.Primitive
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
+import Data.Monoid (Sum(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable)
#endif
import qualified Data.Foldable as F
import Data.List (findIndex)
+import qualified Data.IntMap as IM
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
@@ -422,14 +428,14 @@ categorical ps = do
-- | The Zipf-Mandelbrot distribution.
--
--- Note that `a` should be positive, and that values close to 1 should be
--- avoided as they are very computationally intensive.
+-- Note that `a` should be positive, and that values close to 1 should be
+-- avoided as they are very computationally intensive.
--
--- >>> samples 10 (zipf 1.1) gen
--- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50]
+-- >>> samples 10 (zipf 1.1) gen
+-- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50]
--
--- >>> samples 10 (zipf 1.5) gen
--- [19,3,3,1,1,2,1,191,2,1]
+-- >>> samples 10 (zipf 1.5) gen
+-- [19,3,3,1,1,2,1,191,2,1]
zipf :: (PrimMonad m, Integral b) => Double -> Prob m b
zipf a = do
let
@@ -445,3 +451,55 @@ zipf a = do
else go
go
{-# INLINABLE zipf #-}
+
+-- | The Chinese Restaurant Process with concentration parameter `a` and number
+-- of customers `n`.
+--
+-- See Griffiths and Ghahramani, 2011 for details.
+--
+-- >>> sample (crp 1.8 50) gen
+-- [22,10,7,1,2,2,4,1,1]
+crp
+ :: PrimMonad m
+ => Double -- ^ concentration parameter (> 1)
+ -> Int -- ^ number of customers
+ -> Prob m [Integer]
+crp a n = do
+ ts <- go crpInitial 1
+ pure $ F.toList (fmap getSum ts)
+ where
+ go acc i
+ | i == n = pure acc
+ | otherwise = do
+ acc' <- crpSingle i acc a
+ go acc' (i + 1)
+{-# INLINABLE crp #-}
+
+-- | Update step of the CRP
+crpSingle :: (PrimMonad m, Integral b) =>
+ Int
+ -> CRPTables (Sum b)
+ -> Double
+ -> Prob m (CRPTables (Sum b))
+crpSingle i zs a = do
+ zn1 <- categorical probs
+ pure $ crpInsert zn1 zs
+ where
+ probs = pms <> [pm1]
+ acc m = fromIntegral m / (fromIntegral i - 1 + a)
+ pms = F.toList $ fmap (acc . getSum) zs
+ pm1 = a / (fromIntegral i - 1 + a)
+
+-- Tables at the Chinese Restaurant
+newtype CRPTables c = CRP {
+ getCRPTables :: IM.IntMap c
+ } deriving (Eq, Show, Functor, Foldable, Semigroup, Monoid)
+
+-- Initial state of the CRP : one customer sitting at table #0
+crpInitial :: CRPTables (Sum Integer)
+crpInitial = crpInsert 0 mempty
+
+-- Seat one customer at table 'k'
+crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
+crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts
+