commit 911d5644aae7e31ddc462f512a9aac80fb7c3fa1
parent 864440abffc82e3fbdac5cfd9c47142b441fc363
Author: Jared Tobin <jared@jtobin.ca>
Date: Thu, 10 Mar 2016 22:18:50 +1300
More skeleton work.
Diffstat:
2 files changed, 65 insertions(+), 36 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_multivariate_conditional.r b/finite-gaussian-mixture/src/fmm_multivariate_conditional.r
@@ -49,7 +49,6 @@ conditional_mixing_model = function(y, k, z, a) {
drop(rdirichlet(1, concentration))
}
-# FIXME (jtobin): seems ok but need to check
conditional_label_model = function(y, p, m, s) {
scorer = function(mix, mu, prec) {
exp(log(mix) + dmvnorm(y, mu, solve(prec), log = T))
@@ -77,7 +76,6 @@ conditional_label_model = function(y, p, m, s) {
unname(clusters)
}
-# FIXME (jtobin): this will change quite a bit
conditional_location_model = function(y, z, s, l, r) {
labelled = cbind(y, L1 = z)
@@ -86,59 +84,81 @@ conditional_location_model = function(y, z, s, l, r) {
}
clustered = lapply(seq_along(s), function(j) { cluster(labelled, j) })
- lengths = lapply(clustered, nrow)
- sums = lapply(clustered, function(foo) { apply(foo, MARGIN = 2, sum) })
+ n = lapply(clustered, nrow)
+ yt = lapply(clustered, function(foo) { apply(foo, MARGIN = 2, sum) })
- n = lengths
- yt = sums
+ # FIXME (jtobin): move out of function
+ listcols = function(mat) {
+ lapply(seq(ncol(mat)), function(j) t(matrix(mat[, j])))
+ }
+
+ # FIXME (jtobin): reduce duplication
+ listcolsSquare = function(mat) {
+ lapply(
+ seq(ncol(mat))
+ , function(j) t(matrix(mat[, j], nrow = sqrt(nrow(mat))))
+ )
+ }
+
+ muls = function(a, b) {
+ v = mapply('*', a, b)
+ listcolsSquare(v)
+ }
+
+ num0 = listcols(mapply('%*%', yt, s))
+ num1 = l %*% r
+ num = lapply(num0, function(z) z + num1)
+ den = lapply(muls(n, s), function(z) z + r)
- # FIXME (jtobin): these must be multivariate quantities
- m = (yt * s + l * r) / (n * s + r)
- v = 1 / (n * s + r)
+ v = lapply(den, solve)
+ m = listcols(mapply('%*%', num, v))
- # FIXME (jtobin): needs to be rmvnorm
- mapply(rnorm, 1, m, v)
+ listcols(mapply(rmvnorm, 1, m, v))
}
-# FIXME (jtobin): this will change quite a bit
conditional_precision_model = function(y, z, m, b, w) {
- labelled = data.frame(value = y, L1 = z)
- clustered = group_by(labelled, L1)
- acc = list()
+ labelled = cbind(y, L1 = z)
+ cluster = function(d, j) {
+ vals = d[which(d$L1 == j), !(names(d) %in% 'L1')]
+ }
+
+ clustered = lapply(seq_along(m), function(j) { cluster(labelled, j) })
+ yt = lapply(clustered, function(foo) { apply(foo, MARGIN = 2, sum) })
+
+ centered = list()
for (j in seq_along(m)) {
- acc[[j]] = labelled[which(labelled$L1 == j), 'value']
- }
+ centered[[j]] = clustered[[j]] - m[[j]]
+ }
- centered = mapply("-", acc, m)
- squared = lapply(centered, function(x) x ^ 2)
- ss = unlist(lapply(squared, sum))
+ ss = lapply(centered, function(x) t(as.matrix(x)) %*% as.matrix(x))
+ n = lapply(clustered, nrow)
- n = sapply(seq_along(m),
- function(cluster) {
- lengths = summarise(clustered, value = n())
- idx = which(lengths$L1 == cluster)
- if (length(idx) != 0) {
- lengths$value[idx]
- } else {
- 0
- }
- })
+ # FIXME reduce duplication
+ listcolsSquare = function(mat) {
+ lapply(
+ seq(ncol(mat))
+ , function(j) t(matrix(mat[, j], nrow = sqrt(nrow(mat))))
+ )
+ }
- a = b + n
- bet = (w * b + ss) / a
+ a = lapply(n, function(j) j + b)
+ bet0 = lapply(ss, function(j) { (j + w * b) })
+ bet1 = mapply('/', bet0, a)
+ bet = listcolsSquare(bet1)
- mapply(function(a, b) rgamma(1, a, b), a, bet)
+ listcolsSquare(mapply(function(a, b) rWishart(1, a, b), a, bet))
}
+# FIXME (jtobin): not correct
inverse_model = function(n, k, y, a, l, r, b, w) {
gibbs = function(p0, m0, s0) {
z = conditional_label_model(y, p0, m0, s0)
p1 = conditional_mixing_model(y, k, z, a)
m1 = conditional_location_model(y, z, s0, l, r)
s1 = conditional_precision_model(y, z, m1, b, w)
- l = lmodel(y, z, p1, m1, s1)
- list(p = p1, m = m1, s = s1, z = z, l = l)
+ # l = lmodel(y, z, p1, m1, s1)
+ list(p = p1, m = m1, s = s1, z = z) # l = l)
}
p0 = mixing_model(k, a)
@@ -153,7 +173,7 @@ inverse_model = function(n, k, y, a, l, r, b, w) {
acc$m = rbind(acc$m, params$m)
acc$s = rbind(acc$s, params$s)
acc$z = rbind(acc$z, params$z)
- acc$l = c(acc$l, params$l)
+ # acc$l = c(acc$l, params$l)
}
acc
}
diff --git a/finite-gaussian-mixture/src/fmm_multivariate_generative.r b/finite-gaussian-mixture/src/fmm_multivariate_generative.r
@@ -32,6 +32,15 @@ data_model = function(config) {
model = function(m, k, n) parameter_model(m, k, n) %>% data_model
+# FIXME (jtobin): checkme
+lmodel = function(y, z, p, m, t) {
+ clustered = cbind(y, L1 = z)
+ cluster = clustered$L1
+ score = log(p[cluster]) +
+ dmvnorm(clustered$value, m[cluster], solve(p[cluster]), log = T)
+ sum(score)
+}
+
# utilities
rinvwishart = function(n, v, S) {