deanie

An embedded probabilistic programming language.
git clone git://git.jtobin.io/deanie.git
Log | Files | Refs | README | LICENSE

Measure.hs (2977B)


      1 
      2 module Deanie.Measure (
      3     measure
      4 
      5   -- * queries
      6 
      7   , integrate
      8   , expectation
      9   , variance
     10   , mgf
     11   , cgf
     12   , cdf
     13   ) where
     14 
     15 import Control.Monad
     16 import Data.List (foldl')
     17 import Deanie.Language
     18 import Control.Foldl (Fold)
     19 import Numeric.Integration.TanhSinh
     20 import Numeric.SpecFunctions
     21 
     22 newtype Measure a = Measure ((a -> Double) -> Double)
     23 
     24 integrate :: (a -> Double) -> Measure a -> Double
     25 integrate f (Measure nu) = nu f
     26 
     27 expectation :: Measure Double -> Double
     28 expectation = integrate id
     29 
     30 variance :: Measure Double -> Double
     31 variance nu = integrate (^ 2) nu - expectation nu ^ 2
     32 
     33 mgf :: Measure Double -> Double -> Double
     34 mgf mu t = integrate (\x -> exp (t * x)) mu
     35 
     36 cgf :: Measure Double -> Double -> Double
     37 cgf mu = log . mgf mu
     38 
     39 cdf :: Measure Double -> Double -> Double
     40 cdf nu x = integrate (negativeInfinity `to` x) nu where
     41   negativeInfinity :: Double
     42   negativeInfinity = negate (1 / 0)
     43 
     44   to :: (Num a, Ord a) => a -> a -> a -> a
     45   to a b x
     46     | x >= a && x <= b = 1
     47     | otherwise        = 0
     48 
     49 instance Functor Measure where
     50   fmap f nu = Measure $ \g ->
     51     integrate (g . f) nu
     52 
     53 instance Applicative Measure where
     54   pure x = Measure (\f -> f x)
     55   Measure h <*> Measure g = Measure $ \f ->
     56     h (\k -> g (f . k))
     57 
     58 instance Monad Measure where
     59   return x  = Measure (\f -> f x)
     60   rho >>= g = Measure $ \f ->
     61     integrate (\nu -> integrate f (g nu)) rho
     62 
     63 fromMassFunction :: Foldable f => (a -> Double) -> f a -> Measure a
     64 fromMassFunction f support = Measure $ \g ->
     65   foldl' (\acc x -> acc + f x * g x) 0 support
     66 
     67 fromDensityFunction :: (Double -> Double) -> Measure Double
     68 fromDensityFunction d = Measure $ \f ->
     69     quadratureTanhSinh (\x -> f x * d x)
     70   where
     71     quadratureTanhSinh = result . last . everywhere trap
     72 
     73 mbernoulli :: Double -> Measure Bool
     74 mbernoulli p = fromMassFunction (pmf p) [False, True] where
     75   pmf p x
     76     | p < 0 || p > 1 = 0
     77     | otherwise      = if x then p else 1 - p
     78 
     79 mbeta :: Double -> Double -> Measure Double
     80 mbeta a b = fromDensityFunction (density a b) where
     81   density a b p
     82     | p < 0 || p > 1 = 0
     83     | otherwise      = 1 / exp (logBeta a b) * p ** (a - 1) * (1 - p) ** (b - 1)
     84 
     85 mgamma :: Double -> Double -> Measure Double
     86 mgamma a b = fromDensityFunction (density a b) where
     87   density a b x
     88     | a < 0 || b < 0 = 0
     89     | otherwise  =
     90        b ** a / exp (logGamma a) * x ** (a - 1) * exp (negate (b * x))
     91 
     92 mgaussian :: Double -> Double -> Measure Double
     93 mgaussian m s = fromDensityFunction (density m s) where
     94   density m s x
     95     | s <= 0    = 0
     96     | otherwise =
     97         1 / (s * sqrt (2 * pi)) *
     98           exp (negate ((x - m) ^^ 2) / (2 * (s ^^ 2)))
     99 
    100 measure :: Program a -> Measure a
    101 measure = iterM $ \case
    102     ProgramF (InL term) -> evalAlg term
    103     ProgramF (InR term) -> join (runAp measure term)
    104   where
    105     evalAlg = \case
    106       BernoulliF p k  -> mbernoulli p >>= k
    107       BetaF a b k     -> mbeta a b >>= k
    108       GammaF a b k    -> mgamma a b >>= k
    109       GaussianF m s k -> mgaussian m s >>= k
    110