NUTS.hs (12092B)
1 -- | See Hoffman, Gelman (2011) The No U-Turn Sampler: Adaptively Setting Path 2 -- Lengths in Hamiltonian Monte Carlo. 3 4 module Numeric.MCMC.NUTS where 5 6 import Control.Monad 7 import Control.Monad.Loops 8 import Control.Monad.Primitive 9 import System.Random.MWC 10 import System.Random.MWC.Distributions hiding (gamma) 11 import Statistics.Distribution.Normal 12 13 type Parameters = [Double] 14 type Density = Parameters -> Double 15 type Gradient = Parameters -> Parameters 16 type Particle = (Parameters, Parameters) 17 type StateSpecs = ([Double], [Double], [Double], [Double], [Double], Int, Int) 18 19 newtype StateTree = StateTree { 20 getStateTree :: StateSpecs 21 } 22 23 instance Show StateTree where 24 show (StateTree (tm, rm, tp, rp, t', n, s)) = 25 "\n" ++ "tm: " ++ show tm 26 ++ "\n" ++ "tp: " ++ show tp 27 ++ "\n" ++ "t': " ++ show t' 28 ++ "\n" ++ "n : " ++ show n 29 ++ "\n" ++ "s : " ++ show s 30 31 data DualAveragingParameters = DualAveragingParameters { 32 mAdapt :: Int 33 , delta :: Double 34 , mu :: Double 35 , gamma :: Double 36 , tau0 :: Double 37 , kappa :: Double 38 } deriving Show 39 40 -- | The NUTS sampler. 41 nuts 42 :: PrimMonad m 43 => Density 44 -> Gradient 45 -> Int 46 -> Double 47 -> Parameters 48 -> Gen (PrimState m) 49 -> m [Parameters] 50 nuts lTarget glTarget n e t g = go t 0 [] 51 where go position j acc 52 | j >= n = return acc 53 | otherwise = do 54 nextPosition <- nutsKernel lTarget glTarget e position g 55 go nextPosition (succ j) (nextPosition : acc) 56 57 -- | The NUTS sampler with dual averaging. 58 nutsDualAveraging 59 :: PrimMonad m 60 => Density 61 -> Gradient 62 -> Int 63 -> Int 64 -> Parameters 65 -> Gen (PrimState m) 66 -> m [Parameters] 67 nutsDualAveraging lTarget glTarget n nAdapt t g = do 68 e0 <- findReasonableEpsilon lTarget glTarget t g 69 let daParams = basicDualAveragingParameters e0 nAdapt 70 chain <- unfoldrM (kernel daParams) (1, e0, 1, 0, t) 71 return $ drop nAdapt chain 72 where 73 kernel params (m, e, eAvg, h, t0) = do 74 (eNext, eAvgNext, hNext, tNext) <- 75 nutsKernelDualAvg lTarget glTarget e eAvg h m params t0 g 76 return $ if m > n + nAdapt 77 then Nothing 78 else Just (t0, (succ m, eNext, eAvgNext, hNext, tNext)) 79 80 -- | Default DA parameters, given a base step size and burn in period. 81 basicDualAveragingParameters :: Double -> Int -> DualAveragingParameters 82 basicDualAveragingParameters step burnInPeriod = DualAveragingParameters { 83 mu = log (10 * step) 84 , delta = 0.5 85 , mAdapt = burnInPeriod 86 , gamma = 0.05 87 , tau0 = 10 88 , kappa = 0.75 89 } 90 91 -- | A single iteration of dual-averaging NUTS. 92 nutsKernelDualAvg 93 :: PrimMonad m 94 => Density 95 -> Gradient 96 -> Double 97 -> Double 98 -> Double 99 -> Int 100 -> DualAveragingParameters 101 -> [Double] 102 -> Gen (PrimState m) 103 -> m (Double, Double, Double, Parameters) 104 nutsKernelDualAvg lTarget glTarget e eAvg h m daParams t g = do 105 r0 <- replicateM (length t) (normal 0 1 g) 106 z0 <- exponential 1 g 107 let logu = log (auxilliaryTarget lTarget t r0) - z0 108 109 let go (tn, tp, rn, rp, tm, j, n, s, a, na) g 110 | s == 1 = do 111 vj <- symmetricCategorical [-1, 1] g 112 z <- uniform g 113 114 (tnn, rnn, tpp, rpp, t1, n1, s1, a1, na1) <- 115 if vj == -1 116 then do 117 (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- 118 buildTreeDualAvg lTarget glTarget g tn rn logu vj j e t r0 119 return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1') 120 else do 121 (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- 122 buildTreeDualAvg lTarget glTarget g tp rp logu vj j e t r0 123 return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') 124 125 let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z 126 127 n2 = n + n1 128 s2 = s1 * stopCriterion tnn tpp rnn rpp 129 j1 = succ j 130 t2 | accept = t1 131 | otherwise = tm 132 133 go (tnn, tpp, rnn, rpp, t2, j1, n2, s2, a1, na1) g 134 135 | otherwise = return (tm, a, na) 136 137 (nextPosition, alpha, nalpha) <- go (t, t, r0, r0, t, 0, 1, 1, 0, 0) g 138 139 let (hNext, eNext, eAvgNext) = 140 if m <= mAdapt daParams 141 then (hm, exp logEm, exp logEbarM) 142 else (h, eAvg, eAvg) 143 where 144 eta = 1 / (fromIntegral m + tau0 daParams) 145 hm = (1 - eta) * h 146 + eta * (delta daParams - alpha / fromIntegral nalpha) 147 148 zeta = fromIntegral m ** (- (kappa daParams)) 149 150 logEm = mu daParams - sqrt (fromIntegral m) / gamma daParams * hm 151 logEbarM = (1 - zeta) * log eAvg + zeta * logEm 152 153 return (eNext, eAvgNext, hNext, nextPosition) 154 155 -- | A single iteration of NUTS. 156 nutsKernel 157 :: PrimMonad m 158 => Density 159 -> Gradient 160 -> Double 161 -> Parameters 162 -> Gen (PrimState m) 163 -> m Parameters 164 nutsKernel lTarget glTarget e t g = do 165 r0 <- replicateM (length t) (normal 0 1 g) 166 z0 <- exponential 1 g 167 let logu = log (auxilliaryTarget lTarget t r0) - z0 168 169 let go (tn, tp, rn, rp, tm, j, n, s) g 170 | s == 1 = do 171 vj <- symmetricCategorical [-1, 1] g 172 z <- uniform g 173 174 (tnn, rnn, tpp, rpp, t1, n1, s1) <- 175 if vj == -1 176 then do 177 (tnn', rnn', _, _, t1', n1', s1') <- 178 buildTree lTarget glTarget g tn rn logu vj j e 179 return (tnn', rnn', tp, rp, t1', n1', s1') 180 else do 181 (_, _, tpp', rpp', t1', n1', s1') <- 182 buildTree lTarget glTarget g tp rp logu vj j e 183 return (tn, rn, tpp', rpp', t1', n1', s1') 184 185 let accept = s1 == 1 && (min 1 (fi n1 / fi n :: Double)) > z 186 187 n2 = n + n1 188 s2 = s1 * stopCriterion tnn tpp rnn rpp 189 j1 = succ j 190 t2 | accept = t1 191 | otherwise = tm 192 193 go (tnn, tpp, rnn, rpp, t2, j1, n2, s2) g 194 195 | otherwise = return tm 196 197 go (t, t, r0, r0, t, 0, 1, 1) g 198 199 -- | Build the 'tree' of candidate states. 200 buildTree 201 :: PrimMonad m 202 => Density 203 -> Gradient 204 -> Gen (PrimState m) 205 -> Parameters 206 -> Parameters 207 -> Double 208 -> Double 209 -> Int 210 -> Double 211 -> m StateSpecs 212 buildTree lTarget glTarget g t r logu v 0 e = do 213 let (t0, r0) = leapfrog glTarget (t, r) (v * e) 214 joint = log $ auxilliaryTarget lTarget t0 r0 215 n = indicate (logu < joint) 216 s = indicate (logu - 1000 < joint) 217 return (t0, r0, t0, r0, t0, n, s) 218 219 buildTree lTarget glTarget g t r logu v j e = do 220 z <- uniform g 221 (tn, rn, tp, rp, t0, n0, s0) <- 222 buildTree lTarget glTarget g t r logu v (pred j) e 223 224 if s0 == 1 225 then do 226 (tnn, rnn, tpp, rpp, t1, n1, s1) <- 227 if v == -1 228 then do 229 (tnn', rnn', _, _, t1', n1', s1') <- 230 buildTree lTarget glTarget g tn rn logu v (pred j) e 231 return (tnn', rnn', tp, rp, t1', n1', s1') 232 else do 233 (_, _, tpp', rpp', t1', n1', s1') <- 234 buildTree lTarget glTarget g tp rp logu v (pred j) e 235 return (tn, rn, tpp', rpp', t1', n1', s1') 236 237 let accept = (fi n1 / max (fi (n0 + n1)) 1) > (z :: Double) 238 n2 = n0 + n1 239 s2 = s0 * s1 * stopCriterion tnn tpp rnn rpp 240 t2 | accept = t1 241 | otherwise = t0 242 243 return (tnn, rnn, tpp, rpp, t2, n2, s2) 244 else return (tn, rn, tp, rp, t0, n0, s0) 245 246 -- | Determine whether or not to stop doubling the tree of candidate states. 247 stopCriterion :: (Integral a, Num b, Ord b) => [b] -> [b] -> [b] -> [b] -> a 248 stopCriterion tn tp rn rp = 249 indicate (positionDifference `innerProduct` rn >= 0) 250 * indicate (positionDifference `innerProduct` rp >= 0) 251 where 252 positionDifference = tp .- tn 253 254 -- | Build the tree of candidate states under dual averaging. 255 buildTreeDualAvg 256 :: PrimMonad m 257 => Density 258 -> Gradient 259 -> Gen (PrimState m) 260 -> Parameters 261 -> Parameters 262 -> Double 263 -> Double 264 -> Int 265 -> Double 266 -> Parameters 267 -> Parameters 268 -> m ([Double], [Double], [Double], [Double], [Double], Int, Int, Double, Int) 269 buildTreeDualAvg lTarget glTarget g t r logu v 0 e t0 r0 = do 270 let (t1, r1) = leapfrog glTarget (t, r) (v * e) 271 joint = log $ auxilliaryTarget lTarget t1 r1 272 n = indicate (logu < joint) 273 s = indicate (logu - 1000 < joint) 274 a = min 1 (acceptanceRatio lTarget t0 t1 r0 r1) 275 return (t1, r1, t1, r1, t1, n, s, a, 1) 276 277 buildTreeDualAvg lTarget glTarget g t r logu v j e t0 r0 = do 278 z <- uniform g 279 (tn, rn, tp, rp, t1, n1, s1, a1, na1) <- 280 buildTreeDualAvg lTarget glTarget g t r logu v (pred j) e t0 r0 281 282 if s1 == 1 283 then do 284 (tnn, rnn, tpp, rpp, t2, n2, s2, a2, na2) <- 285 if v == -1 286 then do 287 (tnn', rnn', _, _, t1', n1', s1', a1', na1') <- 288 buildTreeDualAvg lTarget glTarget g tn rn logu v (pred j) e t0 r0 289 return (tnn', rnn', tp, rp, t1', n1', s1', a1', na1') 290 else do 291 (_, _, tpp', rpp', t1', n1', s1', a1', na1') <- 292 buildTreeDualAvg lTarget glTarget g tp rp logu v (pred j) e t0 r0 293 return (tn, rn, tpp', rpp', t1', n1', s1', a1', na1') 294 295 let p = fi n2 / max (fi (n1 + n2)) 1 296 accept = p > (z :: Double) 297 n3 = n1 + n2 298 a3 = a1 + a2 299 na3 = na1 + na2 300 s3 = s1 * s2 * stopCriterion tnn tpp rnn rpp 301 302 t3 | accept = t2 303 | otherwise = t1 304 305 return (tnn, rnn, tpp, rpp, t3, n3, s3, a3, na3) 306 else return (tn, rn, tp, rp, t1, n1, s1, a1, na1) 307 308 -- | Heuristic for initializing step size. 309 findReasonableEpsilon 310 :: PrimMonad m 311 => Density 312 -> Gradient 313 -> Parameters 314 -> Gen (PrimState m) 315 -> m Double 316 findReasonableEpsilon lTarget glTarget t0 g = do 317 r0 <- replicateM (length t0) (normal 0 1 g) 318 let (t1, r1) = leapfrog glTarget (t0, r0) 1.0 319 a = 2 * indicate (acceptanceRatio lTarget t0 t1 r0 r1 > 0.5) - 1 320 321 go j e t r 322 | j <= 0 = e -- no need to shrink this excessively 323 | (acceptanceRatio lTarget t0 t r0 r) ^^ a > 2 ^^ (-a) = 324 let (tn, rn) = leapfrog glTarget (t, r) e 325 in go (pred j) (2 ^^ a * e) tn rn 326 | otherwise = e 327 328 return $ go 10 1.0 t1 r1 329 330 -- | Simulate a single step of Hamiltonian dynamics. 331 leapfrog :: Gradient -> Particle -> Double -> Particle 332 leapfrog glTarget (t, r) e = (tf, rf) 333 where 334 rm = adjustMomentum glTarget e t r 335 tf = adjustPosition e rm t 336 rf = adjustMomentum glTarget e tf rm 337 338 -- | Adjust momentum. 339 adjustMomentum :: Fractional c => (t -> [c]) -> c -> t -> [c] -> [c] 340 adjustMomentum glTarget e t r = r .+ ((e / 2) .* glTarget t) 341 342 -- | Adjust position. 343 adjustPosition :: Num c => c -> [c] -> [c] -> [c] 344 adjustPosition e r t = t .+ (e .* r) 345 346 -- | The MH acceptance ratio for a given proposal. 347 acceptanceRatio :: Floating a => (t -> a) -> t -> t -> [a] -> [a] -> a 348 acceptanceRatio lTarget t0 t1 r0 r1 = auxilliaryTarget lTarget t1 r1 349 / auxilliaryTarget lTarget t0 r0 350 351 -- | The negative potential. 352 auxilliaryTarget :: Floating a => (t -> a) -> t -> [a] -> a 353 auxilliaryTarget lTarget t r = exp (lTarget t - 0.5 * innerProduct r r) 354 355 -- | Simple inner product. 356 innerProduct :: Num a => [a] -> [a] -> a 357 innerProduct xs ys = sum $ zipWith (*) xs ys 358 359 -- | Vectorized multiplication. 360 (.*) :: Num b => b -> [b] -> [b] 361 z .* xs = map (* z) xs 362 363 -- | Vectorized subtraction. 364 (.-) :: Num a => [a] -> [a] -> [a] 365 xs .- ys = zipWith (-) xs ys 366 367 -- | Vectorized addition. 368 (.+) :: Num a => [a] -> [a] -> [a] 369 xs .+ ys = zipWith (+) xs ys 370 371 -- | Indicator function. 372 indicate :: Integral a => Bool -> a 373 indicate True = 1 374 indicate False = 0 375 376 -- | A symmetric categorical (discrete uniform) distribution. 377 symmetricCategorical :: PrimMonad m => [a] -> Gen (PrimState m) -> m a 378 symmetricCategorical [] _ = error "symmetricCategorical: no candidates" 379 symmetricCategorical zs g = do 380 j <- uniformR (0, length zs - 1) g 381 return $ zs !! j 382 383 -- | Alias for fromIntegral. 384 fi :: (Integral a, Num b) => a -> b 385 fi = fromIntegral 386