]> gitweb.michael.orlitzky.com - list-remote-forwards.git/blobdiff - src/Report.hs
Get things in shape finally:
[list-remote-forwards.git] / src / Report.hs
index b15da98b251b67d822e1ae655fcd9ca27563066b..5b3853350efc3a07c581ad20b354741c3d127cc5 100644 (file)
@@ -5,35 +5,44 @@ module Report (
   report_tests )
 where
 
-import Control.Monad ( filterM )
-import qualified Data.ByteString.Char8 as BS ( pack )
+import Data.Map ( mapKeys )
+import qualified Data.Map as Map ( lookup )
 import Data.Maybe ( catMaybes, listToMaybe )
-import Data.String.Utils ( join, split, strip )
+import Data.Set ( fromList, isSubsetOf )
+import qualified Data.Set as Set ( map )
+import Data.String.Utils ( join )
 import Database.HDBC (
   IConnection,
   execute,
   prepare,
   sFetchAllRows')
 import Database.HDBC.Sqlite3 ( connectSqlite3 )
-import Data.List ( (\\) )
 import System.Console.CmdArgs.Default ( Default(..) )
 import Test.Tasty ( TestTree, testGroup )
 import Test.Tasty.HUnit ( (@?=), testCase )
 
 import Configuration ( Configuration(..) )
-import DNS ( lookup_mxs, normalize )
+import DNS ( MxSetMap, mx_set_map, normalize_string_domain )
+import Forward (
+  Forward(..),
+  address_domain,
+  dropby_goto_domains,
+  pretty_print,
+  strings_to_forwards )
 import MxList ( MxList(..) )
 
--- Type synonyms to make the signatures below a little more clear.
+-- | Type synonym to make the signatures below a little more clear.
+--   WARNING: Also defined in the "Forward" module.
 type Domain = String
-type Address = String
-type Goto = String
 
-data Forward =
-  Forward Address Goto
-  deriving (Show)
-
-get_domain_list :: IConnection a => a -> String -> IO [Domain]
+-- | Given a connection @conn@ and a @query@, return a list of domains
+--   found by executing @query@ on @conn. The @query@ is assumed to
+--   return only one column, containing domains.
+--
+get_domain_list :: IConnection a
+                => a -- ^ A database connection
+                -> String -- ^ The @query@ to execute
+                -> IO [Domain] -- ^ The list of domains returned from @query@
 get_domain_list conn query = do
   stmt <- prepare conn query
 
@@ -54,7 +63,18 @@ get_domain_list conn query = do
   return domains
 
 
-get_forward_list :: IConnection a => a -> String -> IO [Forward]
+
+
+-- | Given a connection @conn@ and a @query@, return a list of
+--   forwards found by executing @query@ on @conn. The @query@ is
+--   assumed to return two columns, the first containing addresses and
+--   the second containing a comma-separated list of gotos (as a
+--   string).
+--
+get_forward_list :: IConnection a
+                 => a  -- ^ A database connection
+                 -> String -- ^ The @query@ to execute
+                 -> IO [Forward]  -- ^ A list of forwards returned from @query@
 get_forward_list conn query = do
   stmt <- prepare conn query
 
@@ -67,56 +87,133 @@ get_forward_list conn query = do
   rows <- sFetchAllRows' stmt
 
   -- forwards :: [Forward]
-  let forwards = concatMap (row_to_forwards . catMaybes) rows
+  let forwards = concatMap (strings_to_forwards . catMaybes) rows
 
   return forwards
-  where
-    row_to_forwards :: [String] -> [Forward]
-    row_to_forwards (addr:gotos:_) =
-      [Forward addr (strip g) | g <- split "," gotos]
-    row_to_forwards _ = []
 
 
 
-find_remote_forwards :: [Domain] -> [Forward] -> [Forward]
-find_remote_forwards domains forwards =
-  filter is_remote forwards
-  where
-    is_remote :: Forward -> Bool
-    is_remote (Forward _ goto) =
-      let parts = split "@" goto
-      in
-        case parts of
-          (_:dp:[]) -> not $ dp `elem` domains
-          _        -> True -- Assume it's remote if something is wrong
-
-
-format_forward :: Forward -> String
-format_forward (Forward addr goto) =
-  addr ++ " -> " ++ goto
-
-
--- If the MX records for a domain are exactly those contained in the
--- MxList, then we exclude that domain from the report. Splitting on
--- the '@' is a lazy way of obtaining the domain, but if it's good
--- enough for determining that a forward is remote, then it's good
--- enough for this.
-filter_by_mx :: MxList -> [Forward] -> IO [Forward]
--- This special case is necessary! Otherwise if we have an empty
--- exclude list and a domain that has no MX record, it will be
--- excluded.
-filter_by_mx (MxList [])  = return
-filter_by_mx (MxList mxs) =
-  filterM all_mxs_excluded
+-- | A filter function to remove specific 'Forward's from a list (of
+--   forwards).  Its intended usage is to ignore a 'Forward' if its
+--   'Address' has MX records that are all contained in the given
+--   list. This could be useful if, for example, one MX has strict
+--   spam filtering and remote forwards are not a problem for domains
+--   with that MX.
+--
+--   If the MX records for a domain are contained in the 'MxList',
+--   then we exclude that domain from the report.
+--
+--   For performance reasons, we want to have precomputed the MX
+--   records for all of the address domains in our list of
+--   forwards. We do this so we don't look up the MX records twice for
+--   two addresses within the same domain. We could just as well do
+--   this within this function, but by taking the @domain_mxs@ as a
+--   parameter, we allow ourselves to be a pure function.
+--
+--   If the domain of a forward address can't be determined, it won't
+--   be dropped! This is intentional: the existence of a forward
+--   address without a domain part probably indicates a configuration
+--   error somewhere, and we should report it.
+--
+--   The empty @MxList []@ special case is necessary! Otherwise if we
+--   have an empty exclude list and a domain that has no MX record, it
+--   will be excluded.
+--
+--   ==== __Examples__
+--
+--   Our single forward should be dropped from the list, because its
+--   MX record list, ["mx.example.com"], is contained in the list of
+--   excluded MXs:
+--
+--   >>> import qualified Data.Map as Map ( fromList )
+--   >>> import qualified Data.Set as Set ( fromList )
+--   >>> import Forward ( fwd )
+--   >>> let fwds = [fwd "user1@example.com" "user2@example.net"]
+--   >>> let mx_set = Set.fromList ["mx.example.com"]
+--   >>> let example_mx_pairs = [("example.com.", mx_set)]
+--   >>> let mx_map = Map.fromList example_mx_pairs
+--   >>> let droplist = MxList ["mx.example.com", "mx2.example.com"]
+--   >>> map pretty_print (dropby_mxlist droplist mx_map fwds)
+--   []
+--
+--   Repeat the previous test with the goto domain, to make sure we're
+--   dropping based on the address and not the goto:
+--
+--   >>> import qualified Data.Map as Map ( fromList )
+--   >>> import qualified Data.Set as Set ( fromList )
+--   >>> let fwds = [fwd "user1@example.com" "user2@example.net"]
+--   >>> let mx_set = Set.fromList ["mx.example.net"]
+--   >>> let example_mx_pairs = [("example.net.", mx_set)]
+--   >>> let mx_map = Map.fromList example_mx_pairs
+--   >>> let droplist = MxList ["mx.example.net", "mx2.example.net"]
+--   >>> map pretty_print (dropby_mxlist droplist mx_map fwds)
+--   ["user1@example.com -> user2@example.net"]
+--
+--   Use weird caps, and optional trailing dot all over the place to
+--   make sure everything is handled normalized:
+--
+--   >>> import qualified Data.Set as Set ( fromList )
+--   >>> import Forward ( fwd )
+--   >>> let fwds = [fwd "user1@exAmPle.com." "user2@examPle.net"]
+--   >>> let mx_set = Set.fromList ["mx.EXAMPLE.com"]
+--   >>> let example_mx_pairs = [("Example.com", mx_set)]
+--   >>> let mx_map = Map.fromList example_mx_pairs
+--   >>> let droplist = MxList ["mx.EXAMple.com", "mx2.example.COM"]
+--   >>> map pretty_print (dropby_mxlist droplist mx_map fwds)
+--   []
+--
+--   This time it shouldn't be dropped, because ["mx.example.com"] is
+--   not contained in ["nope.example.com"]:
+--
+--   >>> import qualified Data.Map as Map ( fromList )
+--   >>> import qualified Data.Set as Set ( fromList )
+--   >>> let fwds = [fwd "user1@example.com" "user2@example.net"]
+--   >>> let mx_set = Set.fromList ["mx.example.com"]
+--   >>> let example_mx_pairs = [("example.com.", mx_set)]
+--   >>> let mx_map = Map.fromList example_mx_pairs
+--   >>> let droplist = MxList ["nope.example.com"]
+--   >>> map pretty_print (dropby_mxlist droplist mx_map fwds)
+--   ["user1@example.com -> user2@example.net"]
+--
+--   Now we check that if a forward has two MXes, one of which appears
+--   in the list of excluded MXes, it doesn't get dropped:
+--
+--   >>> import qualified Data.Map as Map ( fromList )
+--   >>> import qualified Data.Set as Set ( fromList )
+--   >>> let fwds = [fwd "user1@example.com" "user2@example.net"]
+--   >>> let mx_set = Set.fromList ["mx1.example.com", "mx2.example.com"]
+--   >>> let example_mx_pairs = [("example.com.", mx_set)]
+--   >>> let mx_map = Map.fromList example_mx_pairs
+--   >>> let droplist = MxList ["mx1.example.com"]
+--   >>> map pretty_print (dropby_mxlist droplist mx_map fwds)
+--   ["user1@example.com -> user2@example.net"]
+--
+dropby_mxlist :: MxList -> MxSetMap -> [Forward] -> [Forward]
+dropby_mxlist (MxList []) _ = id
+dropby_mxlist (MxList mxs) domain_mx_map =
+  filter (not . is_bad)
   where
-    all_mxs_excluded :: Forward -> IO Bool
-    all_mxs_excluded (Forward addr _) =
-      case (split "@" addr) of
-        (_:domain_part:[]) -> do
-          fw_mxs <- lookup_mxs (BS.pack domain_part)
-          let norm_mxs = map (normalize . BS.pack) mxs
-          if (norm_mxs \\ fw_mxs) == [] then return False else return True
-        _ -> return True -- Report it if we can't figure out the domain.
+    -- If we don't normalize these first, comparison (isSubsetOf)
+    -- doesn't work so great.
+    mx_set = fromList (map normalize_string_domain mxs)
+
+    -- We perform a lookup using a normalized key, so we'd better
+    -- normalize the keys in the map first!
+    normal_mxmap = mapKeys normalize_string_domain domain_mx_map
+
+    is_bad :: Forward -> Bool
+    is_bad f =
+      case (address_domain f) of
+        Nothing -> False -- Do **NOT** drop these.
+        Just d  -> case (Map.lookup (normalize_string_domain d) normal_mxmap) of
+                     Nothing -> False -- No domain MX? Don't drop.
+
+                     -- We need to normalize the set of MXes for the
+                     -- domain, too.
+                     Just dmxs ->
+                       let ndmxs = (Set.map normalize_string_domain dmxs)
+                       in
+                         ndmxs `isSubsetOf` mx_set
 
 
 -- | Given a connection and a 'Configuration', produces the report as
@@ -127,11 +224,28 @@ report cfg conn = do
   domains <- get_domain_list conn (domain_query cfg)
   forwards <- get_forward_list conn (forward_query cfg)
 
-  valid_forwards <- filter_by_mx (exclude_mx cfg) forwards
-  let remote_forwards = find_remote_forwards domains valid_forwards
-  let forward_strings = map format_forward remote_forwards
+  -- valid_forwards are those not excluded based on their address's MXes.
+  --
+  -- WARNING: Don't do MX lookups if the exclude list is empty! It
+  -- wastes a ton of time!
+  --
+  -- Don't ask why, but this doesn't work if you factor out the
+  -- "return" below.
+  --
+  let exclude_mx_list = exclude_mx cfg
+  valid_forwards <- if null (get_mxs exclude_mx_list)
+                    then return forwards
+                    else do
+                      domain_mxs <- mx_set_map domains
+                      return $ dropby_mxlist exclude_mx_list domain_mxs forwards
+
+  let remote_forwards = dropby_goto_domains domains valid_forwards
+  let forward_strings = map pretty_print remote_forwards
 
-  return $ (join "\n" forward_strings) ++ "\n"
+  -- Don't append the final newline if there's nothing to report.
+  return $ if (null forward_strings)
+           then ""
+           else (join "\n" forward_strings) ++ "\n"
 
 
 
@@ -154,4 +268,5 @@ test_example1 =
     expected = "user1@example.com -> user1@example.net\n" ++
                "user2@example.com -> user1@example.org\n" ++
                "user2@example.com -> user2@example.org\n" ++
-               "user2@example.com -> user3@example.org\n"
+               "user2@example.com -> user3@example.org\n" ++
+               "user7@example.com -> user8@example.net\n"