commit 88bf84e5e14cb7be679b940636980178cd9b4e43
parent b7d1454a2be517ecbfd79b46dd580c4359824796
Author: Jared Tobin <jared@jtobin.ca>
Date: Tue, 9 Feb 2016 11:16:14 +1300
Add some FMM stuff.
Diffstat:
3 files changed, 81 insertions(+), 0 deletions(-)
diff --git a/dirichlet-process-mixture/src/Generative.hs b/dirichlet-process-mixture/src/Generative.hs
@@ -0,0 +1,51 @@
+{-# LANGUAGE NoMonomorphismRestriction #-}
+
+import Control.Monad
+import System.Random.MWC.Probability
+
+mixing :: Int -> Prob IO [Double]
+mixing k = do
+ a <- inverseGamma 1 1
+ symmetricDirichlet k a
+
+params :: Int -> Double -> Double -> Prob IO [(Double, Double)]
+params k muy vary = do
+ l <- normal muy (sqrt vary)
+ r <- gamma 1 (recip (sqrt vary))
+ mus <- replicateM k (normal l (sqrt (recip r)))
+
+ b <- inverseGamma 1 1
+ w <- gamma 1 vary
+ ss <- replicateM k (gamma b (recip w))
+
+ return $ zip mus ss
+
+fmm :: Int -> Double -> Double -> Prob IO [Double]
+fmm k muy vary = do
+ (mus, ss) <- fmap unzip (params k muy vary)
+ pis <- mixing k
+
+ xs <- zipWithM normal mus (fmap (sqrt . recip) ss)
+
+ return $ zipWith (*) pis xs
+
+conditional n k muy vary = do
+ (mus, ss) <- fmap unzip (params k muy vary)
+ let fs = zipWithM normal mus (fmap (sqrt . recip) ss)
+ pis <- mixing k
+
+ replicateM n $ do
+ f <- fs
+ return $ zipWith (*) pis f
+
+main :: IO ()
+main = do
+ samples <- withSystemRandom . asGenIO $ \gen ->
+ -- replicateM 5000 (sample (fmm 5 1 1) gen)
+ sample (conditional 5000 5 1 1) gen
+ let pretty = putStrLn . filter (`notElem` "[]") . show
+ mapM_ pretty samples
+
+
+
+
diff --git a/dirichlet-process-mixture/src/explore_generative.r b/dirichlet-process-mixture/src/explore_generative.r
@@ -0,0 +1,12 @@
+require(ggplot2)
+require(reshape)
+require(dplyr)
+
+d = read.csv('tmp.dat', header = F, colClasses = 'numeric')
+
+melted = melt(d)
+
+mixturePriorPlot =
+ ggplot(melted, aes(value, fill = variable, colour = variable)) +
+ geom_density(alpha = 0.2)
+
diff --git a/dirichlet-process-mixture/src/spiral.r b/dirichlet-process-mixture/src/spiral.r
@@ -0,0 +1,18 @@
+require(MASS)
+require(scatterplot3d)
+
+set.seed(42)
+
+n = 800
+t = sort(runif(n) * 4 * pi)
+
+x = (13 - 0.5 * t) * cos(t)
+y = (13 - 0.5 * t) * sin(t)
+Z = mvrnorm(n, mu = rep(0, 3), Sigma = 0.5 * diag(3))
+
+X = matrix(data = c(x, y, t), nrow = n, ncol = 3) + Z
+
+# visualization
+
+# quartz()
+# scatterplot3d(X, highlight.3d = T, pch = 19)