Introduce a NormalDomain newtype to ensure comparisons are made safely.
authorMichael Orlitzky <michael@orlitzky.com>
Fri, 28 Nov 2014 16:02:11 +0000 (11:02 -0500)
committerMichael Orlitzky <michael@orlitzky.com>
Fri, 28 Nov 2014 16:02:11 +0000 (11:02 -0500)
doc/TODO [deleted file]
src/DNS.hs
src/Forward.hs
src/Report.hs

diff --git a/doc/TODO b/doc/TODO
deleted file mode 100644 (file)
index 0e57166..0000000
--- a/doc/TODO
+++ /dev/null
@@ -1,5 +0,0 @@
-1. Use the type system to prevent all the bugs that arise from
-  (de)normalization of domain names. We should use newtype wrappers
-  around string domain names that we can only create by using a
-  normalize_foo constructor. That way we can make sure that we never
-  do a comparison on a denormalized name.
index 8d94de51e9c7772d33adcb6cf5ec3e40c9479105..556c68675eecb4e4a6293117143e5358255f28bc 100644 (file)
@@ -1,7 +1,8 @@
 module DNS (
   MxSetMap,
+  NormalDomain,
   mx_set_map,
-  normalize_string_domain )
+  normalize_string )
 where
 
 import qualified Data.ByteString.Char8 as BS ( pack, unpack )
@@ -18,13 +19,31 @@ import Network.DNS (
   normalize,
   withResolver )
 
--- | A map from domain names (represented as 'String's) to sets of
---   mail exchanger names (also represented as 'String's).
+-- | A type-safe wrapper around a domain name (represented as a
+--   string) that ensures we've created it by calling
+--   'normalize_string'. This prevents us from making
+--   comparisons on un-normalized 'Domain's or 'String's.
+--
+newtype NormalDomain =
+  NormalDomain String
+  deriving ( Eq, Ord, Show )
+
+
+-- | A set of mail exchanger names, represented as 'String's. The use
+--   of 'NormalDomain' prevents us from constructing a set of names
+--   that aren't normalized first.
 --
-type MxSetMap = Map String MxSet
+type MxSet = Set NormalDomain
 
--- | A set of mail exchanger names, represented as 'String's.
-type MxSet = Set String
+
+-- | A map from domain names (represented as 'String's) to sets of
+--   mail exchanger names (also represented as 'String's). The use of
+--   'NormalDomain' in the key prevents us from using keys that aren't
+--   normalized; this is important because we'll be using them for
+--   lookups and want e.g. \"foo.com\" and \"FOO.com\" to look up the
+--   same MX records.
+--
+type MxSetMap = Map NormalDomain MxSet
 
 
 -- | Normalize a domain name string by converting to a 'Domain',
@@ -32,11 +51,11 @@ type MxSet = Set String
 --
 --   ==== __Examples__
 --
---   >>> normalize_string_domain "ExAMplE.com"
---   "example.com."
+--   >>> normalize_string "ExAMplE.com"
+--   NormalDomain "example.com."
 --
-normalize_string_domain :: String -> String
-normalize_string_domain = BS.unpack . normalize . BS.pack
+normalize_string :: String -> NormalDomain
+normalize_string = NormalDomain . BS.unpack . normalize . BS.pack
 
 
 -- | Retrieve all MX records for the given domain. This is somewhat
@@ -81,16 +100,20 @@ mx_set_map domains = do
     --   domain name in the first component and a set of its mail
     --   exchangers (as strings) in the second component.
     --
-    make_pair :: Domain -> IO (String, Set String)
+    make_pair :: Domain -> IO (NormalDomain, Set NormalDomain)
     make_pair domain = do
       -- Lookup the @domain@'s MX records.
       mx_list <- lookup_mxs domain
 
-      -- Now convert the MX records *back* to strings.
-      let string_mx_list = map BS.unpack mx_list
+      -- Now convert the MX records *back* to strings, and then to
+      -- NormalDomains
+      let normal_mx_list = map (normalize_string . BS.unpack) mx_list
+
+      -- Convert the list into a set...
+      let normal_mx_set = Set.fromList normal_mx_list
 
-      -- Convert the list into a set
-      let string_mx_set = Set.fromList string_mx_list
+      -- The lookup key.
+      let normal_domain = normalize_string $ BS.unpack domain
 
       -- Finally, construct the pair and return it.
-      return (BS.unpack domain, string_mx_set)
+      return (normal_domain, normal_mx_set)
index 8d977fac299e46017850efdcc625fa99a9001ca3..d4936eda5c2169fbbcdc3130689125cb111195cd 100644 (file)
@@ -11,7 +11,7 @@ where
 
 import Data.String.Utils ( split, strip )
 
-import DNS ( normalize_string_domain )
+import DNS ( NormalDomain, normalize_string )
 
 -- | Type synonym to make the signatures below a little more clear.
 --   WARNING: Also defined in the "Report" module.
@@ -220,27 +220,23 @@ domain_part address =
     parts = split "@" address
 
 
--- | Given a list of 'Domain's @domains@ and a list of 'Forward's
+-- | Given a list of 'NormalDomain's @domains@ and a list of 'Forward's
 --   @forwards@, filter out all elements of @forwards@ that have a
 --   goto domain in the list of @domains@.
 --
 --   ==== __Examples__
 --
---   >>> let ds = ["example.com", "example.net"]
+--   >>> let ds = map normalize_string ["example.com", "example.net"]
 --   >>> let f1 = fwd "a@example.com" "a@example.com"
 --   >>> let f2 = fwd "a@example.com" "a1@example.net"
 --   >>> let f3 = fwd "a@example.com" "a2@example.org"
 --   >>> map pretty_print (dropby_goto_domains ds [f1,f2,f3])
 --   ["a@example.com -> a2@example.org"]
 --
-dropby_goto_domains :: [Domain] -> [Forward] -> [Forward]
-dropby_goto_domains domains =
+dropby_goto_domains :: [NormalDomain] -> [Forward] -> [Forward]
+dropby_goto_domains normal_domains =
   filter (not . is_bad)
   where
-    -- If we don't normalize these first, comparison (i.e. `elem`)
-    -- doesn't work so great.
-    normalized_domains = map normalize_string_domain domains
-
     -- | A 'Forward' is bad if its goto domain appears in the list, or
     --   if we can't figure out its goto domain.
     --
@@ -248,4 +244,5 @@ dropby_goto_domains domains =
     is_bad f =
       case (goto_domain f) of
         Nothing -> True -- Drop these, too.
-        Just d  -> (normalize_string_domain d) `elem` normalized_domains
+        -- Nice, we can't compare unless we normalize @d@!
+        Just d  -> (normalize_string d) `elem` normal_domains
index efeaf65c3fb336d846bbaa2a44f8c9f1401d4e83..82dd07bc8750944fc9e2f8c526005d4f3e9f69a3 100644 (file)
@@ -3,7 +3,6 @@ module Report (
   report_tests )
 where
 
-import Data.Map ( mapKeys )
 import qualified Data.Map as Map ( fromList, lookup )
 import Data.Maybe ( catMaybes, listToMaybe )
 import Data.Set ( isSubsetOf )
@@ -21,7 +20,11 @@ import Test.Tasty ( TestTree, testGroup )
 import Test.Tasty.HUnit ( (@?=), testCase )
 
 import Configuration ( Configuration(..) )
-import DNS ( MxSetMap, mx_set_map, normalize_string_domain )
+import DNS (
+  MxSetMap,
+  NormalDomain,
+  mx_set_map,
+  normalize_string )
 import Forward (
   Forward(..),
   address_domain,
@@ -124,50 +127,41 @@ get_forward_list conn query = do
 --
 --   >>> import Forward ( fwd )
 --   >>> let fwds = [fwd "user1@example.com" "user2@example.net"]
---   >>> let mx_set = Set.fromList ["mx.example.com"]
---   >>> let example_mx_pairs = [("example.com.", mx_set)]
+--   >>> let mx_set = Set.fromList [normalize_string "mx.example.com"]
+--   >>> let example_mx_pairs = [(normalize_string "example.com.", mx_set)]
 --   >>> let mx_map = Map.fromList example_mx_pairs
---   >>> let droplist = MxList ["mx.example.com", "mx2.example.com"]
---   >>> dropby_mxlist droplist mx_map fwds
+--   >>> let droplist = ["mx.example.com", "mx2.example.com"]
+--   >>> let normal_droplist = map normalize_string droplist
+--   >>> dropby_mxlist normal_droplist mx_map fwds
 --   []
 --
 --   This time it shouldn't be dropped, because ["mx.example.com"] is
 --   not contained in ["nope.example.com"]:
 --
 --   >>> let fwds = [fwd "user1@example.com" "user2@example.net"]
---   >>> let mx_set = Set.fromList ["mx.example.com"]
---   >>> let example_mx_pairs = [("example.com.", mx_set)]
+--   >>> let mx_set = Set.fromList [normalize_string "mx.example.com"]
+--   >>> let example_mx_pairs = [(normalize_string "example.com.", mx_set)]
 --   >>> let mx_map = Map.fromList example_mx_pairs
---   >>> let droplist = MxList ["nope.example.com"]
---   >>> map pretty_print (dropby_mxlist droplist mx_map fwds)
+--   >>> let droplist = ["nope.example.com"]
+--   >>> let normal_droplist = map normalize_string droplist
+--   >>> map pretty_print (dropby_mxlist normal_droplist mx_map fwds)
 --   ["user1@example.com -> user2@example.net"]
 --
-dropby_mxlist :: MxList -> MxSetMap -> [Forward] -> [Forward]
-dropby_mxlist (MxList []) _ = id
-dropby_mxlist (MxList mxs) domain_mx_map =
+dropby_mxlist :: [NormalDomain] -> MxSetMap -> [Forward] -> [Forward]
+dropby_mxlist [] _ = id
+dropby_mxlist normal_mxs mx_map =
   filter (not . is_bad)
   where
-    -- If we don't normalize these first, comparison (isSubsetOf)
-    -- doesn't work so great.
-    mx_set = Set.fromList (map normalize_string_domain mxs)
-
-    -- We perform a lookup using a normalized key, so we'd better
-    -- normalize the keys in the map first!
-    normal_mxmap = mapKeys normalize_string_domain domain_mx_map
+    mx_set = Set.fromList normal_mxs
 
     is_bad :: Forward -> Bool
     is_bad f =
       case (address_domain f) of
         Nothing -> False -- Do **NOT** drop these.
-        Just d  -> case (Map.lookup (normalize_string_domain d) normal_mxmap) of
+        Just d  -> case (Map.lookup (normalize_string d) mx_map) of
                      Nothing -> False -- No domain MX? Don't drop.
+                     Just dmxs -> dmxs `isSubsetOf` mx_set
 
-                     -- We need to normalize the set of MXes for the
-                     -- domain, too.
-                     Just dmxs ->
-                       let ndmxs = (Set.map normalize_string_domain dmxs)
-                       in
-                         ndmxs `isSubsetOf` mx_set
 
 
 -- | Given a connection and a 'Configuration', produces the report as
@@ -186,14 +180,17 @@ report cfg conn = do
   -- Don't ask why, but this doesn't work if you factor out the
   -- "return" below.
   --
-  let exclude_mx_list = exclude_mx cfg
-  valid_forwards <- if null (get_mxs exclude_mx_list)
+  let exclude_mx_list = map normalize_string (get_mxs $ exclude_mx cfg)
+  valid_forwards <- if (null exclude_mx_list)
                     then return forwards
                     else do
                       domain_mxs <- mx_set_map domains
                       return $ dropby_mxlist exclude_mx_list domain_mxs forwards
 
-  let remote_forwards = dropby_goto_domains domains valid_forwards
+  -- We need to normalize our domain names before we can pass them to
+  -- dropby_goto_domains.
+  let normal_domains = map normalize_string domains
+  let remote_forwards = dropby_goto_domains normal_domains valid_forwards
   let forward_strings = map pretty_print remote_forwards
 
   -- Don't append the final newline if there's nothing to report.
@@ -237,11 +234,12 @@ test_dropby_mxlist_affects_address :: TestTree
 test_dropby_mxlist_affects_address =
   testCase desc $ do
     let fwds = [fwd "user1@example.com" "user2@example.net"]
-    let mx_set = Set.fromList ["mx.example.net"]
-    let example_mx_pairs = [("example.net.", mx_set)]
+    let mx_set = Set.fromList [normalize_string "mx.example.net"]
+    let example_mx_pairs = [(normalize_string "example.net.", mx_set)]
     let mx_map = Map.fromList example_mx_pairs
-    let droplist = MxList ["mx.example.net", "mx2.example.net"]
-    let actual = dropby_mxlist droplist mx_map fwds
+    let droplist = ["mx.example.net", "mx2.example.net"]
+    let normal_droplist = map normalize_string droplist
+    let actual = dropby_mxlist normal_droplist mx_map fwds
     let expected = fwds
     actual @?= expected
   where
@@ -255,11 +253,12 @@ test_dropby_mxlist_compares_normalized :: TestTree
 test_dropby_mxlist_compares_normalized =
   testCase desc $ do
     let fwds = [fwd "user1@exAmPle.com." "user2@examPle.net"]
-    let mx_set = Set.fromList ["mx.EXAMPLE.com"]
-    let example_mx_pairs = [("Example.com", mx_set)]
+    let mx_set = Set.fromList [normalize_string "mx.EXAMPLE.com"]
+    let example_mx_pairs = [(normalize_string "Example.com", mx_set)]
     let mx_map = Map.fromList example_mx_pairs
-    let droplist = MxList ["mx.EXAMple.com", "mx2.example.COM"]
-    let actual = dropby_mxlist droplist mx_map fwds
+    let droplist = ["mx.EXAMple.com", "mx2.example.COM"]
+    let normal_droplist = map normalize_string droplist
+    let actual = dropby_mxlist normal_droplist mx_map fwds
     let expected = []
     actual @?= expected
   where
@@ -275,10 +274,12 @@ test_dropby_mxlist_requires_subset =
   testCase desc $ do
     let fwds = [fwd "user1@example.com" "user2@example.net"]
     let mx_set = Set.fromList ["mx1.example.com", "mx2.example.com"]
-    let example_mx_pairs = [("example.com.", mx_set)]
+    let normal_mx_set = Set.map normalize_string mx_set
+    let example_mx_pairs = [(normalize_string "example.com.", normal_mx_set)]
     let mx_map = Map.fromList example_mx_pairs
-    let droplist = MxList ["mx1.example.com"]
-    let actual = dropby_mxlist droplist mx_map fwds
+    let droplist = ["mx1.example.com"]
+    let normal_droplist = map normalize_string droplist
+    let actual = dropby_mxlist normal_droplist mx_map fwds
     let expected = fwds
     actual @?= expected
   where