mwc-probability

Sampling function-based probability distributions.
Log | Files | Refs | README | LICENSE

commit c1f935cdd10d054b41c74b9c6d7ed885712b9b9e
parent 0a2a9085a46c8ac3d9b40eb3471ad8a783ffa16b
Author: Jared Tobin <jared@jtobin.ca>
Date:   Tue,  6 Oct 2015 16:32:25 +1300

Add 'samples', remove inlines.

Diffstat:
Mmwc-probability.cabal | 4++--
Msrc/System/Random/MWC/Probability.hs | 32++++++--------------------------
Mstack.yaml | 2+-
3 files changed, 9 insertions(+), 29 deletions(-)

diff --git a/mwc-probability.cabal b/mwc-probability.cabal @@ -1,5 +1,5 @@ name: mwc-probability -version: 0.3.3.0 +version: 0.3.4.0 homepage: http://github.com/jtobin/mwc-probability license: MIT license-file: LICENSE @@ -49,7 +49,7 @@ library default-language: Haskell2010 hs-source-dirs: src build-depends: - base >= 4.7 && < 4.8 + base , mwc-random , primitive , transformers diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs @@ -3,6 +3,8 @@ module System.Random.MWC.Probability ( module MWC , Prob(..) + , samples + , uniform , uniformR , discreteUniform @@ -35,17 +37,18 @@ import qualified System.Random.MWC as QMWC import qualified System.Random.MWC.Distributions as MWC.Dist import System.Random.MWC.CondensedTable +-- | A probability distribution characterized by a sampling function. newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a } +samples :: PrimMonad m => Prob m a -> Int -> Gen (PrimState m) -> m [a] +samples model n gen = replicateM n (sample model gen) + instance Monad m => Functor (Prob m) where fmap h (Prob f) = Prob $ liftM h . f - {-# INLINE fmap #-} instance Monad m => Applicative (Prob m) where pure = return (<*>) = ap - {-# INLINE pure #-} - {-# INLINE (<*>) #-} instance (Applicative m, Monad m, Num a) => Num (Prob m a) where (+) = liftA2 (+) @@ -60,79 +63,61 @@ instance Monad m => Monad (Prob m) where m >>= h = Prob $ \g -> do z <- sample m g sample (h z) g - {-# INLINE return #-} - {-# INLINE (>>=) #-} instance MonadTrans Prob where lift m = Prob $ const m - {-# INLINE lift #-} uniform :: (PrimMonad m, Variate a) => Prob m a uniform = Prob QMWC.uniform -{-# INLINE uniform #-} uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a uniformR r = Prob $ QMWC.uniformR r -{-# INLINE uniformR #-} discreteUniform :: PrimMonad m => [a] -> Prob m a discreteUniform cs = do j <- uniformR (0, length cs - 1) return $ cs !! j -{-# INLINE discreteUniform #-} standard :: PrimMonad m => Prob m Double standard = Prob MWC.Dist.standard -{-# INLINE standard #-} normal :: PrimMonad m => Double -> Double -> Prob m Double normal m sd = Prob $ MWC.Dist.normal m sd -{-# INLINE normal #-} logNormal :: PrimMonad m => Double -> Double -> Prob m Double logNormal m sd = exp <$> normal m sd -{-# INLINE logNormal #-} exponential :: PrimMonad m => Double -> Prob m Double exponential r = Prob $ MWC.Dist.exponential r -{-# INLINE exponential #-} gamma :: PrimMonad m => Double -> Double -> Prob m Double gamma a b = Prob $ MWC.Dist.gamma a b -{-# INLINE gamma #-} inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double inverseGamma a b = recip <$> gamma a b -{-# INLINE inverseGamma #-} chiSquare :: PrimMonad m => Int -> Prob m Double chiSquare k = Prob $ MWC.Dist.chiSquare k -{-# INLINE chiSquare #-} beta :: PrimMonad m => Double -> Double -> Prob m Double beta a b = do u <- gamma a 1 w <- gamma b 1 return $ u / (u + w) -{-# INLINE beta #-} dirichlet :: PrimMonad m => [Double] -> Prob m [Double] dirichlet as = do zs <- mapM (`gamma` 1) as return $ map (/ sum zs) zs -{-# INLINE dirichlet #-} symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double] symmetricDirichlet n a = dirichlet (replicate n a) -{-# INLINE symmetricDirichlet #-} bernoulli :: PrimMonad m => Double -> Prob m Bool bernoulli p = (< p) <$> uniform -{-# INLINE bernoulli #-} binomial :: PrimMonad m => Int -> Double -> Prob m Int binomial n p = liftM (length . filter id) $ replicateM n (bernoulli p) -{-# INLINE binomial #-} multinomial :: PrimMonad m => Int -> [Double] -> Prob m [Int] multinomial n ps = do @@ -141,22 +126,18 @@ multinomial n ps = do z <- uniform let Just g = findIndex (> z) cumulative return g -{-# INLINE multinomial #-} t :: PrimMonad m => Double -> Double -> Double -> Prob m Double t m s k = do sd <- sqrt <$> inverseGamma (k / 2) (s * 2 / k) normal m sd -{-# INLINE t #-} isoGauss :: PrimMonad m => [Double] -> Double -> Prob m [Double] isoGauss ms sd = mapM (\m -> normal m sd) ms -{-# INLINE isoGauss #-} poisson :: PrimMonad m => Double -> Prob m Int poisson l = Prob $ genFromTable table where table = tablePoisson l -{-# INLINE poisson #-} categorical :: PrimMonad m => [Double] -> Prob m Int categorical ps = do @@ -164,5 +145,4 @@ categorical ps = do case xs of [x] -> return x _ -> error "categorical: invalid return value" -{-# INLINE categorical #-} diff --git a/stack.yaml b/stack.yaml @@ -2,4 +2,4 @@ flags: {} packages: - '.' extra-deps: [] -resolver: lts-2.21 +resolver: lts-3.3