bnp

Some older Bayesian nonparametrics research.
git clone git://git.jtobin.io/bnp.git
Log | Files | Refs | README | LICENSE

fmm_multivariate_conditional.r (3266B)


      1 require(mvtnorm)
      2 
      3 source('fmm_multivariate_generative.r')
      4 source('fmm_utils.r')
      5 
      6 # NOTE (jtobin): must load dplyr after plyr
      7 require(dplyr)
      8 
      9 conditional_mixing_model = function(y, k, z, a) {
     10   labelled      = cbind(y, L1 = z)
     11   counts        = summarise(group_by(labelled, L1), count = n())
     12   concentration = sapply(seq(k),
     13     function(cluster) {
     14       idx = which(counts$L1 == cluster)
     15       if (length(idx) != 0) {
     16         counts$count[idx] + a / k
     17       } else {
     18         a / k
     19       }
     20     })
     21   drop(rdirichlet(1, concentration))
     22 }
     23 
     24 conditional_label_model = function(y, p, m, s) {
     25   scorer = function(mix, mu, prec) {
     26     exp(log(mix) + dmvnorm(y, mu, solve(prec), log = T))
     27   }
     28 
     29   unweighted = mapply(scorer, p, m, s)
     30   weights    = 1 / apply(unweighted, MARGIN = 1, sum)
     31   probs      = weights * unweighted
     32 
     33   clusters = apply(
     34       probs
     35     , MARGIN = 1
     36     , function(row) { sample(seq_along(m), size = 1, prob = row) }
     37     )
     38   unname(clusters)
     39 }
     40 
     41 conditional_cluster_parameters_model = function(y, k, z, l, r, b, w) {
     42   labelled  = data.frame(y, L1 = z)
     43   clustered = lapply(seq(k),
     44     function(j) {
     45       vals = labelled[which(labelled$L1 == j), !(names(labelled) %in% 'L1')]
     46       as.matrix(vals)
     47     })
     48 
     49   ybar     = lapply(clustered, colMeans)
     50   n        = lapply(clustered, nrow)
     51   pl       = function(lj, nj, ybarj) {
     52     if (nj == 0) {
     53       lj
     54     } else {
     55       (lj + nj * ybarj) / (1 + nj)
     56     }
     57   }
     58   ln       = mapply(pl, list(l), n, ybar, SIMPLIFY = F)
     59   centered = mapply('-', clustered, ybar, SIMPLIFY = F)
     60   ss       = lapply(centered, function(x) t(x) %*% x)
     61 
     62   # NOTE (jtobin): the extra 'solve' calls here helped; came from
     63   # http://thaines.com/content/misc/gaussian_conjugate_prior_cheat_sheet.pdf
     64   # murphy's famous reference at
     65   # http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf is incorrect.
     66   pt = function(wj, ssj, nj, ybarj) {
     67     if (nj == 0) { wj } else {
     68       solve(solve(wj) + ssj + nj / (1 + nj) * ((l - ybarj) %*% t(l - ybarj)))
     69     }
     70   }
     71 
     72   tn   = mapply(pt, list(w), ss, n, ybar, SIMPLIFY = F)
     73   bn   = lapply(n, function(x) x + b)
     74   prec = mapply(function(i, j) drop(rWishart(1, i, j)), bn, tn, SIMPLIFY = F)
     75   cov  = mapply(function(i, j) solve((i + 1) * j), n, tn, SIMPLIFY = F)
     76   loc  = mapply(rmvnorm, 1, ln, cov, SIMPLIFY = F)
     77   if (any(is.nan(unlist(loc)))) { browser() }
     78   list(m = loc, s = prec)
     79 }
     80 
     81 inverse_model = function(n, k, y, a, l, r, b, w) {
     82   gibbs = function(p0, m0, s0) {
     83     z  = conditional_label_model(y, p0, m0, s0)
     84     p1 = conditional_mixing_model(y, k, z, a)
     85     ps = conditional_cluster_parameters_model(y, k, z, l, r, b, w)
     86     m1 = ps$m
     87     s1 = ps$s
     88     ll = lmodel(y, p1, m1, s1)
     89 
     90     list(p = p1, m = m1, s = s1, z = z, l = ll)
     91   }
     92 
     93   params = list(
     94       p = mixing_model(k, a)
     95     , m = lapply(
     96               location_model(k, l, r)
     97             , function(j) { matrix(j, ncol = length(j)) })
     98     , s = precision_model(k, b, w)
     99     )
    100 
    101   acc   = params
    102   acc$s = list(acc$s)
    103   for (j in seq(n - 1)) {
    104       params = gibbs(params$p, params$m, params$s)
    105 
    106       acc$p  = rbind(acc$p, params$p)
    107       acc$m  = mapply(rbind, acc$m, params$m, SIMPLIFY = F)
    108       acc$s  = c(acc$s, list(params$s))
    109       acc$z  = rbind(acc$z, params$z)
    110       acc$l  = c(acc$l, params$l)
    111     }
    112   acc
    113 }
    114