Examples.hs (1913B)
1 -- Various examples, using NUTS with dual-averaging. Insert whatever trace 2 -- (rosenbrockTrace, bnnTrace, etc.) you want into 'main' in order to spit out 3 -- some observations. 4 -- 5 -- A convenient R script to display these traces: 6 -- 7 -- require(ggplot2) 8 -- system('runhaskell Examples.hs > trace.csv') 9 -- d = read.csv('../tests/trace.csv', header = F) 10 -- ggplot(d, aes(V1, V2)) + geom_point(alpha = 0.05, col = 'darkblue') 11 -- 12 13 module Numeric.MCMC.NUTS.Examples where 14 15 import Numeric.AD 16 import Numeric.MCMC.NUTS 17 import System.Random.MWC 18 19 logRosenbrock :: RealFloat a => [a] -> a 20 logRosenbrock [x0, x1] = negate (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) 21 22 rosenbrockTrace :: IO [Parameters] 23 rosenbrockTrace = withSystemRandom . asGenST $ 24 nutsDualAveraging logRosenbrock (grad logRosenbrock) 10000 1000 [0.0, 0.0] 25 26 logHimmelblau :: RealFloat a => [a] -> a 27 logHimmelblau [x0, x1] = negate ((x0 ^ 2 + x1 - 11) ^ 2 + (x0 + x1 ^ 2 - 7) ^ 2) 28 29 himmelblauTrace :: IO [Parameters] 30 himmelblauTrace = withSystemRandom . asGenST $ 31 nutsDualAveraging logHimmelblau (grad logHimmelblau) 100000 10000 [0.0, 0.0] 32 33 logBnn :: RealFloat a => [a] -> a 34 logBnn [x0, x1] = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) 35 36 bnnTrace :: IO [Parameters] 37 bnnTrace = withSystemRandom . asGenST $ 38 nutsDualAveraging logBnn (grad logBnn) 10000 1000 [0.0, 0.0] 39 40 logBeale :: RealFloat a => [a] -> a 41 logBeale [x0, x1] 42 | and [x0 >= -4.5, x0 <= 4.5, x1 >= -4.5, x1 <= 4.5] 43 = negate $ 44 (1.5 - x0 + x0 * x1) ^ 2 45 + (2.25 - x0 + x0 * x1 ^ 2) ^ 2 46 + (2.625 - x0 + x0 * x1 ^ 3) ^ 2 47 | otherwise = - (1 / 0) 48 49 bealeTrace :: IO [Parameters] 50 bealeTrace = withSystemRandom . asGenST $ 51 nutsDualAveraging logBeale (grad logBeale) 10000 1000 [0.0, 0.0] 52 53 printTrace :: Show a => [a] -> IO () 54 printTrace = mapM_ (putStrLn . filter (`notElem` "[]") . show) 55 56 main :: IO () 57 main = himmelblauTrace >>= printTrace 58