bnp

Some older Bayesian nonparametrics research.
git clone git://git.jtobin.io/bnp.git
Log | Files | Refs | README | LICENSE

fmm_conditional.r (3317B)


      1 require(dplyr)
      2 
      3 source('fmm_generative.r')
      4 source('fmm_utils.r')
      5 
      6 conditional_mixing_model = function(y, k, z, a) {
      7   labelled = data.frame(value = y, L1 = z)
      8   counts   = summarise(group_by(labelled, L1), count = n())
      9 
     10   concentration = sapply(
     11       seq(k)
     12     , function(cluster) {
     13         idx = which(counts$L1 == cluster)
     14         if (length(idx) != 0) {
     15           counts$count[idx] + a / k
     16         } else {
     17           a / k
     18         }
     19       })
     20 
     21   drop(rdirichlet(1, concentration))
     22 }
     23 
     24 conditional_label_model = function(y, p, m, s) {
     25   scorer     = function(mix, mu, prec) {
     26     exp(log(mix) + dnorm(y, mu, sqrt(1 / prec), log = T))
     27   }
     28   unweighted = mapply(scorer, p, m, s)
     29   weights    = 1 / apply(unweighted, MARGIN = 1, sum)
     30   weighted   = weights * unweighted
     31 
     32   probabilize = function(row) {
     33     rs = sum(row)
     34     if (rs == 0 || is.na(rs) || is.nan(rs)) {
     35       drop(rdirichlet(1, rep(1, length(m))))
     36     } else {
     37       row
     38     }
     39   }
     40 
     41   probs = t(apply(weighted, MARGIN = 1, probabilize))
     42   apply(
     43       probs
     44     , MARGIN = 1
     45     , function(row) { sample(seq_along(m), size = 1, prob = row) }
     46     )
     47 }
     48 
     49 conditional_location_model = function(y, z, s, l, r) {
     50   clustered = group_by(data.frame(value = y, L1 = z), L1)
     51   lengths   = summarise(clustered, value = n())
     52   sums      = summarise(clustered, value = sum(value))
     53 
     54   n = sapply(seq_along(s),
     55     function(cluster) {
     56       idx = which(lengths$L1 == cluster)
     57       if (length(idx) != 0) {
     58         lengths$value[idx]
     59       } else {
     60         0
     61       }
     62     })
     63 
     64   yt = sapply(seq_along(s),
     65     function(cluster) {
     66       idx = which(sums$L1 == cluster)
     67       if (length(idx) != 0) {
     68         sums$value[idx]
     69       } else {
     70         0
     71       }
     72     })
     73 
     74   m  = (yt * s + l * r) / (n * s + r)
     75   v  = 1 / (n * s + r)
     76 
     77   mapply(rnorm, 1, m, sqrt(v))
     78 }
     79 
     80 conditional_precision_model = function(y, z, m, b, w) {
     81   labelled  = data.frame(value = y, L1 = z)
     82   clustered = group_by(labelled, L1)
     83 
     84   acc = list()
     85   for (j in seq_along(m)) {
     86     acc[[j]] = labelled[which(labelled$L1 == j), 'value']
     87     }
     88 
     89   centered = mapply("-", acc, m)
     90   squared  = lapply(centered, function(x) x ^ 2)
     91   ss       = unlist(lapply(squared, sum))
     92 
     93   n = sapply(seq_along(m),
     94     function(cluster) {
     95       lengths = summarise(clustered, value = n())
     96       idx     = which(lengths$L1 == cluster)
     97       if (length(idx) != 0) {
     98         lengths$value[idx]
     99         } else {
    100         0
    101         }
    102       })
    103 
    104   a   = b + n
    105   bet = (w * b + ss) / a
    106 
    107   mapply(function(a, b) rgamma(1, a, b), a, bet)
    108 }
    109 
    110 inverse_model = function(n, k, y, a, l, r, b, w) {
    111   gibbs = function(p0, m0, s0) {
    112     z  = conditional_label_model(y, p0, m0, s0)
    113     p1 = conditional_mixing_model(y, k, z, a)
    114     m1 = conditional_location_model(y, z, s0, l, r)
    115     s1 = conditional_precision_model(y, z, m1, b, w)
    116     l  = lmodel(y, p1, m1, s1)
    117     list(p = p1, m = m1, s = s1, z = z, l = l)
    118     }
    119 
    120   p0     = mixing_model(k, a)
    121   m0     = location_model(k, l, r)
    122   s0     = precision_model(k, b, w)
    123   params = list(p = p0, m = m0, s = s0)
    124 
    125   acc = params
    126   for (j in seq(n - 1)) {
    127       params = gibbs(params$p, params$m, params$s)
    128       acc$p  = rbind(acc$p, params$p)
    129       acc$m  = rbind(acc$m, params$m)
    130       acc$s  = rbind(acc$s, params$s)
    131       acc$z  = rbind(acc$z, params$z)
    132       acc$l  = c(acc$l, params$l)
    133     }
    134   acc
    135 }
    136