commit 4ff236f5b6f666717ec41e568dc95ee90c019491
parent d95eab8907be34b357607b88ba1034bfc061f986
Author: Jared Tobin <jared@jtobin.ca>
Date: Sun, 11 Dec 2016 19:38:52 +1300
Add 'psample' for unequal prob. sampling.
* Minor version bump (0.3.0).
Diffstat:
9 files changed, 101 insertions(+), 35 deletions(-)
diff --git a/.ghci b/.ghci
@@ -0,0 +1,2 @@
+:set prompt "> "
+:set -fno-warn-type-defaults
diff --git a/.gitignore b/.gitignore
@@ -1,3 +1,4 @@
.stack-work
*swp
+tags
diff --git a/CHANGELOG b/CHANGELOG
@@ -0,0 +1,6 @@
+# Changelog
+
+- 0.3.0 (2016-12-11)
+ * Add a 'psample' function for unequal probability sampling.
+
+
diff --git a/README.md b/README.md
@@ -13,6 +13,9 @@ Exports variations on two simple functions for sampling from arbitrary
* *sample*, for sampling without replacement
* *resample*, for sampling with replacement (i.e. a bootstrap)
+Each variation can be prefixed with 'p' to sample from a container of values
+weighted by probability.
+
## Usage
*sampling* uses the PRNG provided by
@@ -43,13 +46,7 @@ Sample five elements from a Map, with replacement:
> resampleIO 5 (Map.fromList [(1, "apple"), (2, "orange"), (3, "pear")])
["apple", "apple", "pear", "orange", "pear"]
-## Development
-
-On the todo list:
-
-* Performance improvements
-* A *psample* function to go with *presample*
-
## Etc.
PRs and issues welcome.
+
diff --git a/lib/Numeric/Sampling.hs b/lib/Numeric/Sampling.hs
@@ -13,6 +13,10 @@ module Numeric.Sampling (
, resample
, resampleIO
+ -- * Unequal probability, without replacement
+ , psample
+ , psampleIO
+
-- * Unequal probability, with replacement
, presample
, presampleIO
@@ -21,17 +25,19 @@ module Numeric.Sampling (
, module System.Random.MWC
) where
-import qualified Control.Foldl as F
-import Control.Monad.Primitive (PrimMonad, PrimState)
-import qualified Data.Foldable as Foldable
+import qualified Control.Foldl as F
+import Control.Monad.Primitive (PrimMonad, PrimState)
+import qualified Data.Foldable as Foldable
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable)
#endif
-import Data.Function (on)
-import Data.List (sortBy)
-import qualified Data.Vector as V (toList)
-import Numeric.Sampling.Internal
-import System.Random.MWC
+import Data.Function (on)
+import Data.List (sortBy)
+import Data.Monoid
+import qualified Data.Sequence as S
+import qualified Data.Vector as V (toList)
+import Numeric.Sampling.Internal
+import System.Random.MWC
-- | (/O(n)/) Sample uniformly, without replacement.
--
@@ -42,7 +48,9 @@ sample
=> Int -> f a -> Gen (PrimState m) -> m (Maybe [a])
sample n xs gen
| n < 0 = return Nothing
- | otherwise = fmap (fmap V.toList) (F.foldM (randomN n gen) xs)
+ | otherwise = do
+ collected <- F.foldM (randomN n gen) xs
+ return $ fmap V.toList collected
{-# INLINABLE sample #-}
-- | (/O(n)/) 'sample' specialized to IO.
@@ -62,12 +70,52 @@ resample n xs = presample n weighted where
{-# INLINABLE resample #-}
-- | (/O(n log n)/) 'resample' specialized to IO.
-resampleIO :: (Foldable f) => Int -> f a -> IO [a]
+resampleIO :: Foldable f => Int -> f a -> IO [a]
resampleIO n xs = do
gen <- createSystemRandom
resample n xs gen
{-# INLINABLE resampleIO #-}
+-- | (/O(n log n)/) Unequal probability sampling.
+--
+-- Returns Nothing if the desired sample size is larger than the collection
+-- being sampled from.
+psample
+ :: (PrimMonad m, Foldable f)
+ => Int -> f (Double, a) -> Gen (PrimState m) -> m (Maybe [a])
+psample n weighted gen = do
+ let sorted = sortProbs weighted
+ computeSample n sorted gen
+ where
+ computeSample
+ :: PrimMonad m
+ => Int -> [(Double, a)] -> Gen (PrimState m) -> m (Maybe [a])
+ computeSample size xs g = go 1 [] size (S.fromList xs) where
+ go !mass !acc j vs
+ | j < 0 = return Nothing
+ | j <= 0 = return (Just acc)
+ | otherwise = do
+ z <- fmap (* mass) (uniform g)
+
+ let cumulative = S.drop 1 $ S.scanl (\s (pr, _) -> s + pr) 0 vs
+ midx = S.findIndexL (>= z) cumulative
+
+ idx = case midx of
+ Nothing -> error "psample: no index found"
+ Just x -> x
+
+ (p, val) = S.index vs idx
+ (l, r) = S.splitAt idx vs
+ deleted = l <> S.drop 1 r
+
+ go (mass - p) (val:acc) (pred j) deleted
+{-# INLINABLE psample #-}
+
+-- | (/O(n log n)/) 'psample' specialized to IO.
+psampleIO :: Foldable f => Int -> f (Double, a) -> IO (Maybe [a])
+psampleIO n weighted = withSystemRandom . asGenIO $ psample n weighted
+{-# INLINABLE psampleIO #-}
+
-- | (/O(n log n)/) Unequal probability resampling.
presample
:: (PrimMonad m, Foldable f)
@@ -90,15 +138,14 @@ presample n weighted gen
case F.fold (F.find ((>= z) . fst)) xs of
Just (_, val) -> go (val:acc) (pred s)
Nothing -> return acc
-
- sortProbs :: (Foldable f, Ord a) => f (a, b) -> [(a, b)]
- sortProbs = sortBy (compare `on` fst) . Foldable.toList
{-# INLINABLE presample #-}
-- | (/O(n log n)/) 'presample' specialized to IO.
presampleIO :: (Foldable f) => Int -> f (Double, a) -> IO [a]
-presampleIO n weighted = do
- gen <- createSystemRandom
- presample n weighted gen
+presampleIO n weighted = withSystemRandom . asGenIO $ presample n weighted
{-# INLINABLE presampleIO #-}
+sortProbs :: (Foldable f, Ord a) => f (a, b) -> [(a, b)]
+sortProbs = sortBy (flip compare `on` fst) . Foldable.toList
+{-# INLINABLE sortProbs #-}
+
diff --git a/sampling.cabal b/sampling.cabal
@@ -1,5 +1,5 @@
name: sampling
-version: 0.2.0
+version: 0.3.0
synopsis: Sample values from collections.
homepage: https://github.com/jtobin/sampling
license: MIT
@@ -18,6 +18,9 @@ description:
* 'sample', for sampling without replacement
.
* 'resample', for sampling with replacement (i.e., a bootstrap)
+ .
+ Each variation can be prefixed with 'p' to sample from a container of values
+ weighted by probability.
Source-repository head
Type: git
@@ -33,18 +36,18 @@ library
exposed-modules:
Numeric.Sampling
build-depends:
- base < 5
+ base > 4 && < 6
+ , containers >= 0.5 && < 1
, foldl >= 1.1 && < 2
, mwc-random >= 0.13 && < 0.14
- , primitive
+ , primitive >= 0.6 && < 1
, vector >= 0.11 && < 0.12
-executable sampling-test
- hs-source-dirs: src
- Main-is: Main.hs
- default-language: Haskell2010
- ghc-options:
- -Wall -O2
+Test-suite resample
+ type: exitcode-stdio-1.0
+ hs-source-dirs: test
+ Main-is: Main.hs
+ default-language: Haskell2010
build-depends:
base
, sampling
diff --git a/stack-travis.yaml b/stack-travis.yaml
@@ -2,7 +2,7 @@ flags: {}
packages:
- '.'
extra-deps: []
-resolver: lts-5.1
-compiler: ghc-7.10.3
+resolver: lts-7.11
+compiler: ghc-8.0.1
system-ghc: false
install-ghc: true
diff --git a/stack.yaml b/stack.yaml
@@ -1,4 +1,4 @@
-resolver: lts-5.1
+resolver: lts-7.11
packages: ['.']
extra-deps: []
flags: {}
diff --git a/test/Main.hs b/test/Main.hs
@@ -0,0 +1,10 @@
+{-# OPTIONS_GHC -fno-warn-type-defaults #-}
+
+module Main where
+
+import Numeric.Sampling
+
+main :: IO ()
+main = do
+ foo <- resampleIO 100 ([1..100000] :: [Int])
+ print foo