commit 6664f3f66c3ec1caafbed4389595868fa1f0f498
parent 2b9422c6c0c732fa4f835f17ba4e737512dd0bb7
Author: Jared Tobin <jared@jtobin.ca>
Date: Thu, 10 Mar 2016 10:42:39 +1300
Clean up multivariate models.
Diffstat:
1 file changed, 14 insertions(+), 5 deletions(-)
diff --git a/finite-gaussian-mixture/src/fmm_multivariate_generative.r b/finite-gaussian-mixture/src/fmm_multivariate_generative.r
@@ -2,15 +2,24 @@ require(gtools)
require(magrittr)
require(mvtnorm)
-mixing_model = function(k, a) drop(rdirichlet(1, (rep(a, k))))
-label_model = function(n, p) drop(rmultinom(1, size = n, prob = p))
-location_model = function(k, l, r) rmvnorm(k, l, solve(r))
+mixing_model = function(k, a) drop(rdirichlet(1, (rep(a, k))))
+
+label_model = function(n, p) {
+ vals = drop(rmultinom(1, size = n, prob = p))
+ delabel(lapply(vals, list))
+}
+
+location_model = function(k, l, r) {
+ vals = rmvnorm(k, l, solve(r))
+ delabel(apply(vals, MARGIN = 1, list))
+}
+
precision_model = function(k, b, w) rinvwishart(k, b, solve(w))
parameter_model = function(m, k, n) {
p = mixing_model(k, 1)
- c = delabel(lapply(label_model(n, p), list))
- mu = delabel(apply(location_model(k, rep(0, m), diag(0.05, m)), MARGIN = 1, list))
+ c = label_model(n, p)
+ mu = location_model(k, rep(0, m), diag(0.05, m))
s = precision_model(k, 2, diag(1, m))
list(c, mu, s)
}