mwc-probability

Sampling function-based probability distributions.
Log | Files | Refs | README | LICENSE

commit 90a2c401d9d8d339adc0bbc6d439291a05378698
parent bbe26ba96a8917c526438b9031ac951cf547a3d1
Author: Jared Tobin <jared@jtobin.ca>
Date:   Wed, 29 Jan 2020 19:30:51 +0400

Merge pull request #17 from jtobin/crp

Add implementation of Chinese Restaurant process
Diffstat:
Mmwc-probability.cabal | 1+
Msrc/System/Random/MWC/Probability.hs | 70++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------
2 files changed, 65 insertions(+), 6 deletions(-)

diff --git a/mwc-probability.cabal b/mwc-probability.cabal @@ -55,6 +55,7 @@ library hs-source-dirs: src build-depends: base >= 4.8 && < 6 + , containers >= 0.6 , mwc-random > 0.13 && < 0.15 , primitive >= 0.6 && < 1.0 , transformers >= 0.5 && < 1.0 diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs @@ -1,3 +1,6 @@ +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE CPP #-} {-# OPTIONS_GHC -Wall #-} @@ -85,6 +88,7 @@ module System.Random.MWC.Probability ( , negativeBinomial , multinomial , poisson + , crp ) where import Control.Applicative @@ -92,11 +96,13 @@ import Control.Monad import Control.Monad.Primitive import Control.Monad.IO.Class import Control.Monad.Trans.Class +import Data.Monoid (Sum(..)) #if __GLASGOW_HASKELL__ < 710 import Data.Foldable (Foldable) #endif import qualified Data.Foldable as F import Data.List (findIndex) +import qualified Data.IntMap as IM import System.Random.MWC as MWC hiding (uniform, uniformR) import qualified System.Random.MWC as QMWC import qualified System.Random.MWC.Distributions as MWC.Dist @@ -422,14 +428,14 @@ categorical ps = do -- | The Zipf-Mandelbrot distribution. -- --- Note that `a` should be positive, and that values close to 1 should be --- avoided as they are very computationally intensive. +-- Note that `a` should be positive, and that values close to 1 should be +-- avoided as they are very computationally intensive. -- --- >>> samples 10 (zipf 1.1) gen --- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50] +-- >>> samples 10 (zipf 1.1) gen +-- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50] -- --- >>> samples 10 (zipf 1.5) gen --- [19,3,3,1,1,2,1,191,2,1] +-- >>> samples 10 (zipf 1.5) gen +-- [19,3,3,1,1,2,1,191,2,1] zipf :: (PrimMonad m, Integral b) => Double -> Prob m b zipf a = do let @@ -445,3 +451,55 @@ zipf a = do else go go {-# INLINABLE zipf #-} + +-- | The Chinese Restaurant Process with concentration parameter `a` and number +-- of customers `n`. +-- +-- See Griffiths and Ghahramani, 2011 for details. +-- +-- >>> sample (crp 1.8 50) gen +-- [22,10,7,1,2,2,4,1,1] +crp + :: PrimMonad m + => Double -- ^ concentration parameter (> 1) + -> Int -- ^ number of customers + -> Prob m [Integer] +crp a n = do + ts <- go crpInitial 1 + pure $ F.toList (fmap getSum ts) + where + go acc i + | i == n = pure acc + | otherwise = do + acc' <- crpSingle i acc a + go acc' (i + 1) +{-# INLINABLE crp #-} + +-- | Update step of the CRP +crpSingle :: (PrimMonad m, Integral b) => + Int + -> CRPTables (Sum b) + -> Double + -> Prob m (CRPTables (Sum b)) +crpSingle i zs a = do + zn1 <- categorical probs + pure $ crpInsert zn1 zs + where + probs = pms <> [pm1] + acc m = fromIntegral m / (fromIntegral i - 1 + a) + pms = F.toList $ fmap (acc . getSum) zs + pm1 = a / (fromIntegral i - 1 + a) + +-- Tables at the Chinese Restaurant +newtype CRPTables c = CRP { + getCRPTables :: IM.IntMap c + } deriving (Eq, Show, Functor, Foldable, Semigroup, Monoid) + +-- Initial state of the CRP : one customer sitting at table #0 +crpInitial :: CRPTables (Sum Integer) +crpInitial = crpInsert 0 mempty + +-- Seat one customer at table 'k' +crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a) +crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts +