hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

Examples.hs (1913B)


      1 -- Various examples, using NUTS with dual-averaging.  Insert whatever trace
      2 -- (rosenbrockTrace, bnnTrace, etc.) you want into 'main' in order to spit out
      3 -- some observations.
      4 --
      5 -- A convenient R script to display these traces:
      6 --
      7 -- require(ggplot2)
      8 -- system('runhaskell Examples.hs > trace.csv')
      9 -- d = read.csv('../tests/trace.csv', header = F)
     10 -- ggplot(d, aes(V1, V2)) + geom_point(alpha = 0.05, col = 'darkblue')
     11 --
     12 
     13 module Numeric.MCMC.NUTS.Examples where
     14 
     15 import Numeric.AD
     16 import Numeric.MCMC.NUTS
     17 import System.Random.MWC
     18 
     19 logRosenbrock :: RealFloat a => [a] -> a
     20 logRosenbrock [x0, x1] = negate (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2)
     21 
     22 rosenbrockTrace :: IO [Parameters]
     23 rosenbrockTrace = withSystemRandom . asGenST $
     24   nutsDualAveraging logRosenbrock (grad logRosenbrock) 10000 1000 [0.0, 0.0]
     25 
     26 logHimmelblau :: RealFloat a => [a] -> a
     27 logHimmelblau [x0, x1] = negate ((x0 ^ 2 + x1 - 11) ^ 2 + (x0 + x1 ^ 2 - 7) ^ 2)
     28 
     29 himmelblauTrace :: IO [Parameters]
     30 himmelblauTrace = withSystemRandom . asGenST $
     31   nutsDualAveraging logHimmelblau (grad logHimmelblau) 100000 10000 [0.0, 0.0]
     32 
     33 logBnn :: RealFloat a => [a] -> a
     34 logBnn [x0, x1] = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1)
     35 
     36 bnnTrace :: IO [Parameters]
     37 bnnTrace = withSystemRandom . asGenST $
     38   nutsDualAveraging logBnn (grad logBnn) 10000 1000 [0.0, 0.0]
     39 
     40 logBeale :: RealFloat a => [a] -> a
     41 logBeale [x0, x1]
     42   | and [x0 >= -4.5, x0 <= 4.5, x1 >= -4.5, x1 <= 4.5]
     43       = negate $
     44           (1.5   - x0 + x0 * x1) ^ 2
     45         + (2.25  - x0 + x0 * x1 ^ 2) ^ 2
     46         + (2.625 - x0 + x0 * x1 ^ 3) ^ 2
     47   | otherwise = - (1 / 0)
     48 
     49 bealeTrace :: IO [Parameters]
     50 bealeTrace = withSystemRandom . asGenST $
     51   nutsDualAveraging logBeale (grad logBeale) 10000 1000 [0.0, 0.0]
     52 
     53 printTrace :: Show a => [a] -> IO ()
     54 printTrace = mapM_ (putStrLn . filter (`notElem` "[]") . show)
     55 
     56 main :: IO ()
     57 main = himmelblauTrace >>= printTrace
     58