bnp

Some older Bayesian nonparametrics research.
Log | Files | Refs | README | LICENSE

commit 88bf84e5e14cb7be679b940636980178cd9b4e43
parent b7d1454a2be517ecbfd79b46dd580c4359824796
Author: Jared Tobin <jared@jtobin.ca>
Date:   Tue,  9 Feb 2016 11:16:14 +1300

Add some FMM stuff.

Diffstat:
Adirichlet-process-mixture/src/Generative.hs | 51+++++++++++++++++++++++++++++++++++++++++++++++++++
Adirichlet-process-mixture/src/explore_generative.r | 12++++++++++++
Adirichlet-process-mixture/src/spiral.r | 18++++++++++++++++++
3 files changed, 81 insertions(+), 0 deletions(-)

diff --git a/dirichlet-process-mixture/src/Generative.hs b/dirichlet-process-mixture/src/Generative.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE NoMonomorphismRestriction #-} + +import Control.Monad +import System.Random.MWC.Probability + +mixing :: Int -> Prob IO [Double] +mixing k = do + a <- inverseGamma 1 1 + symmetricDirichlet k a + +params :: Int -> Double -> Double -> Prob IO [(Double, Double)] +params k muy vary = do + l <- normal muy (sqrt vary) + r <- gamma 1 (recip (sqrt vary)) + mus <- replicateM k (normal l (sqrt (recip r))) + + b <- inverseGamma 1 1 + w <- gamma 1 vary + ss <- replicateM k (gamma b (recip w)) + + return $ zip mus ss + +fmm :: Int -> Double -> Double -> Prob IO [Double] +fmm k muy vary = do + (mus, ss) <- fmap unzip (params k muy vary) + pis <- mixing k + + xs <- zipWithM normal mus (fmap (sqrt . recip) ss) + + return $ zipWith (*) pis xs + +conditional n k muy vary = do + (mus, ss) <- fmap unzip (params k muy vary) + let fs = zipWithM normal mus (fmap (sqrt . recip) ss) + pis <- mixing k + + replicateM n $ do + f <- fs + return $ zipWith (*) pis f + +main :: IO () +main = do + samples <- withSystemRandom . asGenIO $ \gen -> + -- replicateM 5000 (sample (fmm 5 1 1) gen) + sample (conditional 5000 5 1 1) gen + let pretty = putStrLn . filter (`notElem` "[]") . show + mapM_ pretty samples + + + + diff --git a/dirichlet-process-mixture/src/explore_generative.r b/dirichlet-process-mixture/src/explore_generative.r @@ -0,0 +1,12 @@ +require(ggplot2) +require(reshape) +require(dplyr) + +d = read.csv('tmp.dat', header = F, colClasses = 'numeric') + +melted = melt(d) + +mixturePriorPlot = + ggplot(melted, aes(value, fill = variable, colour = variable)) + + geom_density(alpha = 0.2) + diff --git a/dirichlet-process-mixture/src/spiral.r b/dirichlet-process-mixture/src/spiral.r @@ -0,0 +1,18 @@ +require(MASS) +require(scatterplot3d) + +set.seed(42) + +n = 800 +t = sort(runif(n) * 4 * pi) + +x = (13 - 0.5 * t) * cos(t) +y = (13 - 0.5 * t) * sin(t) +Z = mvrnorm(n, mu = rep(0, 3), Sigma = 0.5 * diag(3)) + +X = matrix(data = c(x, y, t), nrow = n, ncol = 3) + Z + +# visualization + +# quartz() +# scatterplot3d(X, highlight.3d = T, pch = 19)