commit 210120e28186da51995324022ac4b13c0706bbc3
parent fb42bf3c01bbdc10908d8d209a716cb3401ecf2b
Author: Jared Tobin <jared@jtobin.ca>
Date: Mon, 14 Mar 2016 10:19:51 +1300
Misc.
* Add documentation & fix up multivariate generative FMM.
* Add 3d example to simulation.
Diffstat:
2 files changed, 90 insertions(+), 36 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_multivariate_generative.r b/finite-gaussian-mixture/src/fmm_multivariate_generative.r
@@ -1,57 +1,109 @@
require(gtools)
require(magrittr)
require(mvtnorm)
+require(plyr)
+# Cluster probabilities via a symmetric Dirichlet distribution.
+#
+# k : integer, > 0
+# a : numeric, > 0
+#
+# Returns a vector of probabilities of size k.
mixing_model = function(k, a) drop(rdirichlet(1, (rep(a, k))))
+# Number of observations per cluster, by indicator.
+#
+# n : integer, > 0
+# p : numeric, probability
+#
+# Returns a list of integer sizes corresponding to the given cluster. The
+# number of clusters is determined by the length of 'p'.
label_model = function(n, p) {
vals = drop(rmultinom(1, size = n, prob = p))
- delabel(lapply(vals, list))
+ as.list(vals)
}
+# Location, by cluster.
+#
+# k : integer, > 0
+# l : numeric
+# r : numeric, positive definite
+#
+# Returns a list of 'k' locations, each having same dimension as 'l'.
location_model = function(k, l, r) {
vals = rmvnorm(k, l, solve(r))
- delabel(apply(vals, MARGIN = 1, list))
+ alply(vals, 1)
}
-precision_model = function(k, b, w) rinvwishart(k, b, solve(w))
+# Precision matrix, by cluster.
+#
+# k : integer, > 0
+# b : numeric, >= 1
+# w : numeric, positive definite
+#
+# Returns a list of 'k' precision matrices with same dimension as 'w'.
+precision_model = function(k, b, w) {
+ vals = rWishart(k, b, w)
+ alply(vals, 3)
+}
-parameter_model = function(m, k, b, n) {
+# Parameter model for the finite Gaussian mixture model.
+#
+# k : integer, > 0
+# l : numeric
+# r : numeric, positive definite
+# b : numeric, >= 1
+# w : numeric, positive definite
+# n : integer, > 0
+#
+# Returns a list of three components:
+# * n : list of length 'k' containing the size of the kth cluster
+# * m : list of length 'k' containing the location of the kth cluster
+# * s : list of length 'k' containing the precision of the kth cluster
+parameter_model = function(k, l, r, b, w, n) {
p = mixing_model(k, 1)
c = label_model(n, p)
- mu = location_model(k, rep(0, m), diag(0.05, m))
- s = precision_model(k, b, diag(1, m))
+ mu = location_model(k, l, r)
+ s = precision_model(k, b, w)
list(n = c, m = mu, s = s)
}
-data_model = function(config) {
- mapply(safe_rmvnorm, config$n, config$m, config$s)
+# Data model for the finite Gaussian mixture model.
+#
+# params : output type of 'paramter_model'
+#
+# Returns observations by cluster as a list.
+data_model = function(params) {
+ safe_rmvnorm = function(c, m, s) {
+ if (c <= 0) { numeric(0) } else { rmvnorm(c, m, solve(s)) }
+ }
+ mapply(safe_rmvnorm, params$n, params$m, params$s)
}
-model = function(m, k, b, n) parameter_model(m, k, b, n) %>% data_model
-
-# FIXME (jtobin): checkme, not correct
-lmodel = function(y, z, p, m, s) {
-
- clustered = cbind(y, L1 = z)
- cluster = clustered$L1
-
- score = log(p[cluster]) +
- dmvnorm(clustered$value, m[cluster], solve(s[cluster]), log = T)
-
- sum(score)
+# The finite Gaussian mixture model.
+model = function(k, l, r, b, w, n) {
+ parameter_model(k, l, r, b, w, n) %>% data_model
}
-# utilities
-
-rinvwishart = function(n, v, S) {
- wishes = rWishart(n, v, solve(S))
- delabel(apply(wishes, MARGIN = 3, function(x) list(solve(x))))
- }
+# Log-likelihood for the finite Gaussian mixture model.
+#
+# y : numeric
+# p : numeric, probability
+# m : numeric
+# s : numeric, positive definite
+#
+# 'y' is a matrix of observations. 'p', 'm', and 's' are a probability vector,
+# list of location vectors, and list of precision matrices of the appropriate
+# dimensions.
+lmodel = function(y, p, m, s) {
+ score = function(pr, mu, prec) { pr * dmvnorm(y, mu, solve(prec)) }
+ by_cluster = mapply(score, p, m, s)
+ totalled = apply(by_cluster, MARGIN = 1, sum)
-delabel = function(x) lapply(x, "[[", 1)
+ # NOTE (jtobin): adjusted for numerical stability
+ small = 1.379783e-316
+ adjusted = totalled
+ adjusted[which(adjusted == 0)] = small
+ sum(log(adjusted))
+}
-safe_rmvnorm = function(c, m, s) {
- if (c <= 0) return(numeric(0))
- else rmvnorm(c, m, solve(s))
- }
diff --git a/finite-gaussian-mixture/src/simulation_multivariate.r b/finite-gaussian-mixture/src/simulation_multivariate.r
@@ -24,15 +24,17 @@ p = ggplot(melted, aes(x, y, colour = factor(L1))) + geom_point(alpha = 0.2)
set.seed(42)
-alt_config = list(
+config_3d = list(
m = 3
, v = 3
, k = 4
, n = 10000
)
-alt_d = model(alt_config$m, alt_config$k, alt_config$v, alt_config$n)
-alt_framed = lapply(alt_d,
+d_3d = model(config_3d$m, config_3d$k, config_3d$v, config_3d$n)
+framed_3d = lapply(d_3d,
function(mat) { data.frame(x = mat[,1], y = mat[,2], z = mat[,3]) })
-alt_melted = do.call(rbind, alt_framed)
-scatterplot3d(alt_melted, highlight.3d = T, pch = 19)
+melted_3d = do.call(rbind, framed_3d)
+
+scatterplot3d(melted_3d, highlight.3d = T, pch = 19)
+