commit a3012909061aaf30553e5c5d0451c80d2f4b1755
parent bbe26ba96a8917c526438b9031ac951cf547a3d1
Author: Marco Zocca <ocramz>
Date: Mon, 20 Jan 2020 16:24:29 +0100
add Pitman-Yor sampler
Diffstat:
2 files changed, 60 insertions(+), 0 deletions(-)
diff --git a/mwc-probability.cabal b/mwc-probability.cabal
@@ -55,6 +55,7 @@ library
hs-source-dirs: src
build-depends:
base >= 4.8 && < 6
+ , containers
, mwc-random > 0.13 && < 0.15
, primitive >= 0.6 && < 1.0
, transformers >= 0.5 && < 1.0
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -1,3 +1,5 @@
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}
@@ -85,6 +87,7 @@ module System.Random.MWC.Probability (
, negativeBinomial
, multinomial
, poisson
+ , pitmanYor
) where
import Control.Applicative
@@ -92,11 +95,13 @@ import Control.Monad
import Control.Monad.Primitive
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
+import Data.Monoid (Sum(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable)
#endif
import qualified Data.Foldable as F
import Data.List (findIndex)
+import qualified Data.IntMap as IM
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
@@ -445,3 +450,57 @@ zipf a = do
else go
go
{-# INLINABLE zipf #-}
+
+
+-- * Chinese Restaurant process
+
+-- | Pitman-Yor process
+pitmanYor :: (PrimMonad f) =>
+ Double -- ^ a \in [0, 1]
+ -> Double -- ^ b > 0
+ -> Int -- ^ number of samples
+ -> Gen (PrimState f)
+ -> f [Integer]
+pitmanYor a b n gen = do
+ ts <- go crpInitial (n - 1)
+ pure $ map getSum $ customers ts
+ where
+ go acc 0 = pure acc
+ go acc i = do
+ acc' <- sample (pitmanYorSingle a b acc) gen
+ go acc' (i - 1)
+
+pitmanYorSingle :: (PrimMonad m, Integral a) =>
+ Double -- a \in [0, 1]
+ -> Double -- b > 0
+ -> CRPTables (Sum a)
+ -> Prob m (CRPTables (Sum a))
+pitmanYorSingle a b zs = do
+ zn1 <- categorical probs
+ pure $ crpInsert zn1 zs
+ where
+ m = fromIntegral $ uniques zs
+ n = fromIntegral $ numCustomers zs
+ d = n + b
+ probs = pms <> [pm1]
+ pm1 = (m * a + b) / d
+ pms = map (\x -> (fromIntegral (getSum x) - a) / d) $ customers zs
+
+-- | Tables at the Chinese Restaurant
+newtype CRPTables c = CRP {
+ getCRPTables :: IM.IntMap c } deriving (Eq, Show, Functor, Foldable)
+
+crpInitial :: CRPTables (Sum Integer)
+crpInitial = crpInsert 0 $ CRP IM.empty
+
+crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
+crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts
+
+uniques :: CRPTables a -> Int
+uniques (CRP ts) = length ts
+
+customers :: CRPTables c -> [c]
+customers = map snd . IM.toList . getCRPTables
+
+numCustomers :: (Foldable t, Functor t, Num a) => t (Sum a) -> a
+numCustomers = sum . fmap getSum