commit c797f2a3b2edcb87b1a95a1c6091dc0f09492583
parent 46c6ae601d4af640215eb9aa86e6b53487e4d6e4
Author: Jared Tobin <jared@jtobin.ca>
Date: Mon, 21 Oct 2013 16:13:10 +1300
Bump up the polymorphism factor.
Diffstat:
2 files changed, 24 insertions(+), 16 deletions(-)
diff --git a/src/Examples.hs b/src/Examples.hs
@@ -13,7 +13,7 @@ import System.Random.MWC
import System.Random.MWC.Distributions
genGammaSamples
- :: PrimMonad m
+ :: (Applicative m, PrimMonad m)
=> Int
-> Double
-> Double
@@ -22,7 +22,7 @@ genGammaSamples
genGammaSamples n a b g = replicateM n $ gamma a b g
genNormalSamples
- :: PrimMonad m
+ :: (Applicative m, PrimMonad m)
=> Int
-> Double
-> Double
@@ -37,7 +37,7 @@ genNormalSamples n m t g = replicateM n $ normal m (1 / t) g
-- t ~ gamma(a, b)
-- (X, t) ~ NormalGamma(mu, lambda, a, b)
normalGammaMeasure
- :: (Fractional r, PrimMonad m)
+ :: (Fractional r, Applicative m, PrimMonad m)
=> Int
-> Double
-> Double
@@ -58,7 +58,7 @@ normalGammaMeasure n a b mu lambda g = do
-- various return types. Here we have a probability distribution over hash
-- maps.
altNormalGammaMeasure
- :: (Fractional r, PrimMonad m)
+ :: (Fractional r, Applicative m, PrimMonad m)
=> Int
-> Double
-> Double
@@ -76,7 +76,7 @@ altNormalGammaMeasure n a b mu lambda g = do
return $ HashMap.fromList [("location", location), ("precision", precision)]
normalNormalGammaMeasure
- :: (Fractional r, PrimMonad m)
+ :: (Fractional r, Applicative m, PrimMonad m)
=> Int
-> Double
-> Double
@@ -90,7 +90,7 @@ normalNormalGammaMeasure n a b mu lambda g = do
fromObservations normalSamples
altNormalNormalGammaMeasure
- :: (Fractional r, PrimMonad m)
+ :: (Fractional r, Applicative m, PrimMonad m)
=> Int
-> Double
-> Double
diff --git a/src/Measurable/Generic.hs b/src/Measurable/Generic.hs
@@ -5,6 +5,7 @@ module Measurable.Generic where
import Control.Applicative
import Control.Monad
import Control.Monad.Trans.Cont
+import Data.Foldable (Foldable)
import qualified Data.Foldable as Foldable
import Data.List
import Data.Monoid
@@ -19,7 +20,10 @@ measureT :: MeasureT r m a -> (a -> m r) -> m r
measureT = runContT
-- | Create a measure from observations (samples) from some distribution.
-fromObservations :: (Monad m, Fractional r) => [a] -> MeasureT r m a
+fromObservations
+ :: (Applicative m, Monad m, Fractional r, Traversable f)
+ => f a
+ -> MeasureT r m a
fromObservations xs = ContT (`weightedAverageM` xs)
-- A mass function is close to universal when dealing with discrete objects, but
@@ -30,12 +34,12 @@ fromObservations xs = ContT (`weightedAverageM` xs)
-- Maybe we can use something like an 'observed support'. You can probably get
-- inspiration from how the Dirichlet process is handled in practice.
fromMassFunction
- :: (Num r, Applicative f)
+ :: (Num r, Applicative f, Traversable t)
=> (a -> f r)
- -> [a]
+ -> t a
-> MeasureT r f a
fromMassFunction p support = ContT $ \f ->
- fmap sum . traverse (liftA2 (liftA2 (*)) f p) $ support
+ fmap Foldable.sum . traverse (liftA2 (liftA2 (*)) f p) $ support
-- | Expectation is obtained by integrating against the identity function. We
-- provide an additional function for mapping the input type to the output
@@ -73,19 +77,23 @@ containing xs x | x `Set.member` set = 1
where set = Set.fromList xs
-- | Simple average.
-average :: Fractional a => [a] -> a
-average xs = fst $ foldl'
+average :: (Fractional a, Foldable f) => f a -> a
+average xs = fst $ Foldable.foldl'
(\(!m, !n) x -> (m + (x - m) / fromIntegral (n + 1), n + 1)) (0, 0) xs
{-# INLINE average #-}
-- | Weighted average.
-weightedAverage :: Fractional c => (a -> c) -> [a] -> c
-weightedAverage f = average . map f
+weightedAverage :: (Functor f, Foldable f, Fractional c) => (a -> c) -> f a -> c
+weightedAverage f = average . fmap f
{-# INLINE weightedAverage #-}
-- | Monadic weighted average.
-weightedAverageM :: (Fractional c, Monad m) => (a -> m c) -> [a] -> m c
-weightedAverageM f = liftM average . mapM f
+weightedAverageM
+ :: (Fractional c, Traversable f, Monad m, Applicative m)
+ => (a -> m c)
+ -> f a
+ -> m c
+weightedAverageM f = liftM average . traverse f
{-# INLINE weightedAverageM #-}