]> gitweb.michael.orlitzky.com - hath.git/blobdiff - src/Main.hs
Bump dns dependency to 1.*, and update DNS module.
[hath.git] / src / Main.hs
index 12511b61463b6d04a721434cbb7e7d8b09db15e8..5afadcdfbc80e80fb1b4a2213afacce6fc836809 100644 (file)
@@ -1,32 +1,39 @@
+module Main
+where
+
+import Control.Concurrent.ParallelIO.Global ( stopGlobalPool )
 import Control.Monad (unless, when)
 import Control.Monad (unless, when)
+import qualified Data.ByteString.Char8 as BS (intercalate, pack, unpack)
 import Data.List ((\\), intercalate)
 import Data.Maybe (catMaybes, isNothing)
 import Data.String.Utils (splitWs)
 import System.Exit (ExitCode(..), exitSuccess, exitWith)
 import System.IO (stderr, hPutStrLn)
 import Data.List ((\\), intercalate)
 import Data.Maybe (catMaybes, isNothing)
 import Data.String.Utils (splitWs)
 import System.Exit (ExitCode(..), exitSuccess, exitWith)
 import System.IO (stderr, hPutStrLn)
+import Text.Read (readMaybe)
+
+import Cidr (
+  Cidr(..),
+  combine_all,
+  enumerate,
+  max_octet1,
+  max_octet2,
+  max_octet3,
+  max_octet4,
+  min_octet1,
+  min_octet2,
+  min_octet3,
+  min_octet4 )
+import CommandLine (
+  help_set,
+  help_text,
+  input_function,
+  Mode(..),
+  parse_errors,
+  parse_mode )
+import DNS (Domain, PTRResult, lookup_ptrs)
+import ExitCodes ( exit_args_parse_failed, exit_invalid_cidr )
+import Octet ()
 
 
-import Cidr (Cidr(..),
-             cidr_from_string,
-             combine_all,
-             max_octet1,
-             max_octet2,
-             max_octet3,
-             max_octet4,
-             min_octet1,
-             min_octet2,
-             min_octet3,
-             min_octet4 )
-
-import CommandLine (help_set,
-                    help_text,
-                    input_function,
-                    Mode(..),
-                    parse_errors,
-                    parse_mode)
-
-import ExitCodes
-import Octet
-    
 
 -- | A regular expression that matches a non-address character.
 non_addr_char :: String
 
 -- | A regular expression that matches a non-address character.
 non_addr_char :: String
@@ -49,6 +56,7 @@ addr_barrier x = non_addr_char ++ x ++ non_addr_char
 --      max values.
 --   4. Join the regexes from step 3 with regexes matching periods.
 --   5. Stick an address boundary on either side of the result.
 --      max values.
 --   4. Join the regexes from step 3 with regexes matching periods.
 --   5. Stick an address boundary on either side of the result.
+--
 cidr_to_regex :: Cidr.Cidr -> String
 cidr_to_regex cidr =
     addr_barrier (intercalate "\\." [range1, range2, range3, range4])
 cidr_to_regex :: Cidr.Cidr -> String
 cidr_to_regex cidr =
     addr_barrier (intercalate "\\." [range1, range2, range3, range4])
@@ -57,14 +65,14 @@ cidr_to_regex cidr =
       range2 = numeric_range min2 max2
       range3 = numeric_range min3 max3
       range4 = numeric_range min4 max4
       range2 = numeric_range min2 max2
       range3 = numeric_range min3 max3
       range4 = numeric_range min4 max4
-      min1   = octet_to_int (min_octet1 cidr)
-      min2   = octet_to_int (min_octet2 cidr)
-      min3   = octet_to_int (min_octet3 cidr)
-      min4   = octet_to_int (min_octet4 cidr)
-      max1   = octet_to_int (max_octet1 cidr)
-      max2   = octet_to_int (max_octet2 cidr)
-      max3   = octet_to_int (max_octet3 cidr)
-      max4   = octet_to_int (max_octet4 cidr)
+      min1   = fromEnum (min_octet1 cidr)
+      min2   = fromEnum (min_octet2 cidr)
+      min3   = fromEnum (min_octet3 cidr)
+      min4   = fromEnum (min_octet4 cidr)
+      max1   = fromEnum (max_octet1 cidr)
+      max2   = fromEnum (max_octet2 cidr)
+      max3   = fromEnum (max_octet3 cidr)
+      max4   = fromEnum (max_octet4 cidr)
 
 
 
 
 
 
@@ -106,7 +114,7 @@ main = do
   input <- inputfunc
 
   let cidr_strings = splitWs input
   input <- inputfunc
 
   let cidr_strings = splitWs input
-  let cidrs = map cidr_from_string cidr_strings
+  let cidrs = map readMaybe cidr_strings
 
   when (any isNothing cidrs) $ do
     putStrLn "Error: not valid CIDR notation."
 
   when (any isNothing cidrs) $ do
     putStrLn "Error: not valid CIDR notation."
@@ -122,20 +130,41 @@ main = do
     Regex -> do
       let regexes = map cidr_to_regex valid_cidrs
       putStrLn $ alternate regexes
     Regex -> do
       let regexes = map cidr_to_regex valid_cidrs
       putStrLn $ alternate regexes
-    Reduce -> do
-      _ <- mapM print (combine_all valid_cidrs)
-      return ()
-    Dupe -> do
-       _ <- mapM print dupes
-       return ()
+    Reduce ->
+      mapM_ print (combine_all valid_cidrs)
+    Dupe ->
+       mapM_ print dupes
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
     Diff -> do
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
     Diff -> do
-       _ <- mapM putStrLn deletions
-       _ <- mapM putStrLn additions
-       return ()
+       mapM_ putStrLn deletions
+       mapM_ putStrLn additions
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
          deletions = map (\s -> '-' : (show s)) dupes
          newcidrs = (combine_all valid_cidrs) \\ valid_cidrs
          additions = map (\s -> '+' : (show s)) newcidrs
        where
          dupes = valid_cidrs \\ (combine_all valid_cidrs)
          deletions = map (\s -> '-' : (show s)) dupes
          newcidrs = (combine_all valid_cidrs) \\ valid_cidrs
          additions = map (\s -> '+' : (show s)) newcidrs
+    List -> do
+      let combined_cidrs = combine_all valid_cidrs
+      let addrs = concatMap enumerate combined_cidrs
+      mapM_ print addrs
+    Reverse -> do
+      let combined_cidrs = combine_all valid_cidrs
+      let addrs = concatMap enumerate combined_cidrs
+      let addr_bytestrings = map (BS.pack . show) addrs
+      ptrs <- lookup_ptrs addr_bytestrings
+      let pairs = zip addr_bytestrings ptrs
+      mapM_ (putStrLn . show_pair) pairs
+
+  stopGlobalPool
+
+  where
+    show_pair :: (Domain, PTRResult) -> String
+    show_pair (s, eds) =
+      (BS.unpack s) ++ ": " ++ results
+      where
+        space = BS.pack " "
+        results =
+          case eds of
+            Left err -> "ERROR (" ++ (show err) ++ ")"
+            Right ds -> BS.unpack $ BS.intercalate space ds