fmm_multivariate_conditional.r (3266B)
1 require(mvtnorm) 2 3 source('fmm_multivariate_generative.r') 4 source('fmm_utils.r') 5 6 # NOTE (jtobin): must load dplyr after plyr 7 require(dplyr) 8 9 conditional_mixing_model = function(y, k, z, a) { 10 labelled = cbind(y, L1 = z) 11 counts = summarise(group_by(labelled, L1), count = n()) 12 concentration = sapply(seq(k), 13 function(cluster) { 14 idx = which(counts$L1 == cluster) 15 if (length(idx) != 0) { 16 counts$count[idx] + a / k 17 } else { 18 a / k 19 } 20 }) 21 drop(rdirichlet(1, concentration)) 22 } 23 24 conditional_label_model = function(y, p, m, s) { 25 scorer = function(mix, mu, prec) { 26 exp(log(mix) + dmvnorm(y, mu, solve(prec), log = T)) 27 } 28 29 unweighted = mapply(scorer, p, m, s) 30 weights = 1 / apply(unweighted, MARGIN = 1, sum) 31 probs = weights * unweighted 32 33 clusters = apply( 34 probs 35 , MARGIN = 1 36 , function(row) { sample(seq_along(m), size = 1, prob = row) } 37 ) 38 unname(clusters) 39 } 40 41 conditional_cluster_parameters_model = function(y, k, z, l, r, b, w) { 42 labelled = data.frame(y, L1 = z) 43 clustered = lapply(seq(k), 44 function(j) { 45 vals = labelled[which(labelled$L1 == j), !(names(labelled) %in% 'L1')] 46 as.matrix(vals) 47 }) 48 49 ybar = lapply(clustered, colMeans) 50 n = lapply(clustered, nrow) 51 pl = function(lj, nj, ybarj) { 52 if (nj == 0) { 53 lj 54 } else { 55 (lj + nj * ybarj) / (1 + nj) 56 } 57 } 58 ln = mapply(pl, list(l), n, ybar, SIMPLIFY = F) 59 centered = mapply('-', clustered, ybar, SIMPLIFY = F) 60 ss = lapply(centered, function(x) t(x) %*% x) 61 62 # NOTE (jtobin): the extra 'solve' calls here helped; came from 63 # http://thaines.com/content/misc/gaussian_conjugate_prior_cheat_sheet.pdf 64 # murphy's famous reference at 65 # http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf is incorrect. 66 pt = function(wj, ssj, nj, ybarj) { 67 if (nj == 0) { wj } else { 68 solve(solve(wj) + ssj + nj / (1 + nj) * ((l - ybarj) %*% t(l - ybarj))) 69 } 70 } 71 72 tn = mapply(pt, list(w), ss, n, ybar, SIMPLIFY = F) 73 bn = lapply(n, function(x) x + b) 74 prec = mapply(function(i, j) drop(rWishart(1, i, j)), bn, tn, SIMPLIFY = F) 75 cov = mapply(function(i, j) solve((i + 1) * j), n, tn, SIMPLIFY = F) 76 loc = mapply(rmvnorm, 1, ln, cov, SIMPLIFY = F) 77 if (any(is.nan(unlist(loc)))) { browser() } 78 list(m = loc, s = prec) 79 } 80 81 inverse_model = function(n, k, y, a, l, r, b, w) { 82 gibbs = function(p0, m0, s0) { 83 z = conditional_label_model(y, p0, m0, s0) 84 p1 = conditional_mixing_model(y, k, z, a) 85 ps = conditional_cluster_parameters_model(y, k, z, l, r, b, w) 86 m1 = ps$m 87 s1 = ps$s 88 ll = lmodel(y, p1, m1, s1) 89 90 list(p = p1, m = m1, s = s1, z = z, l = ll) 91 } 92 93 params = list( 94 p = mixing_model(k, a) 95 , m = lapply( 96 location_model(k, l, r) 97 , function(j) { matrix(j, ncol = length(j)) }) 98 , s = precision_model(k, b, w) 99 ) 100 101 acc = params 102 acc$s = list(acc$s) 103 for (j in seq(n - 1)) { 104 params = gibbs(params$p, params$m, params$s) 105 106 acc$p = rbind(acc$p, params$p) 107 acc$m = mapply(rbind, acc$m, params$m, SIMPLIFY = F) 108 acc$s = c(acc$s, list(params$s)) 109 acc$z = rbind(acc$z, params$z) 110 acc$l = c(acc$l, params$l) 111 } 112 acc 113 } 114