commit c975daccadd9549baf4b693e55893c568899c468
parent 66b30451f61e5d2ecef7d9a22e01798d33fed043
Author: Jared Tobin <jared@jtobin.ca>
Date: Mon, 22 Feb 2016 13:56:33 +1300
Add DPMM.
Diffstat:
2 files changed, 42 insertions(+), 1 deletion(-)
diff --git a/dirichlet-process-mixture/src/dpmm.r b/dirichlet-process-mixture/src/dpmm.r
@@ -0,0 +1,41 @@
+set.seed(42)
+
+BNP_DIR = "/Users/jtobin/projects/bnp"
+SBP_SRC = paste(BNP_DIR, "stick-breaking-process/src/sbp.r", sep = "/")
+
+require(gtools)
+require(magrittr)
+source(SBP_SRC)
+
+mixing_model = function(n, a) {
+ if (n <= 1) {
+ stop("need > 1 observation.")
+ } else {
+ sbp(n - 1, a)
+ }
+ }
+
+label_model = function(n, p) drop(rmultinom(1, size = n, prob = p))
+location_model = function(k, l, r) rnorm(k, l, 1 / r)
+precision_model = function(k, b, w) rgamma(k, b, 1 / w)
+
+parameter_model = function(n, a) {
+ p = mixing_model(n, a)
+ k = length(p)
+ c = label_model(n, p)
+ mu = location_model(k, 0, 0.1)
+ s = precision_model(k, 1, 1)
+ list(c, mu, s)
+ }
+
+data_model = function(config) {
+ sampler = function(y, m, s) rnorm(y, m, 1 / s)
+ mapply(sampler, config[[1]], config[[2]], config[[3]])
+ }
+
+model = function(n, a) {
+ clusters = parameter_model(n, a) %>% data_model
+ nonempty = unlist(lapply(clusters, function(x) length(x) != 0))
+ clusters[nonempty]
+ }
+
diff --git a/stick-breaking-process/src/sbp.r b/stick-breaking-process/src/sbp.r
@@ -3,7 +3,7 @@ sbp = function(n, a) {
for (j in seq(n)) {
bundle = snap(bundle[[1]], bundle[[2]], a)
}
- bundle[[2]]
+ c(bundle[[2]], 1 - sum(bundle[[2]]))
}
snap = function(acc, bun, a) {