commit c1f935cdd10d054b41c74b9c6d7ed885712b9b9e
parent 0a2a9085a46c8ac3d9b40eb3471ad8a783ffa16b
Author: Jared Tobin <jared@jtobin.ca>
Date: Tue, 6 Oct 2015 16:32:25 +1300
Add 'samples', remove inlines.
Diffstat:
3 files changed, 9 insertions(+), 29 deletions(-)
diff --git a/mwc-probability.cabal b/mwc-probability.cabal
@@ -1,5 +1,5 @@
name: mwc-probability
-version: 0.3.3.0
+version: 0.3.4.0
homepage: http://github.com/jtobin/mwc-probability
license: MIT
license-file: LICENSE
@@ -49,7 +49,7 @@ library
default-language: Haskell2010
hs-source-dirs: src
build-depends:
- base >= 4.7 && < 4.8
+ base
, mwc-random
, primitive
, transformers
diff --git a/src/System/Random/MWC/Probability.hs b/src/System/Random/MWC/Probability.hs
@@ -3,6 +3,8 @@
module System.Random.MWC.Probability (
module MWC
, Prob(..)
+ , samples
+
, uniform
, uniformR
, discreteUniform
@@ -35,17 +37,18 @@ import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
import System.Random.MWC.CondensedTable
+-- | A probability distribution characterized by a sampling function.
newtype Prob m a = Prob { sample :: Gen (PrimState m) -> m a }
+samples :: PrimMonad m => Prob m a -> Int -> Gen (PrimState m) -> m [a]
+samples model n gen = replicateM n (sample model gen)
+
instance Monad m => Functor (Prob m) where
fmap h (Prob f) = Prob $ liftM h . f
- {-# INLINE fmap #-}
instance Monad m => Applicative (Prob m) where
pure = return
(<*>) = ap
- {-# INLINE pure #-}
- {-# INLINE (<*>) #-}
instance (Applicative m, Monad m, Num a) => Num (Prob m a) where
(+) = liftA2 (+)
@@ -60,79 +63,61 @@ instance Monad m => Monad (Prob m) where
m >>= h = Prob $ \g -> do
z <- sample m g
sample (h z) g
- {-# INLINE return #-}
- {-# INLINE (>>=) #-}
instance MonadTrans Prob where
lift m = Prob $ const m
- {-# INLINE lift #-}
uniform :: (PrimMonad m, Variate a) => Prob m a
uniform = Prob QMWC.uniform
-{-# INLINE uniform #-}
uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a
uniformR r = Prob $ QMWC.uniformR r
-{-# INLINE uniformR #-}
discreteUniform :: PrimMonad m => [a] -> Prob m a
discreteUniform cs = do
j <- uniformR (0, length cs - 1)
return $ cs !! j
-{-# INLINE discreteUniform #-}
standard :: PrimMonad m => Prob m Double
standard = Prob MWC.Dist.standard
-{-# INLINE standard #-}
normal :: PrimMonad m => Double -> Double -> Prob m Double
normal m sd = Prob $ MWC.Dist.normal m sd
-{-# INLINE normal #-}
logNormal :: PrimMonad m => Double -> Double -> Prob m Double
logNormal m sd = exp <$> normal m sd
-{-# INLINE logNormal #-}
exponential :: PrimMonad m => Double -> Prob m Double
exponential r = Prob $ MWC.Dist.exponential r
-{-# INLINE exponential #-}
gamma :: PrimMonad m => Double -> Double -> Prob m Double
gamma a b = Prob $ MWC.Dist.gamma a b
-{-# INLINE gamma #-}
inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double
inverseGamma a b = recip <$> gamma a b
-{-# INLINE inverseGamma #-}
chiSquare :: PrimMonad m => Int -> Prob m Double
chiSquare k = Prob $ MWC.Dist.chiSquare k
-{-# INLINE chiSquare #-}
beta :: PrimMonad m => Double -> Double -> Prob m Double
beta a b = do
u <- gamma a 1
w <- gamma b 1
return $ u / (u + w)
-{-# INLINE beta #-}
dirichlet :: PrimMonad m => [Double] -> Prob m [Double]
dirichlet as = do
zs <- mapM (`gamma` 1) as
return $ map (/ sum zs) zs
-{-# INLINE dirichlet #-}
symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
symmetricDirichlet n a = dirichlet (replicate n a)
-{-# INLINE symmetricDirichlet #-}
bernoulli :: PrimMonad m => Double -> Prob m Bool
bernoulli p = (< p) <$> uniform
-{-# INLINE bernoulli #-}
binomial :: PrimMonad m => Int -> Double -> Prob m Int
binomial n p = liftM (length . filter id) $ replicateM n (bernoulli p)
-{-# INLINE binomial #-}
multinomial :: PrimMonad m => Int -> [Double] -> Prob m [Int]
multinomial n ps = do
@@ -141,22 +126,18 @@ multinomial n ps = do
z <- uniform
let Just g = findIndex (> z) cumulative
return g
-{-# INLINE multinomial #-}
t :: PrimMonad m => Double -> Double -> Double -> Prob m Double
t m s k = do
sd <- sqrt <$> inverseGamma (k / 2) (s * 2 / k)
normal m sd
-{-# INLINE t #-}
isoGauss :: PrimMonad m => [Double] -> Double -> Prob m [Double]
isoGauss ms sd = mapM (\m -> normal m sd) ms
-{-# INLINE isoGauss #-}
poisson :: PrimMonad m => Double -> Prob m Int
poisson l = Prob $ genFromTable table where
table = tablePoisson l
-{-# INLINE poisson #-}
categorical :: PrimMonad m => [Double] -> Prob m Int
categorical ps = do
@@ -164,5 +145,4 @@ categorical ps = do
case xs of
[x] -> return x
_ -> error "categorical: invalid return value"
-{-# INLINE categorical #-}
diff --git a/stack.yaml b/stack.yaml
@@ -2,4 +2,4 @@ flags: {}
packages:
- '.'
extra-deps: []
-resolver: lts-2.21
+resolver: lts-3.3