fmm_conditional.r (3317B)
1 require(dplyr) 2 3 source('fmm_generative.r') 4 source('fmm_utils.r') 5 6 conditional_mixing_model = function(y, k, z, a) { 7 labelled = data.frame(value = y, L1 = z) 8 counts = summarise(group_by(labelled, L1), count = n()) 9 10 concentration = sapply( 11 seq(k) 12 , function(cluster) { 13 idx = which(counts$L1 == cluster) 14 if (length(idx) != 0) { 15 counts$count[idx] + a / k 16 } else { 17 a / k 18 } 19 }) 20 21 drop(rdirichlet(1, concentration)) 22 } 23 24 conditional_label_model = function(y, p, m, s) { 25 scorer = function(mix, mu, prec) { 26 exp(log(mix) + dnorm(y, mu, sqrt(1 / prec), log = T)) 27 } 28 unweighted = mapply(scorer, p, m, s) 29 weights = 1 / apply(unweighted, MARGIN = 1, sum) 30 weighted = weights * unweighted 31 32 probabilize = function(row) { 33 rs = sum(row) 34 if (rs == 0 || is.na(rs) || is.nan(rs)) { 35 drop(rdirichlet(1, rep(1, length(m)))) 36 } else { 37 row 38 } 39 } 40 41 probs = t(apply(weighted, MARGIN = 1, probabilize)) 42 apply( 43 probs 44 , MARGIN = 1 45 , function(row) { sample(seq_along(m), size = 1, prob = row) } 46 ) 47 } 48 49 conditional_location_model = function(y, z, s, l, r) { 50 clustered = group_by(data.frame(value = y, L1 = z), L1) 51 lengths = summarise(clustered, value = n()) 52 sums = summarise(clustered, value = sum(value)) 53 54 n = sapply(seq_along(s), 55 function(cluster) { 56 idx = which(lengths$L1 == cluster) 57 if (length(idx) != 0) { 58 lengths$value[idx] 59 } else { 60 0 61 } 62 }) 63 64 yt = sapply(seq_along(s), 65 function(cluster) { 66 idx = which(sums$L1 == cluster) 67 if (length(idx) != 0) { 68 sums$value[idx] 69 } else { 70 0 71 } 72 }) 73 74 m = (yt * s + l * r) / (n * s + r) 75 v = 1 / (n * s + r) 76 77 mapply(rnorm, 1, m, sqrt(v)) 78 } 79 80 conditional_precision_model = function(y, z, m, b, w) { 81 labelled = data.frame(value = y, L1 = z) 82 clustered = group_by(labelled, L1) 83 84 acc = list() 85 for (j in seq_along(m)) { 86 acc[[j]] = labelled[which(labelled$L1 == j), 'value'] 87 } 88 89 centered = mapply("-", acc, m) 90 squared = lapply(centered, function(x) x ^ 2) 91 ss = unlist(lapply(squared, sum)) 92 93 n = sapply(seq_along(m), 94 function(cluster) { 95 lengths = summarise(clustered, value = n()) 96 idx = which(lengths$L1 == cluster) 97 if (length(idx) != 0) { 98 lengths$value[idx] 99 } else { 100 0 101 } 102 }) 103 104 a = b + n 105 bet = (w * b + ss) / a 106 107 mapply(function(a, b) rgamma(1, a, b), a, bet) 108 } 109 110 inverse_model = function(n, k, y, a, l, r, b, w) { 111 gibbs = function(p0, m0, s0) { 112 z = conditional_label_model(y, p0, m0, s0) 113 p1 = conditional_mixing_model(y, k, z, a) 114 m1 = conditional_location_model(y, z, s0, l, r) 115 s1 = conditional_precision_model(y, z, m1, b, w) 116 l = lmodel(y, p1, m1, s1) 117 list(p = p1, m = m1, s = s1, z = z, l = l) 118 } 119 120 p0 = mixing_model(k, a) 121 m0 = location_model(k, l, r) 122 s0 = precision_model(k, b, w) 123 params = list(p = p0, m = m0, s = s0) 124 125 acc = params 126 for (j in seq(n - 1)) { 127 params = gibbs(params$p, params$m, params$s) 128 acc$p = rbind(acc$p, params$p) 129 acc$m = rbind(acc$m, params$m) 130 acc$s = rbind(acc$s, params$s) 131 acc$z = rbind(acc$z, params$z) 132 acc$l = c(acc$l, params$l) 133 } 134 acc 135 } 136