sim_fmm_2d_conditional.r (2038B)
1 require(ggplot2) 2 require(gridExtra) 3 require(reshape2) 4 5 source('fmm_multivariate_conditional.r') 6 7 dimension = 2 8 9 config = list( 10 k = 3 11 , m = dimension 12 , a = 1 13 , l = rep(0, dimension) 14 , r = diag(0.05, dimension) 15 , b = 2 16 , w = diag(0.05, dimension) 17 , n = 1000 18 ) 19 20 set.seed(222) 21 22 d = list( 23 t(replicate(250, rnorm(dimension, c(5, 5)))) 24 , t(replicate(250, rnorm(dimension, c(-5, -5)))) 25 , t(replicate(500, rnorm(dimension)))) 26 dn = lapply(d, function(j) { data.frame(x = j[,1], y = j[,2]) }) 27 m = melt(dn, id.vars = c('x', 'y')) 28 29 set.seed(222) 30 31 params = inverse_model( 32 config$n, config$k, m[, c('x', 'y')] 33 , config$a 34 , config$l, config$r 35 , config$b, config$w 36 ) 37 38 dp = melt(data.frame(params$p)) 39 dm = melt(lapply(params$m, data.frame), id.vars = c('x', 'y')) 40 dl = melt(as.data.frame(params$l)) 41 42 py = ggplot(m, aes(x, y)) + geom_point() 43 44 pp = ggplot(dp, aes(seq_along(value), value, colour = variable)) + 45 geom_line() + facet_grid(. ~ variable) 46 47 pm = ggplot(dm, aes(x, y, colour = factor(L1), fill = factor(L1))) + 48 geom_point(alpha = 0.5) 49 50 pl = ggplot(dl, aes(x = seq_along(value), y = value)) + 51 geom_line(colour = 'darkblue') 52 53 early = data.frame(x = m$x, y = m$y, variable = params$z[1,]) 54 mid = data.frame(x = m$x, y = m$y, variable = params$z[round(config$n / 2),]) 55 late = data.frame(x = m$x, y = m$y, variable = params$z[config$n - 1,]) 56 57 p_early = 58 ggplot(early, aes(x, y, colour = factor(variable), fill = factor(variable))) + 59 geom_point(alpha = 0.5) 60 61 p_mid = 62 ggplot(mid, aes(x, y, colour = factor(variable), fill = factor(variable))) + 63 geom_point(alpha = 0.5) 64 65 p_late = 66 ggplot(late, aes(x, y, value, colour = factor(variable), fill = factor(variable))) + 67 geom_point(alpha = 0.5) 68 69 mean_convergence_plots = 70 ggplot(dm, aes(x, y, colour = factor(L1), fill = factor(L1))) + 71 geom_point(alpha = 0.2) + facet_grid(. ~ L1) 72 73 chain_plots = grid.arrange(pp, mean_convergence_plots, nrow = 2) 74 75 inferred_plots = grid.arrange(py, p_early, p_mid, p_late, nrow = 2, ncol = 2) 76