]> gitweb.michael.orlitzky.com - hath.git/blobdiff - src/Main.hs
Rewrite command-line parsing to use cmdargs.
[hath.git] / src / Main.hs
index dd7eefe963e5aedb5c0b19d486f1cc510d470fac..45705be673f087d76d8e64ea61f16c086ac98632 100644 (file)
@@ -1,32 +1,33 @@
+module Main
+where
+
+import Control.Concurrent.ParallelIO.Global ( stopGlobalPool )
 import Control.Monad (when)
 import Control.Monad (when)
-import Data.List ((\\), intercalate, intersperse)
+import qualified Data.ByteString.Char8 as BS (intercalate, pack, unpack)
+import Data.List ((\\), intercalate)
 import Data.Maybe (catMaybes, isNothing)
 import Data.String.Utils (splitWs)
 import System.Exit (ExitCode(..), exitWith)
 import System.IO (stderr, hPutStrLn)
 import Data.Maybe (catMaybes, isNothing)
 import Data.String.Utils (splitWs)
 import System.Exit (ExitCode(..), exitWith)
 import System.IO (stderr, hPutStrLn)
+import Text.Read (readMaybe)
+
+import Cidr (
+  Cidr(..),
+  combine_all,
+  enumerate,
+  max_octet1,
+  max_octet2,
+  max_octet3,
+  max_octet4,
+  min_octet1,
+  min_octet2,
+  min_octet3,
+  min_octet4 )
+import CommandLine (Args(..), get_args)
+import DNS (Domain, PTRResult, lookup_ptrs)
+import ExitCodes ( exit_invalid_cidr )
+import Octet ()
 
 
-import Cidr (Cidr(..),
-             cidr_from_string,
-             combine_all,
-             max_octet1,
-             max_octet2,
-             max_octet3,
-             max_octet4,
-             min_octet1,
-             min_octet2,
-             min_octet3,
-             min_octet4 )
-
-import CommandLine (help_set,
-                    help_text,
-                    input_function,
-                    Mode(..),
-                    parse_errors,
-                    parse_mode)
-
-import ExitCodes
-import Octet
-    
 
 -- | A regular expression that matches a non-address character.
 non_addr_char :: String
 
 -- | A regular expression that matches a non-address character.
 non_addr_char :: String
@@ -36,8 +37,8 @@ non_addr_char = "[^\\.0-9]"
 -- | Add non_addr_chars on either side of the given String. This
 --   prevents (for example) the regex '127.0.0.1' from matching
 --   '127.0.0.100'.
 -- | Add non_addr_chars on either side of the given String. This
 --   prevents (for example) the regex '127.0.0.1' from matching
 --   '127.0.0.100'.
-addr_barrier :: String -> String
-addr_barrier x = non_addr_char ++ x ++ non_addr_char
+add_barriers :: String -> String
+add_barriers x = non_addr_char ++ x ++ non_addr_char
 
 
 -- | The magic happens here. We take a CIDR String as an argument, and
 
 
 -- | The magic happens here. We take a CIDR String as an argument, and
@@ -49,29 +50,31 @@ addr_barrier x = non_addr_char ++ x ++ non_addr_char
 --      max values.
 --   4. Join the regexes from step 3 with regexes matching periods.
 --   5. Stick an address boundary on either side of the result.
 --      max values.
 --   4. Join the regexes from step 3 with regexes matching periods.
 --   5. Stick an address boundary on either side of the result.
-cidr_to_regex :: Cidr.Cidr -> String
-cidr_to_regex cidr =
-    addr_barrier (intercalate "\\." [range1, range2, range3, range4])
+--
+cidr_to_regex :: Bool -> Cidr.Cidr -> String
+cidr_to_regex use_barriers cidr =
+    let f = if use_barriers then add_barriers else id in
+      f (intercalate "\\." [range1, range2, range3, range4])
     where
       range1 = numeric_range min1 max1
       range2 = numeric_range min2 max2
       range3 = numeric_range min3 max3
       range4 = numeric_range min4 max4
     where
       range1 = numeric_range min1 max1
       range2 = numeric_range min2 max2
       range3 = numeric_range min3 max3
       range4 = numeric_range min4 max4
-      min1   = octet_to_int (min_octet1 cidr)
-      min2   = octet_to_int (min_octet2 cidr)
-      min3   = octet_to_int (min_octet3 cidr)
-      min4   = octet_to_int (min_octet4 cidr)
-      max1   = octet_to_int (max_octet1 cidr)
-      max2   = octet_to_int (max_octet2 cidr)
-      max3   = octet_to_int (max_octet3 cidr)
-      max4   = octet_to_int (max_octet4 cidr)
+      min1   = fromEnum (min_octet1 cidr)
+      min2   = fromEnum (min_octet2 cidr)
+      min3   = fromEnum (min_octet3 cidr)
+      min4   = fromEnum (min_octet4 cidr)
+      max1   = fromEnum (max_octet1 cidr)
+      max2   = fromEnum (max_octet2 cidr)
+      max3   = fromEnum (max_octet3 cidr)
+      max4   = fromEnum (max_octet4 cidr)
 
 
 
 -- | Take a list of Strings, and return a regular expression matching
 --   any of them.
 alternate :: [String] -> String
 
 
 
 -- | Take a list of Strings, and return a regular expression matching
 --   any of them.
 alternate :: [String] -> String
-alternate terms = "(" ++ (concat (intersperse "|" terms)) ++ ")"
+alternate terms = "(" ++ (intercalate "|" terms) ++ ")"
 
 
 -- | Take two Ints as parameters, and return a regex matching any
 
 
 -- | Take two Ints as parameters, and return a regex matching any
@@ -86,56 +89,68 @@ numeric_range x y =
 
 main :: IO ()
 main = do
 
 main :: IO ()
 main = do
-  -- First, check for any errors that occurred while parsing
-  -- the command line options.
-  errors <- CommandLine.parse_errors
-  when ((not . null) errors) $ do
-    hPutStrLn stderr (concat errors)
-    putStrLn CommandLine.help_text
-    exitWith (ExitFailure exit_args_parse_failed)
-
-  -- Next, check to see if the 'help' option was passed to the
-  -- program. If it was, display the help, and exit successfully.
-  help_opt_set <- CommandLine.help_set
-  when help_opt_set $ do
-    putStrLn CommandLine.help_text
-    exitWith ExitSuccess
-
-  -- The input function we receive here should know what to read.
-  inputfunc <- (CommandLine.input_function)
-  input <- inputfunc
+  args <- get_args
+
+  -- This reads stdin.
+  input <- getContents
 
   let cidr_strings = splitWs input
 
   let cidr_strings = splitWs input
-  let cidrs = map cidr_from_string cidr_strings
+  let cidrs = map readMaybe cidr_strings
 
   when (any isNothing cidrs) $ do
 
   when (any isNothing cidrs) $ do
-    putStrLn "Error: not valid CIDR notation."
+    hPutStrLn stderr "ERROR: not valid CIDR notation:"
+
+    -- Output the bad lines, safely.
+    let pairs = zip cidr_strings cidrs
+    let print_pair (x, Nothing) = hPutStrLn stderr ("  * " ++ x)
+        print_pair (_, _) = return ()
+
+    mapM_ print_pair pairs
     exitWith (ExitFailure exit_invalid_cidr)
 
   -- Filter out only the valid ones.
   let valid_cidrs = catMaybes cidrs
 
     exitWith (ExitFailure exit_invalid_cidr)
 
   -- Filter out only the valid ones.
   let valid_cidrs = catMaybes cidrs
 
-  -- Get the mode of operation.
-  mode <- CommandLine.parse_mode
-
-  case mode of
-    Regex -> do
-      let regexes = map cidr_to_regex valid_cidrs
+  case args of
+    Regexed{} -> do
+      let cidrs' = combine_all valid_cidrs
+      let regexes = map (cidr_to_regex (barriers args)) cidrs'
       putStrLn $ alternate regexes
       putStrLn $ alternate regexes
-    Reduce -> do
-      _ <- mapM (putStrLn . show) (combine_all valid_cidrs)
-      return ()
-    Dupe -> do
-       _ <- mapM (putStrLn . show) dupes
-       return ()
+    Reduced{} ->
+      mapM_ print (combine_all valid_cidrs)
+    Duped{} ->
+       mapM_ print dupes
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
-    Diff -> do
-       _ <- mapM putStrLn deletions
-       _ <- mapM putStrLn additions
-       return ()
+    Diffed{} -> do
+       mapM_ putStrLn deletions
+       mapM_ putStrLn additions
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
-         deletions = map (\s -> "-" ++ (show s)) dupes
+         deletions = map (\s -> '-' : (show s)) dupes
          newcidrs = (combine_all valid_cidrs) \\ valid_cidrs
          newcidrs = (combine_all valid_cidrs) \\ valid_cidrs
-         additions = map (\s -> "+" ++ (show s)) newcidrs
+         additions = map (\s -> '+' : (show s)) newcidrs
+    Listed{} -> do
+      let combined_cidrs = combine_all valid_cidrs
+      let addrs = concatMap enumerate combined_cidrs
+      mapM_ print addrs
+    Reversed{} -> do
+      let combined_cidrs = combine_all valid_cidrs
+      let addrs = concatMap enumerate combined_cidrs
+      let addr_bytestrings = map (BS.pack . show) addrs
+      ptrs <- lookup_ptrs addr_bytestrings
+      let pairs = zip addr_bytestrings ptrs
+      mapM_ (putStrLn . show_pair) pairs
+
+  stopGlobalPool
+
+  where
+    show_pair :: (Domain, PTRResult) -> String
+    show_pair (s, eds) =
+      (BS.unpack s) ++ ": " ++ results
+      where
+        space = BS.pack " "
+        results =
+          case eds of
+            Left err -> "ERROR (" ++ (show err) ++ ")"
+            Right ds -> BS.unpack $ BS.intercalate space ds