commit 205bf1b2060f0086bf7a0668a83c4e36e6f4190d
parent a9738c6782dfbfd4670996f10974ebeecfcf5517
Author: Jared Tobin <jared@jtobin.ca>
Date: Tue, 23 Feb 2016 17:23:23 +1300
Add conditional precision model.
Diffstat:
1 file changed, 58 insertions(+), 20 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_conditional.r b/finite-gaussian-mixture/src/fmm_conditional.r
@@ -55,27 +55,65 @@ conditional_mixing_model = function(y, k, z, a) {
rdirichlet(1, concentration)
}
+conditional_location_model = function(y, z, s, l, r) {
+ clustered = group_by(data.frame(value = y, L1 = z), L1)
+ lengths = summarise(clustered, value = n())
+ sums = summarise(clustered, value = sum(value))
+ n = sapply(seq_along(s),
+ function(cluster) {
+ idx = which(lengths$L1 == cluster)
+ if (length(idx) != 0) {
+ lengths$value[idx]
+ } else {
+ 0
+ }
+ })
+
+ yt = sapply(seq_along(s),
+ function(cluster) {
+ idx = which(sums$L1 == cluster)
+ if (length(idx) != 0) {
+ sums$value[idx]
+ } else {
+ 0
+ }
+ })
+
+ m = (yt * s + l * r) / (n * s + r)
+ v = 1 / (n * s + r)
+ mapply(rnorm, 1, m, v)
+ }
+
+conditional_precision_model = function(y, z, m, b, w) {
+ labelled = data.frame(value = y, L1 = z)
+ clustered = group_by(labelled, L1)
+
+ acc = list()
+ for (j in seq_along(m)) {
+ acc[[j]] = labelled[which(labelled$L1 == j), 'value']
+ }
+
+ centered = mapply("-", acc, m)
+ squared = lapply(centered, function(x) x ^ 2)
+ ss = unlist(lapply(squared, sum))
+
+ n = sapply(seq_along(s),
+ function(cluster) {
+ lengths = summarise(clustered, value = n())
+ idx = which(lengths$L1 == cluster)
+ if (length(idx) != 0) {
+ lengths$value[idx]
+ } else {
+ 0
+ }
+ })
+
+ a = b + n
+ bet = (w * b + ss) / a
+
+ mapply(function(a, b) rgamma(1, a, b), a, bet)
+ }
-# mixing_model = function(k, a) drop(rdirichlet(1, (rep(a, k))))
-# label_model = function(n, p) drop(rmultinom(1, size = n, prob = p))
-# location_model = function(k, l, r) rnorm(k, l, 1 / r)
-# precision_model = function(k, b, w) rgamma(k, b, 1 / w)
-#
-# parameter_model = function(k, n) {
-# p = mixing_model(k, 1)
-# c = label_model(n, p)
-# mu = location_model(k, 0, 0.1)
-# s = precision_model(k, 1, 1)
-# list(c, mu, s)
-# }
-#
-# data_model = function(config) {
-# sampler = function(y, m, s) rnorm(y, m, 1 / s)
-# mapply(sampler, config[[1]], config[[2]], config[[3]])
-# }
-#
-# model = function(k, n) parameter_model(k, n) %>% data_model
-#