fmm_multivariate_generative.r (2965B)
1 require(mvtnorm) 2 require(plyr) 3 4 source('fmm_utils.r') 5 6 # Cluster probabilities via a symmetric Dirichlet distribution. 7 # 8 # k : integer, > 0 9 # a : numeric, > 0 10 # 11 # Returns a vector of probabilities of size k. 12 mixing_model = function(k, a) drop(rdirichlet(1, (rep(a, k)))) 13 14 # Number of observations per cluster, by indicator. 15 # 16 # n : integer, > 0 17 # p : numeric, probability 18 # 19 # Returns a list of integer sizes corresponding to the given cluster. The 20 # number of clusters is determined by the length of 'p'. 21 label_model = function(n, p) { 22 vals = drop(rmultinom(1, size = n, prob = p)) 23 as.list(vals) 24 } 25 26 # Location, by cluster. 27 # 28 # k : integer, > 0 29 # l : numeric 30 # r : numeric, positive definite 31 # 32 # Returns a list of 'k' locations, each having same dimension as 'l'. 33 location_model = function(k, l, r) { 34 vals = rmvnorm(k, l, solve(r)) 35 unlist(apply(vals, MARGIN = 1, list), recursive = F) 36 } 37 38 # Precision matrix, by cluster. 39 # 40 # k : integer, > 0 41 # b : numeric, >= 1 42 # w : numeric, positive definite 43 # 44 # Returns a list of 'k' precision matrices with same dimension as 'w'. 45 precision_model = function(k, b, w) { 46 vals = rWishart(k, b, w) 47 alply(vals, 3) 48 } 49 50 # Parameter model for the finite Gaussian mixture model. 51 # 52 # k : integer, > 0 53 # l : numeric 54 # r : numeric, positive definite 55 # b : numeric, >= 1 56 # w : numeric, positive definite 57 # n : integer, > 0 58 # 59 # Returns a list of three components: 60 # * n : list of length 'k' containing the size of the kth cluster 61 # * m : list of length 'k' containing the location of the kth cluster 62 # * s : list of length 'k' containing the precision of the kth cluster 63 parameter_model = function(k, l, r, b, w, n) { 64 p = mixing_model(k, 1) 65 c = label_model(n, p) 66 mu = location_model(k, l, r) 67 s = precision_model(k, b, w) 68 list(n = c, m = mu, s = s) 69 } 70 71 # Data model for the finite Gaussian mixture model. 72 # 73 # params : output type of 'paramter_model' 74 # 75 # Returns observations by cluster as a list. 76 data_model = function(params) { 77 safe_rmvnorm = function(c, m, s) { 78 if (c <= 0) { numeric(0) } else { rmvnorm(c, m, solve(s)) } 79 } 80 mapply(safe_rmvnorm, params$n, params$m, params$s) 81 } 82 83 # The finite Gaussian mixture model. 84 model = function(k, l, r, b, w, n) { 85 params = parameter_model(k, l, r, b, w, n) 86 data_model(params) 87 } 88 89 # Log-likelihood for the finite Gaussian mixture model. 90 # 91 # y : numeric 92 # p : numeric, probability 93 # m : numeric 94 # s : numeric, positive definite 95 # 96 # 'y' is a matrix of observations. 'p', 'm', and 's' are a probability vector, 97 # list of location vectors, and list of precision matrices of the appropriate 98 # dimensions. 99 lmodel = function(y, p, m, s) { 100 score = function(pr, mu, prec) { pr * dmvnorm(y, mu, solve(prec)) } 101 by_cluster = mapply(score, p, m, s) 102 totalled = apply(by_cluster, MARGIN = 1, sum) 103 104 # NOTE (jtobin): adjusted for numerical stability 105 small = 1.379783e-316 106 adjusted = totalled 107 adjusted[which(adjusted == 0)] = small 108 sum(log(adjusted)) 109 } 110