Probability.hs (16447B)
1 {-# LANGUAGE DeriveFoldable #-} 2 {-# LANGUAGE DeriveFunctor #-} 3 {-# LANGUAGE GeneralizedNewtypeDeriving #-} 4 {-# LANGUAGE TypeFamilies #-} 5 {-# LANGUAGE CPP #-} 6 {-# OPTIONS_GHC -Wall #-} 7 8 -- | 9 -- Module: System.Random.MWC.Probability 10 -- Copyright: (c) 2015-2018 Jared Tobin, Marco Zocca 11 -- License: MIT 12 -- 13 -- Maintainer: Jared Tobin <jared@jtobin.ca>, Marco Zocca <zocca.marco gmail> 14 -- Stability: unstable 15 -- Portability: ghc 16 -- 17 -- A probability monad based on sampling functions, implemented as a thin 18 -- wrapper over the 19 -- [mwc-random](https://hackage.haskell.org/package/mwc-random) library. 20 -- 21 -- Probability distributions are abstract constructs that can be represented in 22 -- a variety of ways. The sampling function representation is particularly 23 -- useful -- it's computationally efficient, and collections of samples are 24 -- amenable to much practical work. 25 -- 26 -- Probability monads propagate uncertainty under the hood. An expression like 27 -- @'beta' 1 8 >>= 'binomial' 10@ corresponds to a 28 -- <https://en.wikipedia.org/wiki/Beta-binomial_distribution beta-binomial> 29 -- distribution in which the uncertainty captured by @'beta' 1 8@ has been 30 -- marginalized out. 31 -- 32 -- The distribution resulting from a series of effects is called the 33 -- /predictive distribution/ of the model described by the corresponding 34 -- expression. The monadic structure lets one piece together a hierarchical 35 -- structure from simpler, local conditionals: 36 -- 37 -- > hierarchicalModel = do 38 -- > [c, d, e, f] <- replicateM 4 $ uniformR (1, 10) 39 -- > a <- gamma c d 40 -- > b <- gamma e f 41 -- > p <- beta a b 42 -- > n <- uniformR (5, 10) 43 -- > binomial n p 44 -- 45 -- The functor instance allows one to transforms the support of a distribution 46 -- while leaving its density structure invariant. For example, @'uniform'@ is 47 -- a distribution over the 0-1 interval, but @fmap (+ 1) uniform@ is the 48 -- translated distribution over the 1-2 interval: 49 -- 50 -- >>> create >>= sample (fmap (+ 1) uniform) 51 -- 1.5480073474340754 52 -- 53 -- The applicative instance guarantees that the generated samples are generated 54 -- independently: 55 -- 56 -- >>> create >>= sample ((,) <$> uniform <*> uniform) 57 58 module System.Random.MWC.Probability ( 59 module MWC 60 , Prob(..) 61 , samples 62 63 , uniform 64 , uniformR 65 , normal 66 , standardNormal 67 , isoNormal 68 , logNormal 69 , exponential 70 , inverseGaussian 71 , laplace 72 , gamma 73 , inverseGamma 74 , normalGamma 75 , weibull 76 , chiSquare 77 , beta 78 , gstudent 79 , student 80 , pareto 81 , dirichlet 82 , symmetricDirichlet 83 , discreteUniform 84 , zipf 85 , categorical 86 , discrete 87 , bernoulli 88 , binomial 89 , negativeBinomial 90 , multinomial 91 , poisson 92 , crp 93 ) where 94 95 import Control.Applicative 96 import Control.Monad 97 import Control.Monad.Primitive 98 import Control.Monad.IO.Class 99 import Control.Monad.Trans.Class 100 import Data.Monoid (Sum(..)) 101 #if __GLASGOW_HASKELL__ < 710 102 import Data.Foldable (Foldable) 103 #endif 104 import qualified Data.Foldable as F 105 import Data.List (findIndex) 106 import qualified Data.IntMap as IM 107 import System.Random.MWC as MWC hiding (uniform, uniformR) 108 import qualified System.Random.MWC as QMWC 109 import qualified System.Random.MWC.Distributions as MWC.Dist 110 import System.Random.MWC.CondensedTable 111 112 -- | A probability distribution characterized by a sampling function. 113 -- 114 -- >>> gen <- createSystemRandom 115 -- >>> sample uniform gen 116 -- 0.4208881170464097 117 newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a } 118 119 -- | Sample from a model 'n' times. 120 -- 121 -- >>> samples 2 uniform gen 122 -- [0.6738707766845254,0.9730405951541817] 123 samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a] 124 samples n model gen = sequenceA (replicate n (sample model gen)) 125 {-# INLINABLE samples #-} 126 127 instance Functor m => Functor (Prob m) where 128 fmap h (Prob f) = Prob (fmap h . f) 129 130 instance Monad m => Applicative (Prob m) where 131 pure = Prob . const . pure 132 (<*>) = ap 133 134 instance Monad m => Monad (Prob m) where 135 return = pure 136 m >>= h = Prob $ \g -> do 137 z <- sample m g 138 sample (h z) g 139 {-# INLINABLE (>>=) #-} 140 141 instance (Monad m, Num a) => Num (Prob m a) where 142 (+) = liftA2 (+) 143 (-) = liftA2 (-) 144 (*) = liftA2 (*) 145 abs = fmap abs 146 signum = fmap signum 147 fromInteger = pure . fromInteger 148 149 instance MonadTrans Prob where 150 lift m = Prob $ const m 151 152 instance MonadIO m => MonadIO (Prob m) where 153 liftIO m = Prob $ const (liftIO m) 154 155 instance PrimMonad m => PrimMonad (Prob m) where 156 type PrimState (Prob m) = PrimState m 157 primitive = lift . primitive 158 {-# INLINE primitive #-} 159 160 -- | The uniform distribution at a specified type. 161 -- 162 -- Note that `Double` and `Float` variates are defined over the unit 163 -- interval. 164 -- 165 -- >>> sample uniform gen :: IO Double 166 -- 0.29308497534914946 167 -- >>> sample uniform gen :: IO Bool 168 -- False 169 uniform :: (PrimMonad m, Variate a) => Prob m a 170 uniform = Prob QMWC.uniform 171 {-# INLINABLE uniform #-} 172 173 -- | The uniform distribution over the provided interval. 174 -- 175 -- >>> sample (uniformR (0, 1)) gen 176 -- 0.44984153252922365 177 uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a 178 uniformR r = Prob $ QMWC.uniformR r 179 {-# INLINABLE uniformR #-} 180 181 -- | The discrete uniform distribution. 182 -- 183 -- >>> sample (discreteUniform [0..10]) gen 184 -- 6 185 -- >>> sample (discreteUniform "abcdefghijklmnopqrstuvwxyz") gen 186 -- 'a' 187 discreteUniform :: (PrimMonad m, Foldable f) => f a -> Prob m a 188 discreteUniform cs = do 189 j <- uniformR (0, length cs - 1) 190 return $ F.toList cs !! j 191 {-# INLINABLE discreteUniform #-} 192 193 -- | The standard normal or Gaussian distribution with mean 0 and standard 194 -- deviation 1. 195 standardNormal :: PrimMonad m => Prob m Double 196 standardNormal = Prob MWC.Dist.standard 197 {-# INLINABLE standardNormal #-} 198 199 -- | The normal or Gaussian distribution with specified mean and standard 200 -- deviation. 201 -- 202 -- Note that `sd` should be positive. 203 normal :: PrimMonad m => Double -> Double -> Prob m Double 204 normal m sd = Prob $ MWC.Dist.normal m sd 205 {-# INLINABLE normal #-} 206 207 -- | The log-normal distribution with specified mean and standard deviation. 208 -- 209 -- Note that `sd` should be positive. 210 logNormal :: PrimMonad m => Double -> Double -> Prob m Double 211 logNormal m sd = exp <$> normal m sd 212 {-# INLINABLE logNormal #-} 213 214 -- | The exponential distribution with provided rate parameter. 215 -- 216 -- Note that `r` should be positive. 217 exponential :: PrimMonad m => Double -> Prob m Double 218 exponential r = Prob $ MWC.Dist.exponential r 219 {-# INLINABLE exponential #-} 220 221 -- | The Laplace or double-exponential distribution with provided location and 222 -- scale parameters. 223 -- 224 -- Note that `sigma` should be positive. 225 laplace :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a 226 laplace mu sigma = do 227 u <- uniformR (-0.5, 0.5) 228 let b = sigma / sqrt 2 229 return $ mu - b * signum u * log (1 - 2 * abs u) 230 {-# INLINABLE laplace #-} 231 232 -- | The Weibull distribution with provided shape and scale parameters. 233 -- 234 -- Note that both parameters should be positive. 235 weibull :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a 236 weibull a b = do 237 x <- uniform 238 return $ (- 1/a * log (1 - x)) ** 1/b 239 {-# INLINABLE weibull #-} 240 241 -- | The gamma distribution with shape parameter `a` and scale parameter `b`. 242 -- 243 -- This is the parameterization used more traditionally in frequentist 244 -- statistics. It has the following corresponding probability density 245 -- function: 246 -- 247 -- > f(x; a, b) = 1 / (Gamma(a) * b ^ a) x ^ (a - 1) e ^ (- x / b) 248 -- 249 -- Note that both parameters should be positive. 250 gamma :: PrimMonad m => Double -> Double -> Prob m Double 251 gamma a b = Prob $ MWC.Dist.gamma a b 252 {-# INLINABLE gamma #-} 253 254 -- | The inverse-gamma distribution with shape parameter `a` and scale 255 -- parameter `b`. 256 -- 257 -- Note that both parameters should be positive. 258 inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double 259 inverseGamma a b = recip <$> gamma a b 260 {-# INLINABLE inverseGamma #-} 261 262 -- | The Normal-Gamma distribution. 263 -- 264 -- Note that the `lambda`, `a`, and `b` parameters should be positive. 265 normalGamma :: PrimMonad m => Double -> Double -> Double -> Double -> Prob m Double 266 normalGamma mu lambda a b = do 267 tau <- gamma a b 268 let xsd = sqrt (recip (lambda * tau)) 269 normal mu xsd 270 {-# INLINABLE normalGamma #-} 271 272 -- | The chi-square distribution with the specified degrees of freedom. 273 -- 274 -- Note that `k` should be positive. 275 chiSquare :: PrimMonad m => Int -> Prob m Double 276 chiSquare k = Prob $ MWC.Dist.chiSquare k 277 {-# INLINABLE chiSquare #-} 278 279 -- | The beta distribution with the specified shape parameters. 280 -- 281 -- Note that both parameters should be positive. 282 beta :: PrimMonad m => Double -> Double -> Prob m Double 283 beta a b = do 284 u <- gamma a 1 285 w <- gamma b 1 286 return $ u / (u + w) 287 {-# INLINABLE beta #-} 288 289 -- | The Pareto distribution with specified index `a` and minimum `xmin` 290 -- parameters. 291 -- 292 -- Note that both parameters should be positive. 293 pareto :: PrimMonad m => Double -> Double -> Prob m Double 294 pareto a xmin = do 295 y <- exponential a 296 return $ xmin * exp y 297 {-# INLINABLE pareto #-} 298 299 -- | The Dirichlet distribution with the provided concentration parameters. 300 -- The dimension of the distribution is determined by the number of 301 -- concentration parameters supplied. 302 -- 303 -- >>> sample (dirichlet [0.1, 1, 10]) gen 304 -- [1.2375387187120799e-5,3.4952460651813816e-3,0.9964923785476316] 305 -- 306 -- Note that all concentration parameters should be positive. 307 dirichlet 308 :: (Traversable f, PrimMonad m) => f Double -> Prob m (f Double) 309 dirichlet as = do 310 zs <- traverse (`gamma` 1) as 311 return $ fmap (/ sum zs) zs 312 {-# INLINABLE dirichlet #-} 313 314 -- | The symmetric Dirichlet distribution with dimension `n`. The provided 315 -- concentration parameter is simply replicated `n` times. 316 -- 317 -- Note that `a` should be positive. 318 symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double] 319 symmetricDirichlet n a = dirichlet (replicate n a) 320 {-# INLINABLE symmetricDirichlet #-} 321 322 -- | The Bernoulli distribution with success probability `p`. 323 bernoulli :: PrimMonad m => Double -> Prob m Bool 324 bernoulli p = (< p) <$> uniform 325 {-# INLINABLE bernoulli #-} 326 327 -- | The binomial distribution with number of trials `n` and success 328 -- probability `p`. 329 -- 330 -- >>> sample (binomial 10 0.3) gen 331 -- 4 332 binomial :: PrimMonad m => Int -> Double -> Prob m Int 333 binomial n p = fmap (length . filter id) $ replicateM n (bernoulli p) 334 {-# INLINABLE binomial #-} 335 336 -- | The negative binomial distribution with number of trials `n` and success 337 -- probability `p`. 338 -- 339 -- >>> sample (negativeBinomial 10 0.3) gen 340 -- 21 341 negativeBinomial :: (PrimMonad m, Integral a) => a -> Double -> Prob m Int 342 negativeBinomial n p = do 343 y <- gamma (fromIntegral n) ((1 - p) / p) 344 poisson y 345 {-# INLINABLE negativeBinomial #-} 346 347 -- | The multinomial distribution of `n` trials and category probabilities 348 -- `ps`. 349 -- 350 -- Note that the supplied probability container should consist of non-negative 351 -- values but is not required to sum to one. 352 multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int] 353 multinomial n ps = do 354 let (cumulative, total) = runningTotals (F.toList ps) 355 replicateM n $ do 356 z <- uniformR (0, total) 357 case findIndex (> z) cumulative of 358 Just g -> return g 359 Nothing -> error "mwc-probability: invalid probability vector" 360 where 361 -- Note: this is significantly faster than any 362 -- of the recursions one might write by hand. 363 runningTotals :: Num a => [a] -> ([a], a) 364 runningTotals xs = let adds = scanl1 (+) xs in (adds, sum xs) 365 {-# INLINABLE multinomial #-} 366 367 -- | Generalized Student's t distribution with location parameter `m`, scale 368 -- parameter `s`, and degrees of freedom `k`. 369 -- 370 -- Note that the `s` and `k` parameters should be positive. 371 gstudent :: PrimMonad m => Double -> Double -> Double -> Prob m Double 372 gstudent m s k = do 373 sd <- fmap sqrt (inverseGamma (k / 2) (s * 2 / k)) 374 normal m sd 375 {-# INLINABLE gstudent #-} 376 377 -- | Student's t distribution with `k` degrees of freedom. 378 -- 379 -- Note that `k` should be positive. 380 student :: PrimMonad m => Double -> Prob m Double 381 student = gstudent 0 1 382 {-# INLINABLE student #-} 383 384 -- | An isotropic or spherical Gaussian distribution with specified mean 385 -- vector and scalar standard deviation parameter. 386 -- 387 -- Note that `sd` should be positive. 388 isoNormal 389 :: (Traversable f, PrimMonad m) => f Double -> Double -> Prob m (f Double) 390 isoNormal ms sd = traverse (`normal` sd) ms 391 {-# INLINABLE isoNormal #-} 392 393 -- | The inverse Gaussian (also known as Wald) distribution with mean parameter 394 -- `mu` and shape parameter `lambda`. 395 -- 396 -- Note that both 'mu' and 'lambda' should be positive. 397 inverseGaussian :: PrimMonad m => Double -> Double -> Prob m Double 398 inverseGaussian lambda mu = do 399 nu <- standardNormal 400 let y = nu ** 2 401 s = sqrt (4 * mu * lambda * y + mu ** 2 * y ** 2) 402 x = mu * (1 + 1 / (2 * lambda) * (mu * y - s)) 403 thresh = mu / (mu + x) 404 z <- uniform 405 if z <= thresh 406 then return x 407 else return (mu ** 2 / x) 408 {-# INLINABLE inverseGaussian #-} 409 410 -- | The Poisson distribution with rate parameter `l`. 411 -- 412 -- Note that `l` should be positive. 413 poisson :: PrimMonad m => Double -> Prob m Int 414 poisson l = Prob $ genFromTable table where 415 table = tablePoisson l 416 {-# INLINABLE poisson #-} 417 418 -- | A categorical distribution defined by the supplied probabilities. 419 -- 420 -- Note that the supplied probability container should consist of non-negative 421 -- values but is not required to sum to one. 422 categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int 423 categorical ps = do 424 xs <- multinomial 1 ps 425 case xs of 426 [x] -> return x 427 _ -> error "mwc-probability: invalid probability vector" 428 {-# INLINABLE categorical #-} 429 430 -- | A categorical distribution defined by the supplied support. 431 -- 432 -- Note that the supplied probabilities should be non-negative, but are not 433 -- required to sum to one. 434 -- 435 -- >>> samples 10 (discrete [(0.1, "yeah"), (0.9, "nah")]) gen 436 -- ["yeah","nah","nah","nah","nah","yeah","nah","nah","nah","nah"] 437 discrete :: (Foldable f, PrimMonad m) => f (Double, a) -> Prob m a 438 discrete d = do 439 let (ps, xs) = unzip (F.toList d) 440 idx <- categorical ps 441 pure (xs !! idx) 442 {-# INLINABLE discrete #-} 443 444 -- | The Zipf-Mandelbrot distribution. 445 -- 446 -- Note that `a` should be positive, and that values close to 1 should be 447 -- avoided as they are very computationally intensive. 448 -- 449 -- >>> samples 10 (zipf 1.1) gen 450 -- [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50] 451 -- 452 -- >>> samples 10 (zipf 1.5) gen 453 -- [19,3,3,1,1,2,1,191,2,1] 454 zipf :: (PrimMonad m, Integral b) => Double -> Prob m b 455 zipf a = do 456 let 457 b = 2 ** (a - 1) 458 go = do 459 u <- uniform 460 v <- uniform 461 let xInt = floor (u ** (- 1 / (a - 1))) 462 x = fromIntegral xInt 463 t = (1 + 1 / x) ** (a - 1) 464 if v * x * (t - 1) / (b - 1) <= t / b 465 then return xInt 466 else go 467 go 468 {-# INLINABLE zipf #-} 469 470 -- | The Chinese Restaurant Process with concentration parameter `a` and number 471 -- of customers `n`. 472 -- 473 -- See Griffiths and Ghahramani, 2011 for details. 474 -- 475 -- >>> sample (crp 1.8 50) gen 476 -- [22,10,7,1,2,2,4,1,1] 477 crp 478 :: PrimMonad m 479 => Double -- ^ concentration parameter (> 1) 480 -> Int -- ^ number of customers 481 -> Prob m [Integer] 482 crp a n = do 483 ts <- go crpInitial 1 484 pure $ F.toList (fmap getSum ts) 485 where 486 go acc i 487 | i == n = pure acc 488 | otherwise = do 489 acc' <- crpSingle i acc a 490 go acc' (i + 1) 491 {-# INLINABLE crp #-} 492 493 -- | Update step of the CRP 494 crpSingle :: (PrimMonad m, Integral b) => 495 Int 496 -> CRPTables (Sum b) 497 -> Double 498 -> Prob m (CRPTables (Sum b)) 499 crpSingle i zs a = do 500 zn1 <- categorical probs 501 pure $ crpInsert zn1 zs 502 where 503 probs = pms <> [pm1] 504 acc m = fromIntegral m / (fromIntegral i - 1 + a) 505 pms = F.toList $ fmap (acc . getSum) zs 506 pm1 = a / (fromIntegral i - 1 + a) 507 508 -- Tables at the Chinese Restaurant 509 newtype CRPTables c = CRP { 510 getCRPTables :: IM.IntMap c 511 } deriving (Eq, Show, Functor, Foldable, Semigroup, Monoid) 512 513 -- Initial state of the CRP : one customer sitting at table #0 514 crpInitial :: CRPTables (Sum Integer) 515 crpInitial = crpInsert 0 mempty 516 517 -- Seat one customer at table 'k' 518 crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a) 519 crpInsert k (CRP ts) = CRP $ IM.insertWith (<>) k (Sum 1) ts 520