commit 92da1ee67b8c6eb0d03bdfaee6f9f4d686f78906
parent d42d0b689baa0a50138e81d61f2382e478898c78
Author: Jared Tobin <jared@jtobin.ca>
Date: Wed, 9 Mar 2016 11:12:28 +1300
Add gibbs for FMM.
Diffstat:
2 files changed, 180 insertions(+), 0 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_conditional.r b/finite-gaussian-mixture/src/fmm_conditional.r
@@ -0,0 +1,135 @@
+require(dplyr)
+require(gtools)
+
+source('fmm_generative.r')
+
+conditional_mixing_model = function(y, k, z, a) {
+ labelled = data.frame(value = y, L1 = z)
+ counts = summarise(group_by(labelled, L1), count = n())
+
+ concentration = sapply(
+ seq(k)
+ , function(cluster) {
+ idx = which(counts$L1 == cluster)
+ if (length(idx) != 0) {
+ counts$count[idx] + a / k
+ } else {
+ a / k
+ }
+ })
+
+ drop(rdirichlet(1, concentration))
+ }
+
+conditional_label_model = function(y, p, m, s) {
+ scorer = function(mix, mu, prec) {
+ exp(log(mix) + dnorm(y, mu, sqrt(1 / prec), log = T))
+ }
+ unweighted = mapply(scorer, p, m, s)
+ weights = 1 / apply(unweighted, MARGIN = 1, sum)
+ weighted = weights * unweighted
+
+ probabilize = function(row) {
+ rs = sum(row)
+ if (rs == 0 || is.na(rs) || is.nan(rs)) {
+ drop(rdirichlet(1, rep(1, length(m))))
+ } else {
+ row
+ }
+ }
+
+ probs = t(apply(weighted, MARGIN = 1, probabilize))
+ apply(
+ probs
+ , MARGIN = 1
+ , function(row) { sample(seq_along(m), size = 1, prob = row) }
+ )
+ }
+
+conditional_location_model = function(y, z, s, l, r) {
+ clustered = group_by(data.frame(value = y, L1 = z), L1)
+ lengths = summarise(clustered, value = n())
+ sums = summarise(clustered, value = sum(value))
+
+ n = sapply(seq_along(s),
+ function(cluster) {
+ idx = which(lengths$L1 == cluster)
+ if (length(idx) != 0) {
+ lengths$value[idx]
+ } else {
+ 0
+ }
+ })
+
+ yt = sapply(seq_along(s),
+ function(cluster) {
+ idx = which(sums$L1 == cluster)
+ if (length(idx) != 0) {
+ sums$value[idx]
+ } else {
+ 0
+ }
+ })
+
+ # FIXME (jtobin): check these against tim hopper's
+ m = (yt * s + l * r) / (n * s + r)
+ v = 1 / (n * s + r)
+
+ # FIXME (jtobin): check this
+ mapply(rnorm, 1, m, v)
+ }
+
+conditional_precision_model = function(y, z, m, b, w) {
+ labelled = data.frame(value = y, L1 = z)
+ clustered = group_by(labelled, L1)
+
+ acc = list()
+ for (j in seq_along(m)) {
+ acc[[j]] = labelled[which(labelled$L1 == j), 'value']
+ }
+
+ centered = mapply("-", acc, m)
+ squared = lapply(centered, function(x) x ^ 2)
+ ss = unlist(lapply(squared, sum))
+
+ n = sapply(seq_along(m),
+ function(cluster) {
+ lengths = summarise(clustered, value = n())
+ idx = which(lengths$L1 == cluster)
+ if (length(idx) != 0) {
+ lengths$value[idx]
+ } else {
+ 0
+ }
+ })
+
+ a = b + n
+ bet = (w * b + ss) / a
+
+ mapply(function(a, b) rgamma(1, a, b), a, bet)
+ }
+
+inverse_model = function(n, k, y, a, l, r, b, w) {
+ kernel = function(p0, m0, s0) {
+ z = conditional_label_model(y, p0, m0, s0)
+ p1 = conditional_mixing_model(y, k, z, a)
+ m1 = conditional_location_model(y, z, s0, l, r)
+ s1 = conditional_precision_model(y, z, m1, b, w)
+ list(p = p1, m = m1, s = s1)
+ }
+
+ p0 = mixing_model(k, a)
+ m0 = location_model(k, l, r)
+ s0 = precision_model(k, b, w)
+ params = list(p = p0, m = m0, s = s0)
+
+ acc = params
+ for (j in seq(n - 1)) {
+ params = kernel(params$p, params$m, params$s)
+ acc$p = rbind(acc$p, params$p)
+ acc$m = rbind(acc$m, params$m)
+ acc$s = rbind(acc$s, params$s)
+ }
+ acc
+ }
+
diff --git a/finite-gaussian-mixture/src/simulation_conditional.r b/finite-gaussian-mixture/src/simulation_conditional.r
@@ -0,0 +1,45 @@
+set.seed(42)
+
+require(ggplot2)
+require(reshape2)
+
+source('fmm_generative.r')
+
+config = list(
+ k = 3
+ , a = 1
+ , l = 0
+ , r = 0.1
+ , b = 1
+ , w = 1
+ , n = 500
+ )
+
+origin = list(
+ p = mixing_model(config$k, config$a)
+ , m = location_model(config$k, config$l, config$r)
+ , s = precision_model(config$k, config$b, config$w)
+ )
+
+d = melt(model(config$k, config$n))
+
+params = inverse_model(
+ config$n, config$k, data$value
+ , config$a, config$l, config$r
+ , config$b, config$w
+ )
+
+dp = melt(as.data.frame(params$p))
+dm = melt(as.data.frame(params$m))
+ds = melt(as.data.frame(params$s))
+
+pp = ggplot(dp, aes(x = seq_along(value), y = value, colour = variable))
+ + geom_line()
+
+pm = ggplot(dm, aes(x = seq_along(value), y = value, colour = variable))
+ + geom_line()
+
+ps = ggplot(ds, aes(x = seq_along(value), y = value, colour = variable))
+ + geom_line()
+
+