commit 955a54fefe6798a893f91144488311f09139ef63
parent 68e9a9d0ad44f13657398e747d168c3db7f9b1e3
Author: Jared Tobin <jared@jtobin.ca>
Date: Fri, 12 Feb 2016 20:36:33 +1300
Clean up some code. Inference still problematic.
Diffstat:
1 file changed, 24 insertions(+), 22 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm.r b/finite-gaussian-mixture/src/fmm.r
@@ -25,13 +25,14 @@ rp_conditional = function(n, a0) {
rc_model = function(n, p) drop(rmultinom(1, size = n, prob = p))
-# FIXME generating NA in probability vector for at least one y.
+# FIXME too easy to generate NaN probabilities when precisions are too large
+# may not be correct at all
rc_conditional = function(y, p0, mu, s) {
k = length(p0)
- config = data.frame(p = p0, mu = mu, s = s)
- reducer = function(row) { row[1] * dnorm(y, row[2], 1 / row[3]) }
- elements = apply(config, MARGIN = 1, reducer)
+ reducer = function(p, m, prec) { p * dnorm(y, m, 1 / prec) }
+ elements = mapply(reducer, p, m, s)
p1 = elements / sum(elements)
+
sample(1:k, size = 1, prob = p1)
}
@@ -42,7 +43,7 @@ rmu_model = function(k, l, r) rnorm(k, l, 1 / r)
rmu_conditional = function(y, s, l, r) {
k = length(s)
c = sapply(y, length)
- ybar = sapply(y, safe_mean) # FIXME safe_mean returns 0 on length 0; safe?
+ ybar = sapply(y, safe_mean)
m = (ybar * c * s + l * r) / (c * s + r)
v = 1 / (c * s + r)
rnorm(k, m, v)
@@ -53,8 +54,9 @@ rmu_conditional = function(y, s, l, r) {
rs_model = function(k, b, w) rgamma(k, b, 1 / w)
rs_conditional = function(y, mu, b, w) {
- k = length(y)
- c = sapply(y, length)
+
+ k = length(y)
+ c = sapply(y, length)
centered = mapply("-", y, mu)
squared = lapply(centered, function(x) { x ^ 2 })
@@ -104,7 +106,7 @@ rmodel_conditional = function(n, y) {
raw = unlist(y)
kernel = function(p0, mu0, s0) {
- z = sapply(raw, function(x) rc_conditional(x, p0, mu0, s0)) # FIXME slow
+ z = sapply(raw, function(x) rc_conditional(x, p0, mu0, s0)) # FIXME slow
labelled = data.frame(label = z, y = raw)
id_query = function(c) filter(labelled, label == c)$y
@@ -113,7 +115,7 @@ rmodel_conditional = function(n, y) {
p1 = rp_conditional(counts, a)
mu1 = rmu_conditional(clustered, s0, l, r)
- s1 = rs_conditional(clustered, mu1, b, w) # FIXME precisions explode
+ s1 = rs_conditional(clustered, mu1, b, w)
list(p = p1, mu = mu1, s = s1)
}
@@ -137,12 +139,22 @@ rmodel_conditional = function(n, y) {
gibbs(n, init, p0, mu0, s0)
}
+# utilities
+
+safe_mean = function(x) {
+ if (is.null(x) || (length(x) == 0)) {
+ return(0)
+ } else {
+ mean(x)
+ }
+ }
+
# debug
test_data = list(
- rnorm(101, 3.5, 1)
- , rnorm(38, 0.3, 0.8)
- , rnorm(90, -4.2, 0.5)
+ rnorm(101, 10.5, 1)
+ , rnorm(38, 0.3, 1)
+ , rnorm(90, -8.2, 0.5)
)
# y = list(
# rnorm(801, 3.5, 1)
@@ -162,13 +174,3 @@ test_data = list(
# mu0 = mu
# s0 = s
-# utilities
-
-safe_mean = function(x) {
- if (is.null(x) || (length(x) == 0)) {
- return(0)
- } else {
- mean(x)
- }
- }
-