mwc-probability

Sampling function-based probability distributions.
git clone git://git.jtobin.io/mwc-probability.git
Log | Files | Refs | README | LICENSE

Probability.hs (16447B)


      1 {-# LANGUAGE DeriveFoldable #-}
      2 {-# LANGUAGE DeriveFunctor #-}
      3 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
      4 {-# LANGUAGE TypeFamilies #-}
      5 {-# LANGUAGE CPP #-}
      6 {-# OPTIONS_GHC -Wall #-}
      7 
      8 -- |
      9 -- Module: System.Random.MWC.Probability
     10 -- Copyright: (c) 2015-2018 Jared Tobin, Marco Zocca
     11 -- License: MIT
     12 --
     13 -- Maintainer: Jared Tobin <jared@jtobin.ca>, Marco Zocca <zocca.marco gmail>
     14 -- Stability: unstable
     15 -- Portability: ghc
     16 --
     17 -- A probability monad based on sampling functions, implemented as a thin
     18 -- wrapper over the
     19 -- [mwc-random](https://hackage.haskell.org/package/mwc-random) library.
     20 --
     21 -- Probability distributions are abstract constructs that can be represented in
     22 -- a variety of ways.  The sampling function representation is particularly
     23 -- useful -- it's computationally efficient, and collections of samples are
     24 -- amenable to much practical work.
     25 --
     26 -- Probability monads propagate uncertainty under the hood.  An expression like
     27 -- @'beta' 1 8 >>= 'binomial' 10@ corresponds to a
     28 -- <https://en.wikipedia.org/wiki/Beta-binomial_distribution beta-binomial>
     29 -- distribution in which the uncertainty captured by @'beta' 1 8@ has been
     30 -- marginalized out.
     31 --
     32 -- The distribution resulting from a series of effects is called the
     33 -- /predictive distribution/ of the model described by the corresponding
     34 -- expression.  The monadic structure lets one piece together a hierarchical
     35 -- structure from simpler, local conditionals:
     36 --
     37 -- > hierarchicalModel = do
     38 -- >   [c, d, e, f] <- replicateM 4 $ uniformR (1, 10)
     39 -- >   a <- gamma c d
     40 -- >   b <- gamma e f
     41 -- >   p <- beta a b
     42 -- >   n <- uniformR (5, 10)
     43 -- >   binomial n p
     44 --
     45 -- The functor instance allows one to transforms the support of a distribution
     46 -- while leaving its density structure invariant.  For example, @'uniform'@ is
     47 -- a distribution over the 0-1 interval, but @fmap (+ 1) uniform@ is the
     48 -- translated distribution over the 1-2 interval:
     49 --
     50 -- >>> create >>= sample (fmap (+ 1) uniform)
     51 -- 1.5480073474340754
     52 --
     53 -- The applicative instance guarantees that the generated samples are generated
     54 -- independently:
     55 --
     56 -- >>> create >>= sample ((,) <$> uniform <*> uniform)
     57 
     58 module System.Random.MWC.Probability (
     59     module MWC
     60   , Prob(..)
     61   , samples
     62 
     63   , uniform
     64   , uniformR
     65   , normal
     66   , standardNormal
     67   , isoNormal
     68   , logNormal
     69   , exponential
     70   , inverseGaussian
     71   , laplace
     72   , gamma
     73   , inverseGamma
     74   , normalGamma
     75   , weibull
     76   , chiSquare
     77   , beta
     78   , gstudent
     79   , student
     80   , pareto
     81   , dirichlet
     82   , symmetricDirichlet
     83   , discreteUniform
     84   , zipf
     85   , categorical
     86   , discrete
     87   , bernoulli
     88   , binomial
     89   , negativeBinomial
     90   , multinomial
     91   , poisson
     92   , crp
     93   ) where
     94 
     95 import Control.Applicative
     96 import Control.Monad
     97 import Control.Monad.Primitive
     98 import Control.Monad.IO.Class
     99 import Control.Monad.Trans.Class
    100 import Data.Monoid (Sum(..))
    101 #if __GLASGOW_HASKELL__ < 710
    102 import Data.Foldable (Foldable)
    103 #endif
    104 import qualified Data.Foldable as F
    105 import Data.List (findIndex)
    106 import qualified Data.IntMap as IM
    107 import System.Random.MWC as MWC hiding (uniform, uniformR)
    108 import qualified System.Random.MWC as QMWC
    109 import qualified System.Random.MWC.Distributions as MWC.Dist
    110 import System.Random.MWC.CondensedTable
    111 
    112 -- | A probability distribution characterized by a sampling function.
    113 --
    114 -- >>> gen <- createSystemRandom
    115 -- >>> sample uniform gen
    116 -- 0.4208881170464097
    117 newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a }
    118 
    119 -- | Sample from a model 'n' times.
    120 --
    121 -- >>> samples 2 uniform gen
    122 -- [0.6738707766845254,0.9730405951541817]
    123 samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a]
    124 samples n model gen = sequenceA (replicate n (sample model gen))
    125 {-# INLINABLE samples #-}
    126 
    127 instance Functor m => Functor (Prob m) where
    128   fmap h (Prob f) = Prob (fmap h . f)
    129 
    130 instance Monad m => Applicative (Prob m) where
    131   pure  = Prob . const . pure
    132   (<*>) = ap
    133 
    134 instance Monad m => Monad (Prob m) where
    135   return = pure
    136   m >>= h = Prob $ \g -> do
    137     z <- sample m g
    138     sample (h z) g
    139   {-# INLINABLE (>>=) #-}
    140 
    141 instance (Monad m, Num a) => Num (Prob m a) where
    142   (+)         = liftA2 (+)
    143   (-)         = liftA2 (-)
    144   (*)         = liftA2 (*)
    145   abs         = fmap abs
    146   signum      = fmap signum
    147   fromInteger = pure . fromInteger
    148 
    149 instance MonadTrans Prob where
    150   lift m = Prob $ const m
    151 
    152 instance MonadIO m => MonadIO (Prob m) where
    153   liftIO m = Prob $ const (liftIO m)
    154 
    155 instance PrimMonad m => PrimMonad (Prob m) where
    156   type PrimState (Prob m) = PrimState m
    157   primitive = lift . primitive
    158   {-# INLINE primitive #-}
    159 
    160 -- | The uniform distribution at a specified type.
    161 --
    162 --   Note that `Double` and `Float` variates are defined over the unit
    163 --   interval.
    164 --
    165 --   >>> sample uniform gen :: IO Double
    166 --   0.29308497534914946
    167 --   >>> sample uniform gen :: IO Bool
    168 --   False
    169 uniform :: (PrimMonad m, Variate a) => Prob m a
    170 uniform = Prob QMWC.uniform
    171 {-# INLINABLE uniform #-}
    172 
    173 -- | The uniform distribution over the provided interval.
    174 --
    175 --   >>> sample (uniformR (0, 1)) gen
    176 --   0.44984153252922365
    177 uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a
    178 uniformR r = Prob $ QMWC.uniformR r
    179 {-# INLINABLE uniformR #-}
    180 
    181 -- | The discrete uniform distribution.
    182 --
    183 --   >>> sample (discreteUniform [0..10]) gen
    184 --   6
    185 --   >>> sample (discreteUniform "abcdefghijklmnopqrstuvwxyz") gen
    186 --   'a'
    187 discreteUniform :: (PrimMonad m, Foldable f) => f a -> Prob m a
    188 discreteUniform cs = do
    189   j <- uniformR (0, length cs - 1)
    190   return $ F.toList cs !! j
    191 {-# INLINABLE discreteUniform #-}
    192 
    193 -- | The standard normal or Gaussian distribution with mean 0 and standard
    194 --   deviation 1.
    195 standardNormal :: PrimMonad m => Prob m Double
    196 standardNormal = Prob MWC.Dist.standard
    197 {-# INLINABLE standardNormal #-}
    198 
    199 -- | The normal or Gaussian distribution with specified mean and standard
    200 --   deviation.
    201 --
    202 --   Note that `sd` should be positive.
    203 normal :: PrimMonad m => Double -> Double -> Prob m Double
    204 normal m sd = Prob $ MWC.Dist.normal m sd
    205 {-# INLINABLE normal #-}
    206 
    207 -- | The log-normal distribution with specified mean and standard deviation.
    208 --
    209 --   Note that `sd` should be positive.
    210 logNormal :: PrimMonad m => Double -> Double -> Prob m Double
    211 logNormal m sd = exp <$> normal m sd
    212 {-# INLINABLE logNormal #-}
    213 
    214 -- | The exponential distribution with provided rate parameter.
    215 --
    216 --   Note that `r` should be positive.
    217 exponential :: PrimMonad m => Double -> Prob m Double
    218 exponential r = Prob $ MWC.Dist.exponential r
    219 {-# INLINABLE exponential #-}
    220 
    221 -- | The Laplace or double-exponential distribution with provided location and
    222 --   scale parameters.
    223 --
    224 --   Note that `sigma` should be positive.
    225 laplace :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
    226 laplace mu sigma = do
    227   u <- uniformR (-0.5, 0.5)
    228   let b = sigma / sqrt 2
    229   return $ mu - b * signum u * log (1 - 2 * abs u)
    230 {-# INLINABLE laplace #-}
    231 
    232 -- | The Weibull distribution with provided shape and scale parameters.
    233 --
    234 --   Note that both parameters should be positive.
    235 weibull :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
    236 weibull a b = do
    237   x <- uniform
    238   return $ (- 1/a * log (1 - x)) ** 1/b
    239 {-# INLINABLE weibull #-}
    240 
    241 -- | The gamma distribution with shape parameter `a` and scale parameter `b`.
    242 --
    243 --   This is the parameterization used more traditionally in frequentist
    244 --   statistics.  It has the following corresponding probability density
    245 --   function:
    246 --
    247 -- > f(x; a, b) = 1 / (Gamma(a) * b ^ a) x ^ (a - 1) e ^ (- x / b)
    248 --
    249 --   Note that both parameters should be positive.
    250 gamma :: PrimMonad m => Double -> Double -> Prob m Double
    251 gamma a b = Prob $ MWC.Dist.gamma a b
    252 {-# INLINABLE gamma #-}
    253 
    254 -- | The inverse-gamma distribution with shape parameter `a` and scale
    255 --   parameter `b`.
    256 --
    257 --   Note that both parameters should be positive.
    258 inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double
    259 inverseGamma a b = recip <$> gamma a b
    260 {-# INLINABLE inverseGamma #-}
    261 
    262 -- | The Normal-Gamma distribution.
    263 --
    264 --   Note that the `lambda`, `a`, and `b` parameters should be positive.
    265 normalGamma :: PrimMonad m => Double -> Double -> Double -> Double -> Prob m Double
    266 normalGamma mu lambda a b = do
    267   tau <- gamma a b
    268   let xsd = sqrt (recip (lambda * tau))
    269   normal mu xsd
    270 {-# INLINABLE normalGamma #-}
    271 
    272 -- | The chi-square distribution with the specified degrees of freedom.
    273 --
    274 --   Note that `k` should be positive.
    275 chiSquare :: PrimMonad m => Int -> Prob m Double
    276 chiSquare k = Prob $ MWC.Dist.chiSquare k
    277 {-# INLINABLE chiSquare #-}
    278 
    279 -- | The beta distribution with the specified shape parameters.
    280 --
    281 --   Note that both parameters should be positive.
    282 beta :: PrimMonad m => Double -> Double -> Prob m Double
    283 beta a b = do
    284   u <- gamma a 1
    285   w <- gamma b 1
    286   return $ u / (u + w)
    287 {-# INLINABLE beta #-}
    288 
    289 -- | The Pareto distribution with specified index `a` and minimum `xmin`
    290 --   parameters.
    291 --
    292 --   Note that both parameters should be positive.
    293 pareto :: PrimMonad m => Double -> Double -> Prob m Double
    294 pareto a xmin = do
    295   y <- exponential a
    296   return $ xmin * exp y
    297 {-# INLINABLE pareto #-}
    298 
    299 -- | The Dirichlet distribution with the provided concentration parameters.
    300 --   The dimension of the distribution is determined by the number of
    301 --   concentration parameters supplied.
    302 --
    303 --   >>> sample (dirichlet [0.1, 1, 10]) gen
    304 --   [1.2375387187120799e-5,3.4952460651813816e-3,0.9964923785476316]
    305 --
    306 --   Note that all concentration parameters should be positive.
    307 dirichlet
    308   :: (Traversable f, PrimMonad m) => f Double -> Prob m (f Double)
    309 dirichlet as = do
    310   zs <- traverse (`gamma` 1) as
    311   return $ fmap (/ sum zs) zs
    312 {-# INLINABLE dirichlet #-}
    313 
    314 -- | The symmetric Dirichlet distribution with dimension `n`.  The provided
    315 --   concentration parameter is simply replicated `n` times.
    316 --
    317 --   Note that `a` should be positive.
    318 symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
    319 symmetricDirichlet n a = dirichlet (replicate n a)
    320 {-# INLINABLE symmetricDirichlet #-}
    321 
    322 -- | The Bernoulli distribution with success probability `p`.
    323 bernoulli :: PrimMonad m => Double -> Prob m Bool
    324 bernoulli p = (< p) <$> uniform
    325 {-# INLINABLE bernoulli #-}
    326 
    327 -- | The binomial distribution with number of trials `n` and success
    328 --   probability `p`.
    329 --
    330 --   >>> sample (binomial 10 0.3) gen
    331 --   4
    332 binomial :: PrimMonad m => Int -> Double -> Prob m Int
    333 binomial n p = fmap (length . filter id) $ replicateM n (bernoulli p)
    334 {-# INLINABLE binomial #-}
    335 
    336 -- | The negative binomial distribution with number of trials `n` and success
    337 --   probability `p`.
    338 --
    339 --   >>> sample (negativeBinomial 10 0.3) gen
    340 --   21
    341 negativeBinomial :: (PrimMonad m, Integral a) => a -> Double -> Prob m Int
    342 negativeBinomial n p = do
    343   y <- gamma (fromIntegral n) ((1 - p) / p)
    344   poisson y
    345 {-# INLINABLE negativeBinomial #-}
    346 
    347 -- | The multinomial distribution of `n` trials and category probabilities
    348 --   `ps`.
    349 --
    350 --   Note that the supplied probability container should consist of non-negative
    351 --   values but is not required to sum to one.
    352 multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int]
    353 multinomial n ps = do
    354     let (cumulative, total) = runningTotals (F.toList ps)
    355     replicateM n $ do
    356       z <- uniformR (0, total)
    357       case findIndex (> z) cumulative of
    358         Just g  -> return g
    359         Nothing -> error "mwc-probability: invalid probability vector"
    360   where
    361     -- Note: this is significantly faster than any
    362     -- of the recursions one might write by hand.
    363     runningTotals :: Num a => [a] -> ([a], a)
    364     runningTotals xs = let adds = scanl1 (+) xs in (adds, sum xs)
    365 {-# INLINABLE multinomial #-}
    366 
    367 -- | Generalized Student's t distribution with location parameter `m`, scale
    368 --   parameter `s`, and degrees of freedom `k`.
    369 --
    370 --   Note that the `s` and `k` parameters should be positive.
    371 gstudent :: PrimMonad m => Double -> Double -> Double -> Prob m Double
    372 gstudent m s k = do
    373   sd <- fmap sqrt (inverseGamma (k / 2) (s * 2 / k))
    374   normal m sd
    375 {-# INLINABLE gstudent #-}
    376 
    377 -- | Student's t distribution with `k` degrees of freedom.
    378 --
    379 --   Note that `k` should be positive.
    380 student :: PrimMonad m => Double -> Prob m Double
    381 student = gstudent 0 1
    382 {-# INLINABLE student #-}
    383 
    384 -- | An isotropic or spherical Gaussian distribution with specified mean
    385 --   vector and scalar standard deviation parameter.
    386 --
    387 --   Note that `sd` should be positive.
    388 isoNormal
    389   :: (Traversable f, PrimMonad m) => f Double -> Double -> Prob m (f Double)
    390 isoNormal ms sd = traverse (`normal` sd) ms
    391 {-# INLINABLE isoNormal #-}
    392 
    393 -- | The inverse Gaussian (also known as Wald) distribution with mean parameter
    394 --   `mu` and shape parameter `lambda`.
    395 --
    396 --   Note that both 'mu' and 'lambda' should be positive.
    397 inverseGaussian :: PrimMonad m => Double -> Double -> Prob m Double
    398 inverseGaussian lambda mu = do
    399   nu <- standardNormal
    400   let y = nu ** 2
    401       s =  sqrt (4 * mu * lambda * y + mu ** 2  * y ** 2)
    402       x = mu * (1 + 1 / (2 * lambda) * (mu * y - s))
    403       thresh = mu / (mu + x)
    404   z <- uniform
    405   if z <= thresh
    406     then return x
    407     else return (mu ** 2 / x)
    408 {-# INLINABLE inverseGaussian #-}
    409 
    410 -- | The Poisson distribution with rate parameter `l`.
    411 --
    412 --   Note that `l` should be positive.
    413 poisson :: PrimMonad m => Double -> Prob m Int
    414 poisson l = Prob $ genFromTable table where
    415   table = tablePoisson l
    416 {-# INLINABLE poisson #-}
    417 
    418 -- | A categorical distribution defined by the supplied probabilities.
    419 --
    420 --   Note that the supplied probability container should consist of non-negative
    421 --   values but is not required to sum to one.
    422 categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int
    423 categorical ps = do
    424   xs <- multinomial 1 ps
    425   case xs of
    426     [x] -> return x
    427     _   -> error "mwc-probability: invalid probability vector"
    428 {-# INLINABLE categorical #-}
    429 
    430 -- | A categorical distribution defined by the supplied support.
    431 --
    432 --   Note that the supplied probabilities should be non-negative, but are not
    433 --   required to sum to one.
    434 --
    435 --   >>> samples 10 (discrete [(0.1, "yeah"), (0.9, "nah")]) gen
    436 --   ["yeah","nah","nah","nah","nah","yeah","nah","nah","nah","nah"]
    437 discrete :: (Foldable f, PrimMonad m) => f (Double, a) -> Prob m a
    438 discrete d = do
    439   let (ps, xs) = unzip (F.toList d)
    440   idx <- categorical ps
    441   pure (xs !! idx)
    442 {-# INLINABLE discrete #-}
    443 
    444 -- | The Zipf-Mandelbrot distribution.
    445 --
    446 --  Note that `a` should be positive, and that values close to 1 should be
    447 --  avoided as they are very computationally intensive.
    448 --
    449 --  >>> samples 10 (zipf 1.1) gen
    450 --  [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50]
    451 --
    452 --  >>> samples 10 (zipf 1.5) gen
    453 --  [19,3,3,1,1,2,1,191,2,1]
    454 zipf :: (PrimMonad m, Integral b) => Double -> Prob m b
    455 zipf a = do
    456   let
    457     b = 2 ** (a - 1)
    458     go = do
    459         u <- uniform
    460         v <- uniform
    461         let xInt = floor (u ** (- 1 / (a - 1)))
    462             x = fromIntegral xInt
    463             t = (1 + 1 / x) ** (a - 1)
    464         if v * x * (t - 1) / (b - 1) <= t / b
    465           then return xInt
    466           else go
    467   go
    468 {-# INLINABLE zipf #-}
    469 
    470 -- | The Chinese Restaurant Process with concentration parameter `a` and number
    471 --   of customers `n`.
    472 --
    473 --   See Griffiths and Ghahramani, 2011 for details.
    474 --
    475 --   >>> sample (crp 1.8 50) gen
    476 --   [22,10,7,1,2,2,4,1,1]
    477 crp
    478   :: PrimMonad m
    479   => Double            -- ^ concentration parameter (> 1)
    480   -> Int               -- ^ number of customers
    481   -> Prob m [Integer]
    482 crp a n = do
    483     ts <- go crpInitial 1
    484     pure $ F.toList (fmap getSum ts)
    485   where
    486     go acc i
    487       | i == n = pure acc
    488       | otherwise = do
    489           acc' <- crpSingle i acc a
    490           go acc' (i + 1)
    491 {-# INLINABLE crp #-}
    492 
    493 -- | Update step of the CRP
    494 crpSingle :: (PrimMonad m, Integral b) =>
    495              Int
    496           -> CRPTables (Sum b)
    497           -> Double
    498           -> Prob m (CRPTables (Sum b))
    499 crpSingle i zs a = do
    500     zn1 <- categorical probs
    501     pure $ crpInsert zn1 zs
    502   where
    503     probs = pms <> [pm1]
    504     acc m = fromIntegral m / (fromIntegral i - 1 + a)
    505     pms = F.toList $ fmap (acc . getSum) zs
    506     pm1 = a / (fromIntegral i - 1 + a)
    507 
    508 -- Tables at the Chinese Restaurant
    509 newtype CRPTables c = CRP {
    510     getCRPTables :: IM.IntMap c
    511   } deriving (Eq, Show, Functor, Foldable, Semigroup, Monoid)
    512 
    513 -- Initial state of the CRP : one customer sitting at table #0
    514 crpInitial :: CRPTables (Sum Integer)
    515 crpInitial = crpInsert 0 mempty
    516 
    517 -- Seat one customer at table 'k'
    518 crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
    519 crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts
    520