bnp

Some older Bayesian nonparametrics research.
git clone git://git.jtobin.io/bnp.git
Log | Files | Refs | README | LICENSE

sim_fmm_3d_conditional.r (1940B)


      1 require(ggplot2)
      2 require(gridExtra)
      3 require(reshape2)
      4 
      5 source('fmm_multivariate_conditional.r')
      6 
      7 dimension = 3
      8 
      9 config = list(
     10     k = 3
     11   , m = dimension
     12   , a = 1
     13   , l = rep(0, dimension)
     14   , r = diag(0.05, dimension)
     15   , b = dimension
     16   , w = diag(1, dimension)
     17   , n = 5000
     18   )
     19 
     20 set.seed(222)
     21 
     22 d = list(
     23     t(replicate(250, rnorm(config$m, c(5, 5))))
     24   , t(replicate(250, rnorm(config$m, c(-5, -5))))
     25   , t(replicate(500, rnorm(config$m))))
     26 dn = lapply(d, function(j) { data.frame(x = j[,1], y = j[,2], z = j[,3]) })
     27 m  = melt(dn, id.vars = c('x', 'y', 'z'))
     28 
     29 set.seed(990909)
     30 
     31 params = inverse_model(
     32     config$n, config$k, m[, c('x', 'y', 'z')]
     33   , config$a
     34   , config$l, config$r
     35   , config$b, config$w
     36   )
     37 
     38 dp = melt(data.frame(params$p))
     39 
     40 dm = melt(lapply(params$m, data.frame), id.vars = c('x', 'y', 'z'))
     41 
     42 py = ggplot(m, aes(x, y)) + geom_point()
     43 
     44 pp = ggplot(dp, aes(seq_along(value), value, colour = variable)) +
     45        geom_line() + facet_grid(. ~ variable)
     46 
     47 pm = ggplot(dm, aes(x, y, colour = factor(L1), fill = factor(L1))) +
     48        geom_point(alpha = 0.5)
     49 
     50 early = data.frame(x = m$x, y = m$y, variable = params$z[1,])
     51 mid   = data.frame(x = m$x, y = m$y, variable = params$z[round(config$n / 2),])
     52 late  = data.frame(x = m$x, y = m$y, variable = params$z[config$n - 1,])
     53 
     54 p_early =
     55   ggplot(early, aes(x, y, colour = factor(variable), fill = factor(variable))) +
     56     geom_point(alpha = 0.5)
     57 
     58 p_mid =
     59   ggplot(mid, aes(x, y, colour = factor(variable), fill = factor(variable))) +
     60     geom_point(alpha = 0.5)
     61 
     62 p_late =
     63   ggplot(late, aes(x, y, value, colour = factor(variable), fill = factor(variable))) +
     64     geom_point(alpha = 0.5)
     65 
     66 mean_convergence_plots =
     67   ggplot(dm, aes(x, y, colour = factor(L1), fill = factor(L1))) +
     68     geom_point(alpha = 0.2) + facet_grid(. ~ L1)
     69 
     70 chain_plots = grid.arrange(pp, mean_convergence_plots, nrow = 2)
     71 
     72 inferred_plots = grid.arrange(py, p_early, p_mid, p_late, nrow = 2, ncol = 2)
     73