sim_fmm_3d_conditional.r (1940B)
1 require(ggplot2) 2 require(gridExtra) 3 require(reshape2) 4 5 source('fmm_multivariate_conditional.r') 6 7 dimension = 3 8 9 config = list( 10 k = 3 11 , m = dimension 12 , a = 1 13 , l = rep(0, dimension) 14 , r = diag(0.05, dimension) 15 , b = dimension 16 , w = diag(1, dimension) 17 , n = 5000 18 ) 19 20 set.seed(222) 21 22 d = list( 23 t(replicate(250, rnorm(config$m, c(5, 5)))) 24 , t(replicate(250, rnorm(config$m, c(-5, -5)))) 25 , t(replicate(500, rnorm(config$m)))) 26 dn = lapply(d, function(j) { data.frame(x = j[,1], y = j[,2], z = j[,3]) }) 27 m = melt(dn, id.vars = c('x', 'y', 'z')) 28 29 set.seed(990909) 30 31 params = inverse_model( 32 config$n, config$k, m[, c('x', 'y', 'z')] 33 , config$a 34 , config$l, config$r 35 , config$b, config$w 36 ) 37 38 dp = melt(data.frame(params$p)) 39 40 dm = melt(lapply(params$m, data.frame), id.vars = c('x', 'y', 'z')) 41 42 py = ggplot(m, aes(x, y)) + geom_point() 43 44 pp = ggplot(dp, aes(seq_along(value), value, colour = variable)) + 45 geom_line() + facet_grid(. ~ variable) 46 47 pm = ggplot(dm, aes(x, y, colour = factor(L1), fill = factor(L1))) + 48 geom_point(alpha = 0.5) 49 50 early = data.frame(x = m$x, y = m$y, variable = params$z[1,]) 51 mid = data.frame(x = m$x, y = m$y, variable = params$z[round(config$n / 2),]) 52 late = data.frame(x = m$x, y = m$y, variable = params$z[config$n - 1,]) 53 54 p_early = 55 ggplot(early, aes(x, y, colour = factor(variable), fill = factor(variable))) + 56 geom_point(alpha = 0.5) 57 58 p_mid = 59 ggplot(mid, aes(x, y, colour = factor(variable), fill = factor(variable))) + 60 geom_point(alpha = 0.5) 61 62 p_late = 63 ggplot(late, aes(x, y, value, colour = factor(variable), fill = factor(variable))) + 64 geom_point(alpha = 0.5) 65 66 mean_convergence_plots = 67 ggplot(dm, aes(x, y, colour = factor(L1), fill = factor(L1))) + 68 geom_point(alpha = 0.2) + facet_grid(. ~ L1) 69 70 chain_plots = grid.arrange(pp, mean_convergence_plots, nrow = 2) 71 72 inferred_plots = grid.arrange(py, p_early, p_mid, p_late, nrow = 2, ncol = 2) 73