hnuts

No U-Turn Sampling in Haskell.
git clone git://git.jtobin.io/hnuts.git
Log | Files | Refs | README | LICENSE

commit 5ff45efda4530a1465e5521b13b6c94f08d11f34
parent c943afc861b6d0973d78796e35523bbb478ce566
Author: Jared Tobin <jared@jtobin.ca>
Date:   Wed, 16 Oct 2013 10:24:38 +1300

Cabalize.

Diffstat:
ALICENSE | 30++++++++++++++++++++++++++++++
Ahnuts.cabal | 30++++++++++++++++++++++++++++++
Dsrc/Numeric/MCMC/Examples/Examples.hs | 56--------------------------------------------------------
Asrc/Numeric/MCMC/NUTS/Examples.hs | 58++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 118 insertions(+), 56 deletions(-)

diff --git a/LICENSE b/LICENSE @@ -0,0 +1,30 @@ +Copyright (c) 2013, Jared Tobin + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of Jared Tobin nor the names of other + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/hnuts.cabal b/hnuts.cabal @@ -0,0 +1,30 @@ +name: hnuts +version: 1.0.0.0 +synopsis: Automatic gradient-based sampling. +homepage: github.com/jtobin/hnuts +license: BSD3 +license-file: LICENSE +author: Jared Tobin +maintainer: jared@jtobin.ca +category: Math +build-type: Simple +extra-source-files: README.md +cabal-version: >= 1.10 + +library + exposed-modules: Numeric.MCMC.NUTS, + Numeric.MCMC.NUTS.Examples + + other-extensions: ScopedTypeVariables + + build-depends: base >= 4.6, + monad-loops >= 0.4, + primitive >= 0.5, + mwc-random >= 0.12, + statistics >= 0.10, + ad >= 3.4 + + hs-source-dirs: src + + default-language: Haskell2010 + diff --git a/src/Numeric/MCMC/Examples/Examples.hs b/src/Numeric/MCMC/Examples/Examples.hs @@ -1,56 +0,0 @@ --- Various examples, using NUTS with dual-averaging. Insert whatever trace --- (rosenbrockTrace, bnnTrace, etc.) you want into 'main' in order to spit out --- some observations. --- --- A convenient R script to display these traces: --- --- require(ggplot2) --- system('runhaskell Examples.hs > trace.csv') --- d = read.csv('../tests/trace.dat', header = F) --- ggplot(d, aes(V1, V2)) + geom_point(alpha = 0.05, col = 'darkblue') --- - -import Numeric.AD -import Numeric.MCMC.NUTS -import System.Random.MWC - -logRosenbrock :: RealFloat a => [a] -> a -logRosenbrock [x0, x1] = negate (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) - -rosenbrockTrace :: IO [Parameters] -rosenbrockTrace = withSystemRandom . asGenST $ - nutsDualAveraging logRosenbrock (grad logRosenbrock) 10000 1000 [0.0, 0.0] - -logHimmelblau :: RealFloat a => [a] -> a -logHimmelblau [x0, x1] = negate ((x0 ^ 2 + x1 - 11) ^ 2 + (x0 + x1 ^ 2 - 7) ^ 2) - -himmelblauTrace :: IO [Parameters] -himmelblauTrace = withSystemRandom . asGenST $ - nutsDualAveraging logHimmelblau (grad logHimmelblau) 10000 1000 [0.0, 0.0] - -logBnn :: RealFloat a => [a] -> a -logBnn [x0, x1] = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) - -bnnTrace :: IO [Parameters] -bnnTrace = withSystemRandom . asGenST $ - nutsDualAveraging logBnn (grad logBnn) 10000 1000 [0.0, 0.0] - -logBeale :: RealFloat a => [a] -> a -logBeale [x0, x1] - | and [x0 >= -4.5, x0 <= 4.5, x1 >= -4.5, x1 <= 4.5] - = negate $ - (1.5 - x0 + x0 * x1) ^ 2 - + (2.25 - x0 + x0 * x1 ^ 2) ^ 2 - + (2.625 - x0 + x0 * x1 ^ 3) ^ 2 - | otherwise = - (1 / 0) - -bealeTrace :: IO [Parameters] -bealeTrace = withSystemRandom . asGenST $ - nutsDualAveraging logBeale (grad logBeale) 10000 1000 [0.0, 0.0] - -printTrace :: Show a => [a] -> IO () -printTrace = mapM_ (putStrLn . filter (`notElem` "[]") . show) - -main :: IO () -main = bnnTrace >>= printTrace - diff --git a/src/Numeric/MCMC/NUTS/Examples.hs b/src/Numeric/MCMC/NUTS/Examples.hs @@ -0,0 +1,58 @@ +-- Various examples, using NUTS with dual-averaging. Insert whatever trace +-- (rosenbrockTrace, bnnTrace, etc.) you want into 'main' in order to spit out +-- some observations. +-- +-- A convenient R script to display these traces: +-- +-- require(ggplot2) +-- system('runhaskell Examples.hs > trace.csv') +-- d = read.csv('../tests/trace.dat', header = F) +-- ggplot(d, aes(V1, V2)) + geom_point(alpha = 0.05, col = 'darkblue') +-- + +module Numeric.MCMC.NUTS.Examples where + +import Numeric.AD +import Numeric.MCMC.NUTS +import System.Random.MWC + +logRosenbrock :: RealFloat a => [a] -> a +logRosenbrock [x0, x1] = negate (5 * (x1 - x0 ^ 2) ^ 2 + 0.05 * (1 - x0) ^ 2) + +rosenbrockTrace :: IO [Parameters] +rosenbrockTrace = withSystemRandom . asGenST $ + nutsDualAveraging logRosenbrock (grad logRosenbrock) 10000 1000 [0.0, 0.0] + +logHimmelblau :: RealFloat a => [a] -> a +logHimmelblau [x0, x1] = negate ((x0 ^ 2 + x1 - 11) ^ 2 + (x0 + x1 ^ 2 - 7) ^ 2) + +himmelblauTrace :: IO [Parameters] +himmelblauTrace = withSystemRandom . asGenST $ + nutsDualAveraging logHimmelblau (grad logHimmelblau) 10000 1000 [0.0, 0.0] + +logBnn :: RealFloat a => [a] -> a +logBnn [x0, x1] = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1) + +bnnTrace :: IO [Parameters] +bnnTrace = withSystemRandom . asGenST $ + nutsDualAveraging logBnn (grad logBnn) 10000 1000 [0.0, 0.0] + +logBeale :: RealFloat a => [a] -> a +logBeale [x0, x1] + | and [x0 >= -4.5, x0 <= 4.5, x1 >= -4.5, x1 <= 4.5] + = negate $ + (1.5 - x0 + x0 * x1) ^ 2 + + (2.25 - x0 + x0 * x1 ^ 2) ^ 2 + + (2.625 - x0 + x0 * x1 ^ 3) ^ 2 + | otherwise = - (1 / 0) + +bealeTrace :: IO [Parameters] +bealeTrace = withSystemRandom . asGenST $ + nutsDualAveraging logBeale (grad logBeale) 10000 1000 [0.0, 0.0] + +printTrace :: Show a => [a] -> IO () +printTrace = mapM_ (putStrLn . filter (`notElem` "[]") . show) + +main :: IO () +main = bnnTrace >>= printTrace +