bnp

Some older Bayesian nonparametrics research.
git clone git://git.jtobin.io/bnp.git
Log | Files | Refs | README | LICENSE

fmm_multivariate_conditional_collapsed.r (2983B)


      1 require(mvtnorm)
      2 
      3 source('fmm_multivariate_generative.r')
      4 
      5 cluster_statistics = function(cluster, l, b, w) {
      6   mclust   =
      7     # R, seriously?
      8     if (is.null(dim(cluster))) {
      9       matrix(cluster, ncol = ncol(w))
     10     } else {
     11       as.matrix(cluster)
     12     }
     13   m        = ncol(mclust)
     14   n        = nrow(mclust)
     15   ybar     = colMeans(mclust)
     16   centered = mclust - ybar
     17   ss       = t(centered) %*% centered
     18   ln       = (l + n * ybar) / (1 + n)
     19   tn       =
     20     if (n == 0) {
     21       w
     22     } else {
     23       solve(solve(w) + ss + n / (1 + n) * ((l - ybar) %*% t(l - ybar)))
     24     }
     25   bn       = b + n
     26   df       = bn - m + 1
     27   mu       = ln
     28   coef     = n / ((n + 1) * df)
     29   v        = coef * solve(tn)
     30   list(
     31       n = n, ln = ln, tn = tn, bn = bn
     32     , df = df, mu = mu, v = v
     33     )
     34 }
     35 
     36 conditional_label_model = function(y, k, z, a, l, b, w) {
     37   cluster_labels = seq(k)
     38   rows           = seq(nrow(y))
     39   m              = ncol(y)
     40 
     41   initial_clusters = sapply(
     42       cluster_labels
     43     , function(j) { y[which(z == j),] }
     44     , simplify = F)
     45 
     46   sufficient_statistics = lapply(
     47       initial_clusters
     48     , function(c) { cluster_statistics(c, l, b, w) })
     49 
     50   relabel = function(i) {
     51     old_label  = z[i]
     52     val        = y[i,]
     53     y_censored = as.matrix(y[-i,])
     54     z_censored = z[-i]
     55     n_censored = sapply(
     56         cluster_labels
     57       , function(j) { length(which(z_censored == j)) })
     58 
     59     score_by_cluster = function(j) {
     60       sufficient_stats = if (j == old_label) {
     61           cluster = y_censored[which(z_censored == j), ]
     62           sufficient_statistics[[j]] <<- cluster_statistics(cluster, l, b, w)
     63           sufficient_statistics[[j]]
     64         } else {
     65           sufficient_statistics[[j]]
     66         }
     67       dmvt(
     68           val
     69         , df    = sufficient_stats$df
     70         , sigma = sufficient_stats$v
     71         , delta = sufficient_stats$mu
     72         , log   = T
     73         )
     74     }
     75 
     76     scores    = exp(sapply(cluster_labels, score_by_cluster))
     77     weight    = n_censored + a / k
     78     probs     = scores * weight / sum(scores * weight)
     79     new_label = sample(cluster_labels, size = 1, prob = probs)
     80 
     81     z[i] <<- new_label
     82     new_stats = cluster_statistics(y[which(z == new_label),], l, b, w)
     83     sufficient_statistics[[new_label]] <<- new_stats
     84 
     85     new_label
     86   }
     87   sapply(rows, relabel)
     88 }
     89 
     90 inverse_model = function(n, k, y, a, l, b, w) {
     91   gibbs = function(z0) {
     92     z = conditional_label_model(y, k, z0, a, l, b, w)
     93     clustered = lapply(seq(k),
     94       function(j) {
     95         vals = y[which(z == j),]
     96         as.matrix(vals)
     97       })
     98 
     99     ps    = lapply(clustered, function(j) { nrow(j) / nrow(y) })
    100     mus   = lapply(clustered, colMeans)
    101     precs = lapply(clustered, function(j) (solve(cov(j))))
    102     ll    = lmodel(y, ps, mus, precs)
    103     list(z = z, ll = ll)
    104   }
    105 
    106   params = list(z = sample(seq(k), size = nrow(y), replace = T))
    107   acc    = params
    108   for (j in seq(n - 1)) {
    109     params = gibbs(params$z)
    110     acc$z  = rbind(acc$z, params$z)
    111     acc$ll = c(acc$ll, params$ll)
    112   }
    113   acc
    114 }
    115