commit 309f8268dbe45adc940ffbf30209e33b644eb19a
parent 7c0df4efdb90e88688760fb18cad3fffd13b62fa
Author: Jared Tobin <jared@jtobin.ca>
Date: Wed, 16 Mar 2016 18:33:51 +1300
Add collapsed sampler.
Diffstat:
2 files changed, 143 insertions(+), 5 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_conditional.r b/finite-gaussian-mixture/src/fmm_conditional.r
@@ -19,7 +19,7 @@ conditional_mixing_model = function(y, k, z, a) {
})
drop(rdirichlet(1, concentration))
- }
+}
conditional_label_model = function(y, p, m, s) {
scorer = function(mix, mu, prec) {
@@ -44,7 +44,7 @@ conditional_label_model = function(y, p, m, s) {
, MARGIN = 1
, function(row) { sample(seq_along(m), size = 1, prob = row) }
)
- }
+}
conditional_location_model = function(y, z, s, l, r) {
clustered = group_by(data.frame(value = y, L1 = z), L1)
@@ -75,7 +75,7 @@ conditional_location_model = function(y, z, s, l, r) {
v = 1 / (n * s + r)
mapply(rnorm, 1, m, sqrt(v))
- }
+}
conditional_precision_model = function(y, z, m, b, w) {
labelled = data.frame(value = y, L1 = z)
@@ -105,7 +105,7 @@ conditional_precision_model = function(y, z, m, b, w) {
bet = (w * b + ss) / a
mapply(function(a, b) rgamma(1, a, b), a, bet)
- }
+}
inverse_model = function(n, k, y, a, l, r, b, w) {
gibbs = function(p0, m0, s0) {
@@ -132,5 +132,5 @@ inverse_model = function(n, k, y, a, l, r, b, w) {
acc$l = c(acc$l, params$l)
}
acc
- }
+}
diff --git a/finite-gaussian-mixture/src/fmm_multivariate_conditional_collapsed.r b/finite-gaussian-mixture/src/fmm_multivariate_conditional_collapsed.r
@@ -0,0 +1,138 @@
+require(mvtnorm)
+
+cluster_statistics = function(cluster, l, b, w) {
+ n = nrow(cluster)
+ ybar = colMeans(cluster)
+ centered = as.matrix(cluster) - ybar
+ ss = t(centered) %*% centered
+ ln = (l + n * ybar) / (1 + n)
+ tn = solve(solve(w) + ss + n / (1 + n) * ((l - ybar) %*% t(l - ybar)))
+ bn = b + n
+ df = bn - m + 1
+ mu = ln
+ coef = n / ((n + 1) * df)
+ v = coef * solve(tn)
+ list(
+ n = n, ln = ln, tn = tn, bn = bn
+ , df = df, mu = mu, v = v
+ )
+}
+
+# FIXME (jtobin): more efficient to cache sufficient statistics in gibbs loop
+conditional_label_model = function(y, k, z, a, l, r, b, w) {
+ m = ncol(y)
+ cluster_labels = seq(k)
+ rows = sample(seq(nrow(y)))
+
+ initial_clusters = sapply(
+ cluster_labels
+ , function(j) { y[which(z == j),] }
+ , simplify = F)
+
+ sufficient_statistics = lapply(
+ initial_clusters
+ , function(c) { cluster_statistics(c, l, b, w) })
+
+ relabel = function(i) {
+ old_label = z[i]
+ val = y[i,]
+ y_censored = y[-i,]
+ z_censored = z[-i]
+ n_censored = sapply(
+ cluster_labels
+ , function(j) { length(which(z_censored == j)) })
+
+ score_by_cluster = function(j) {
+ sufficient_stats = if (j == old_label) {
+ cluster = y_censored[which(z_censored == j), ]
+ cluster_statistics(cluster, l, b, w)
+ } else {
+ sufficient_statistics[[j]]
+ }
+ dmvt(
+ val
+ , df = sufficient_stats$df
+ , sigma = sufficient_stats$v
+ , delta = sufficient_stats$mu
+ , log = T
+ )
+ }
+
+ scores = exp(sapply(cluster_labels, score_by_cluster))
+ weight = n_censored + a / k
+ probs = scores * weight / sum(scores * weight)
+ new_label = sample(cluster_labels, size = 1, prob = probs)
+
+ # MUTATION
+ z[i] <- new_label
+ new_stats = cluster_statistics(y[which(z == new_label),], l, b, w)
+ sufficient_statistics[[new_label]] <- new_stats
+
+ new_label
+ }
+ sapply(rows, relabel)
+}
+
+inverse_model = function(n, y, k, a, l, r, b, w) {
+ # FIXME (jtobin): add likelihood calculation
+ gibbs = function(z0) {
+ list(z = conditional_label_model(y, k, z0, a, l, r, b, w))
+ }
+ params = list(z = sample(seq(k), size = nrow(y), replace = T))
+ acc = params
+ # FIXME (jtobin): can use replicate
+ for (j in seq(n - 1)) {
+ params = gibbs(params$z)
+ acc$z = rbind(acc$z, params$z)
+ }
+ acc
+}
+
+
+
+# development
+
+require(reshape2) # FIXME move to sim
+require(ggplot2)
+require(gridExtra)
+
+d = list(
+ t(replicate(250, rnorm(2, c(5, 5))))
+ , t(replicate(250, rnorm(2, c(-5, -5))))
+ , t(replicate(500, rnorm(2))))
+dn = lapply(d, function(j) { data.frame(x = j[,1], y = j[,2]) })
+m = melt(dn, id.vars = c('x', 'y'))
+
+dimension = 2
+
+config = list(
+ k = 3
+ , m = dimension
+ , a = 1
+ , l = rep(0, dimension)
+ , r = diag(0.05, dimension)
+ , b = 2
+ , w = diag(1, dimension)
+ , n = 1000
+ )
+
+foo = inverse_model(100, y, 3, a, l, r, b, w)
+
+early = data.frame(x = y$x, y = y$y, variable = foo$z[1,])
+mid = data.frame(x = y$x, y = y$y, variable = foo$z[round(80),])
+late = data.frame(x = y$x, y = y$y, variable = foo$z[100 - 1,])
+
+p_early =
+ ggplot(early, aes(x, y, colour = factor(variable), fill = factor(variable))) +
+ geom_point(alpha = 0.5)
+
+p_mid =
+ ggplot(mid, aes(x, y, colour = factor(variable), fill = factor(variable))) +
+ geom_point(alpha = 0.5)
+
+p_late =
+ ggplot(late, aes(x, y, value, colour = factor(variable), fill = factor(variable))) +
+ geom_point(alpha = 0.5)
+
+inferred_plots = grid.arrange(p_early, p_mid, p_late, ncol = 3)
+