commit 59d68c41432b67cbc77227af947dd74de1e08bec
parent a71a0f1d52f7500c859bc9d31247086bef715302
Author: Jared Tobin <jared@jtobin.ca>
Date: Wed, 16 Mar 2016 22:37:17 +1300
Fix collapsed sampler.
Diffstat:
2 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_multivariate_conditional_collapsed.r b/finite-gaussian-mixture/src/fmm_multivariate_conditional_collapsed.r
@@ -7,7 +7,12 @@ cluster_statistics = function(cluster, l, b, w) {
centered = as.matrix(cluster) - ybar
ss = t(centered) %*% centered
ln = (l + n * ybar) / (1 + n)
- tn = solve(solve(w) + ss + n / (1 + n) * ((l - ybar) %*% t(l - ybar)))
+ tn =
+ if (n == 0) {
+ w
+ } else {
+ solve(solve(w) + ss + n / (1 + n) * ((l - ybar) %*% t(l - ybar)))
+ }
bn = b + n
df = bn - m + 1
mu = ln
@@ -21,7 +26,7 @@ cluster_statistics = function(cluster, l, b, w) {
conditional_label_model = function(y, k, z, a, l, r, b, w) {
cluster_labels = seq(k)
- rows = sample(seq(nrow(y)))
+ rows = seq(nrow(y))
m = ncol(y)
initial_clusters = sapply(
@@ -44,8 +49,9 @@ conditional_label_model = function(y, k, z, a, l, r, b, w) {
score_by_cluster = function(j) {
sufficient_stats = if (j == old_label) {
- cluster = y_censored[which(z_censored == j), ]
- cluster_statistics(cluster, l, b, w)
+ cluster = y_censored[which(z_censored == j), ]
+ sufficient_statistics[[j]] <<- cluster_statistics(cluster, l, b, w)
+ sufficient_statistics[[j]]
} else {
sufficient_statistics[[j]]
}
@@ -64,7 +70,7 @@ conditional_label_model = function(y, k, z, a, l, r, b, w) {
new_label = sample(cluster_labels, size = 1, prob = probs)
z[i] <<- new_label
- new_stats = cluster_statistics(y[which(z == new_label),], l, b, w)
+ new_stats = cluster_statistics(y[which(z == new_label),], l, b, w) # FIXME: ???
sufficient_statistics[[new_label]] <<- new_stats
new_label
diff --git a/finite-gaussian-mixture/src/sim_fmm_2d_conditional_collapsed.r b/finite-gaussian-mixture/src/sim_fmm_2d_conditional_collapsed.r
@@ -26,7 +26,7 @@ d = list(
dn = lapply(d, function(j) { data.frame(x = j[,1], y = j[,2]) })
m = melt(dn, id.vars = c('x', 'y'))
-set.seed(222) #990909)
+set.seed(990909)
params = inverse_model(
config$n, config$k, m[, c('x', 'y')]