]> gitweb.michael.orlitzky.com - spline3.git/blobdiff - src/Misc.hs
Define an export list in Misc, clean up imports.
[spline3.git] / src / Misc.hs
index 16b0ead151ad1b29190028db01aa98ef34d6b797..09773bf56655b13453ed7dff6e49b160bf70eb11 100644 (file)
@@ -1,14 +1,24 @@
+{-# LANGUAGE BangPatterns #-}
+
 -- | The Misc module contains helper functions that seem out of place
 --   anywhere else.
-module Misc
+--
+module Misc (
+  all_equal,
+  disjoint,
+  factorial,
+  flatten,
+  misc_properties,
+  misc_tests,
+  transpose_xz )
 where
 
-import Data.List (intersect)
-import Test.Framework (Test, testGroup)
-import Test.Framework.Providers.HUnit (testCase)
-import Test.Framework.Providers.QuickCheck2 (testProperty)
-import Test.HUnit
-import Test.QuickCheck
+import qualified Data.Vector as V ( Vector, elem, empty, filter )
+import Test.Framework ( Test, testGroup )
+import Test.Framework.Providers.HUnit ( testCase )
+import Test.Framework.Providers.QuickCheck2 ( testProperty )
+import Test.HUnit ( Assertion, assertEqual )
+import Test.QuickCheck ( Property, (==>) )
 
 
 -- | The standard factorial function. See
@@ -24,11 +34,12 @@ import Test.QuickCheck
 --   24
 --
 factorial :: Int -> Int
-factorial n
-    | n <= 1 = 1
-    | n > 20 = error "integer overflow in factorial function"
-    | otherwise = product [1..n]
-
+factorial !n =
+  go 1 n
+  where
+    go !acc !i
+      | i <= 1    = acc
+      | otherwise = go (acc * i) (i - 1)
 
 -- | Takes a three-dimensional list, and flattens it into a
 --   one-dimensional one.
@@ -58,30 +69,38 @@ transpose_xz m =
 
 -- | Takes a list, and returns True if its elements are pairwise
 --   equal. Returns False otherwise.
+--
+--   Only used in tests.
+--
 all_equal :: (Eq a) => [a] -> Bool
-all_equal xs =
-    all (== first_element) other_elements
-    where
-      first_element  = head xs
-      other_elements = tail xs
+all_equal [] = True -- Vacuously
+all_equal (x:xs) = all (== x) xs
+
 
 
--- | Returns 'True' if the lists xs and ys are disjoint, 'False'
+-- | Returns 'True' if the vectors xs and ys are disjoint, 'False'
 --   otherwise.
 --
 --   Examples:
 --
---   >>> disjoint [1,2,3] [4,5,6]
+--   >>> let xs = Data.Vector.fromList [1,2,3]
+--   >>> let ys = Data.Vector.fromList [4,5,6]
+--   >>> disjoint xs ys
 --   True
 --
---   >>>  disjoint [1,2,3] [3,4,5]
+--   >>> let ys = Data.Vector.fromList [3,4,5]
+--   >>> disjoint xs ys
 --   False
 --
-disjoint :: (Eq a) => [a] -> [a] -> Bool
+--   Only used in tests.
+--
+disjoint :: (Eq a) => V.Vector a -> V.Vector a -> Bool
 disjoint xs ys =
-    intersect xs ys == []
-
-
+  intersect xs ys == V.empty
+  where
+    intersect :: (Eq a) => V.Vector a -> V.Vector a -> V.Vector a
+    intersect ws zs =
+      V.filter (`V.elem` zs) ws
 
 prop_factorial_greater :: Int -> Property
 prop_factorial_greater n =