module Main where import Control.Concurrent.ParallelIO.Global ( parallel, stopGlobalPool ) import Control.Monad (unless, when) import qualified Data.ByteString.Char8 as BS (intercalate, pack, unpack) import Data.List ((\\), intercalate) import Data.Maybe (catMaybes, isNothing) import Data.String.Utils (splitWs) import System.Exit (ExitCode(..), exitSuccess, exitWith) import System.IO (stderr, hPutStrLn) import Text.Read (readMaybe) import Cidr ( Cidr(..), combine_all, enumerate, max_octet1, max_octet2, max_octet3, max_octet4, min_octet1, min_octet2, min_octet3, min_octet4 ) import CommandLine ( help_set, help_text, input_function, Mode(..), parse_errors, parse_mode ) import DNS (Domain, lookup_ptrs) import ExitCodes ( exit_args_parse_failed, exit_invalid_cidr ) import Octet () -- | A regular expression that matches a non-address character. non_addr_char :: String non_addr_char = "[^\\.0-9]" -- | Add non_addr_chars on either side of the given String. This -- prevents (for example) the regex '127.0.0.1' from matching -- '127.0.0.100'. addr_barrier :: String -> String addr_barrier x = non_addr_char ++ x ++ non_addr_char -- | The magic happens here. We take a CIDR String as an argument, and -- return the equivalent regular expression. We do this as follows: -- -- 1. Compute the minimum possible value of each octet. -- 2. Compute the maximum possible value of each octet. -- 3. Generate a regex matching every value between those min and -- max values. -- 4. Join the regexes from step 3 with regexes matching periods. -- 5. Stick an address boundary on either side of the result. cidr_to_regex :: Cidr.Cidr -> String cidr_to_regex cidr = addr_barrier (intercalate "\\." [range1, range2, range3, range4]) where range1 = numeric_range min1 max1 range2 = numeric_range min2 max2 range3 = numeric_range min3 max3 range4 = numeric_range min4 max4 min1 = fromEnum (min_octet1 cidr) min2 = fromEnum (min_octet2 cidr) min3 = fromEnum (min_octet3 cidr) min4 = fromEnum (min_octet4 cidr) max1 = fromEnum (max_octet1 cidr) max2 = fromEnum (max_octet2 cidr) max3 = fromEnum (max_octet3 cidr) max4 = fromEnum (max_octet4 cidr) -- | Take a list of Strings, and return a regular expression matching -- any of them. alternate :: [String] -> String alternate terms = "(" ++ (intercalate "|" terms) ++ ")" -- | Take two Ints as parameters, and return a regex matching any -- integer between them (inclusive). numeric_range :: Int -> Int -> String numeric_range x y = alternate (map show [lower..upper]) where lower = minimum [x,y] upper = maximum [x,y] main :: IO () main = do -- First, check for any errors that occurred while parsing -- the command line options. errors <- CommandLine.parse_errors unless (null errors) $ do hPutStrLn stderr (concat errors) putStrLn CommandLine.help_text exitWith (ExitFailure exit_args_parse_failed) -- Next, check to see if the 'help' option was passed to the -- program. If it was, display the help, and exit successfully. help_opt_set <- CommandLine.help_set when help_opt_set $ do putStrLn CommandLine.help_text exitSuccess -- The input function we receive here should know what to read. inputfunc <- (CommandLine.input_function) input <- inputfunc let cidr_strings = splitWs input let cidrs = map readMaybe cidr_strings when (any isNothing cidrs) $ do putStrLn "Error: not valid CIDR notation." exitWith (ExitFailure exit_invalid_cidr) -- Filter out only the valid ones. let valid_cidrs = catMaybes cidrs -- Get the mode of operation. mode <- CommandLine.parse_mode case mode of Regex -> do let regexes = map cidr_to_regex valid_cidrs putStrLn $ alternate regexes Reduce -> mapM_ print (combine_all valid_cidrs) Dupe -> mapM_ print dupes where dupes = valid_cidrs \\ (combine_all valid_cidrs) Diff -> do mapM_ putStrLn deletions mapM_ putStrLn additions where dupes = valid_cidrs \\ (combine_all valid_cidrs) deletions = map (\s -> '-' : (show s)) dupes newcidrs = (combine_all valid_cidrs) \\ valid_cidrs additions = map (\s -> '+' : (show s)) newcidrs List -> do let combined_cidrs = combine_all valid_cidrs let addrs = concatMap enumerate combined_cidrs mapM_ print addrs Reverse -> do let combined_cidrs = combine_all valid_cidrs let addrs = concatMap enumerate combined_cidrs let addr_bytestrings = map (BS.pack . show) addrs ptrs <- lookup_ptrs addr_bytestrings let pairs = zip addr_bytestrings ptrs _ <- parallel (map (putStrLn . show_pair) pairs) return () stopGlobalPool where show_pair :: (Domain, Maybe [Domain]) -> String show_pair (s, mds) = (BS.unpack s) ++ ": " ++ results where space = BS.pack " " results = case mds of Nothing -> "" Just ds -> BS.unpack $ BS.intercalate space ds