mwc-probability

Sampling function-based probability distributions.
Log | Files | Refs | README | LICENSE

commit a3012909061aaf30553e5c5d0451c80d2f4b1755
parent bbe26ba96a8917c526438b9031ac951cf547a3d1
Author: Marco Zocca <ocramz>
Date:   Mon, 20 Jan 2020 16:24:29 +0100

add Pitman-Yor sampler

Diffstat:
Mmwc-probability.cabal | 1+
Msrc/System/Random/MWC/Probability.hs | 59+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 60 insertions(+), 0 deletions(-)

diff --git a/mwc-probability.cabal b/mwc-probability.cabal @@ -55,6 +55,7 @@ library hs-source-dirs: src build-depends: base >= 4.8 && < 6 + , containers , mwc-random > 0.13 && < 0.15 , primitive >= 0.6 && < 1.0 , transformers >= 0.5 && < 1.0 diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE CPP #-} {-# OPTIONS_GHC -Wall #-} @@ -85,6 +87,7 @@ module System.Random.MWC.Probability ( , negativeBinomial , multinomial , poisson + , pitmanYor ) where import Control.Applicative @@ -92,11 +95,13 @@ import Control.Monad import Control.Monad.Primitive import Control.Monad.IO.Class import Control.Monad.Trans.Class +import Data.Monoid (Sum(..)) #if __GLASGOW_HASKELL__ < 710 import Data.Foldable (Foldable) #endif import qualified Data.Foldable as F import Data.List (findIndex) +import qualified Data.IntMap as IM import System.Random.MWC as MWC hiding (uniform, uniformR) import qualified System.Random.MWC as QMWC import qualified System.Random.MWC.Distributions as MWC.Dist @@ -445,3 +450,57 @@ zipf a = do else go go {-# INLINABLE zipf #-} + + +-- * Chinese Restaurant process + +-- | Pitman-Yor process +pitmanYor :: (PrimMonad f) => + Double -- ^ a \in [0, 1] + -> Double -- ^ b > 0 + -> Int -- ^ number of samples + -> Gen (PrimState f) + -> f [Integer] +pitmanYor a b n gen = do + ts <- go crpInitial (n - 1) + pure $ map getSum $ customers ts + where + go acc 0 = pure acc + go acc i = do + acc' <- sample (pitmanYorSingle a b acc) gen + go acc' (i - 1) + +pitmanYorSingle :: (PrimMonad m, Integral a) => + Double -- a \in [0, 1] + -> Double -- b > 0 + -> CRPTables (Sum a) + -> Prob m (CRPTables (Sum a)) +pitmanYorSingle a b zs = do + zn1 <- categorical probs + pure $ crpInsert zn1 zs + where + m = fromIntegral $ uniques zs + n = fromIntegral $ numCustomers zs + d = n + b + probs = pms <> [pm1] + pm1 = (m * a + b) / d + pms = map (\x -> (fromIntegral (getSum x) - a) / d) $ customers zs + +-- | Tables at the Chinese Restaurant +newtype CRPTables c = CRP { + getCRPTables :: IM.IntMap c } deriving (Eq, Show, Functor, Foldable) + +crpInitial :: CRPTables (Sum Integer) +crpInitial = crpInsert 0 $ CRP IM.empty + +crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a) +crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts + +uniques :: CRPTables a -> Int +uniques (CRP ts) = length ts + +customers :: CRPTables c -> [c] +customers = map snd . IM.toList . getCRPTables + +numCustomers :: (Foldable t, Functor t, Num a) => t (Sum a) -> a +numCustomers = sum . fmap getSum