commit 4bc704aed8a469c0d232746fb83c45b09627d716
parent 80f818bc63b50b0b46e78872231b1178e93e1ccb
Author: Jared Tobin <jared@jtobin.ca>
Date: Mon, 14 Mar 2016 16:28:00 +1300
Various tweaks to multivariate stuff.
Diffstat:
3 files changed, 50 insertions(+), 17 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_multivariate_conditional.r b/finite-gaussian-mixture/src/fmm_multivariate_conditional.r
@@ -1,9 +1,11 @@
-require(dplyr)
require(gtools)
require(mvtnorm)
source('fmm_multivariate_generative.r')
+# NOTE (jtobin): must load dplyr after plyr
+require(dplyr)
+
conditional_mixing_model = function(y, k, z, a) {
labelled = cbind(y, L1 = z)
counts = summarise(group_by(labelled, L1), count = n())
@@ -20,7 +22,7 @@ conditional_mixing_model = function(y, k, z, a) {
}
conditional_label_model = function(y, p, m, s) {
- scorer = function(mix, mu, prec) {
+ scorer = function(mix, mu, prec) {
exp(log(mix) + dmvnorm(y, mu, solve(prec), log = T))
}
@@ -46,16 +48,45 @@ conditional_label_model = function(y, p, m, s) {
unname(clusters)
}
-conditional_location_model = function(y, z, s, l, r) {
- labelled = cbind(y, L1 = z)
- cluster = function(d, j) {
- vals = d[which(d$L1 == j), !(names(d) %in% 'L1')]
+conditional_cluster_parameters_model = function(y, k, z, l, r, b, w) {
+ labelled = data.frame(y, L1 = z)
+ clustered = lapply(seq(k),
+ function(j) {
+ vals = labelled[which(labelled$L1 == j), !(names(labelled) %in% 'L1')]
+ as.matrix(vals)
+ })
+
+ # FIXME (jtobin): NaN for empty clusters; need to handle this?
+ ybar = lapply(clustered, colMeans)
+ n = lapply(clustered, nrow)
+ pl = function(lj, nj, ybarj) { (lj + nj * ybarj) / (1 + nj) }
+ ln = mapply(pl, list(l), n, ybar, SIMPLIFY = F)
+ centered = mapply('-', clustered, ybar, SIMPLIFY = F)
+ ss = lapply(centered, function(x) t(x) %*% x)
+
+ pt = function(wj, ssj, nj, ybarj) {
+ wj + ssj + nj / (1 + nj) * ((l - ybarj) %*% t(l - ybarj))
}
- clustered = lapply(seq_along(s), function(j) { cluster(labelled, j) })
+ tn = mapply(pt, list(w), ss, n, ybar, SIMPLIFY = F)
+ bn = lapply(n, function(x) x + b)
+ prec = mapply(function(i, j) drop(rWishart(1, i, j)), bn, tn, SIMPLIFY = F)
+ cov = mapply(function(i, j) solve((i + 1) * j), n, tn, SIMPLIFY = F)
+ loc = mapply(rmvnorm, 1, ln, cov, SIMPLIFY = F)
+ list(m = loc, s = prec)
+}
+
+conditional_location_model = function(y, z, s, l, r) {
+ labelled = data.frame(y, L1 = z)
+ clustered = lapply(seq_along(s),
+ function(j) {
+ labelled[which(labelled$L1 == j), !(names(labelled) %in% 'L1')]
+ })
+
n = lapply(clustered, nrow)
yt = lapply(clustered, function(j) { apply(j, MARGIN = 2, sum) })
num0 = mapply('%*%', yt, s, SIMPLIFY = F)
+
num = lapply(num0, function(z) { z + (l %*% r) })
den0 = mapply('*', n, s, SIMPLIFY = F)
den = lapply(den0, function(z) z + r)
@@ -88,14 +119,14 @@ conditional_precision_model = function(y, z, m, b, w) {
mapply(function(i, j) drop(rWishart(1, i, j)), a, bet, SIMPLIFY = F)
}
-# FIXME dubious
inverse_model = function(n, k, y, a, l, r, b, w) {
gibbs = function(p0, m0, s0) {
z = conditional_label_model(y, p0, m0, s0)
p1 = conditional_mixing_model(y, k, z, a)
- m1 = conditional_location_model(y, z, s0, l, r)
- s1 = conditional_precision_model(y, z, m1, b, w)
- l = lmodel(y, z, p1, m1, s1)
+ ps = conditional_cluster_parameters_model(y, k, z, l, r, b, w)
+ m1 = ps$m
+ s1 = ps$s
+ l = lmodel(y, p1, m1, s1)
list(p = p1, m = m1, s = s1, z = z, l = l)
}
diff --git a/finite-gaussian-mixture/src/simulation_multivariate.r b/finite-gaussian-mixture/src/simulation_multivariate.r
@@ -4,7 +4,7 @@ require(scatterplot3d)
source('fmm_multivariate_generative.r')
-# 2d
+# 2d example
config = list(
k = 4
@@ -26,7 +26,7 @@ framed = lapply(d, function(mat) { data.frame(x = mat[,1], y = mat[,2]) })
melted = melt(framed, id.vars = c('x', 'y'))
p = ggplot(melted, aes(x, y, colour = factor(L1))) + geom_point(alpha = 0.2)
-# 3d
+# 3d example
config_3d = list(
k = 4
diff --git a/finite-gaussian-mixture/src/simulation_multivariate_conditional.r b/finite-gaussian-mixture/src/simulation_multivariate_conditional.r
@@ -1,5 +1,3 @@
-set.seed(222)
-
require(ggplot2)
require(gridExtra)
require(reshape2)
@@ -19,6 +17,8 @@ config = list(
, n = 1000
)
+set.seed(222)
+
origin = list(
p = mixing_model(config$k, config$a)
, m = location_model(config$k, config$l, config$r)
@@ -26,12 +26,14 @@ origin = list(
)
# FIXME generate a known/non-pathological configuration first, to test
-d = melt(model(config$m, config$k, config$n), id.vars = c('x', 'y'))
+d = model(config$k, config$l, config$r, config$b, config$w, config$n)
+dn = lapply(d, function(j) { data.frame(x = j[,1], y = j[,2]) })
+m = melt(dn, id.vars = c('x', 'y'))
set.seed(990909)
params = inverse_model(
- config$n, config$k, d[, c('x', 'y')]
+ config$n, config$k, m[, c('x', 'y')]
, config$a
, config$l, config$r
, config$b, config$w