X-Git-Url: http://gitweb.michael.orlitzky.com/?a=blobdiff_plain;f=src%2FMain.hs;h=4774c8778e5904a42378e53e2673ef112f3cdef7;hb=2404313e648301064041c12fdab8d2f976c26a64;hp=12511b61463b6d04a721434cbb7e7d8b09db15e8;hpb=3c9316fed6fd100be9a5e1f8d72db6534fb163cd;p=hath.git diff --git a/src/Main.hs b/src/Main.hs index 12511b6..4774c87 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,34 +1,36 @@ -import Control.Monad (unless, when) +module Main +where + +import Control.Monad (when) import Data.List ((\\), intercalate) +import qualified Data.List as List (sort) import Data.Maybe (catMaybes, isNothing) -import Data.String.Utils (splitWs) -import System.Exit (ExitCode(..), exitSuccess, exitWith) +import System.Exit (ExitCode( ExitFailure ), exitWith) import System.IO (stderr, hPutStrLn) +import Text.Read (readMaybe) + +import Cidr ( + Cidr(), + combine_all, + enumerate, + max_octet1, + max_octet2, + max_octet3, + max_octet4, + min_octet1, + min_octet2, + min_octet3, + min_octet4 ) +import qualified Cidr ( normalize ) +import CommandLine( + Args( Regexed, Reduced, Duped, Diffed, Listed, barriers, normalize, sort ), + get_args ) +import ExitCodes ( exit_invalid_cidr ) +import Octet () -import Cidr (Cidr(..), - cidr_from_string, - combine_all, - max_octet1, - max_octet2, - max_octet3, - max_octet4, - min_octet1, - min_octet2, - min_octet3, - min_octet4 ) - -import CommandLine (help_set, - help_text, - input_function, - Mode(..), - parse_errors, - parse_mode) - -import ExitCodes -import Octet - -- | A regular expression that matches a non-address character. +-- non_addr_char :: String non_addr_char = "[^\\.0-9]" @@ -36,8 +38,9 @@ non_addr_char = "[^\\.0-9]" -- | Add non_addr_chars on either side of the given String. This -- prevents (for example) the regex '127.0.0.1' from matching -- '127.0.0.100'. -addr_barrier :: String -> String -addr_barrier x = non_addr_char ++ x ++ non_addr_char +-- +add_barriers :: String -> String +add_barriers x = non_addr_char ++ x ++ non_addr_char -- | The magic happens here. We take a CIDR String as an argument, and @@ -48,37 +51,47 @@ addr_barrier x = non_addr_char ++ x ++ non_addr_char -- 3. Generate a regex matching every value between those min and -- max values. -- 4. Join the regexes from step 3 with regexes matching periods. --- 5. Stick an address boundary on either side of the result. -cidr_to_regex :: Cidr.Cidr -> String -cidr_to_regex cidr = - addr_barrier (intercalate "\\." [range1, range2, range3, range4]) +-- 5. Stick an address boundary on either side of the result if +-- use_barriers is True. +-- +cidr_to_regex :: Bool -> Cidr -> String +cidr_to_regex use_barriers cidr = + let f = if use_barriers then add_barriers else id in + f (intercalate "\\." [range1, range2, range3, range4]) where range1 = numeric_range min1 max1 range2 = numeric_range min2 max2 range3 = numeric_range min3 max3 range4 = numeric_range min4 max4 - min1 = octet_to_int (min_octet1 cidr) - min2 = octet_to_int (min_octet2 cidr) - min3 = octet_to_int (min_octet3 cidr) - min4 = octet_to_int (min_octet4 cidr) - max1 = octet_to_int (max_octet1 cidr) - max2 = octet_to_int (max_octet2 cidr) - max3 = octet_to_int (max_octet3 cidr) - max4 = octet_to_int (max_octet4 cidr) + min1 = fromEnum (min_octet1 cidr) + min2 = fromEnum (min_octet2 cidr) + min3 = fromEnum (min_octet3 cidr) + min4 = fromEnum (min_octet4 cidr) + max1 = fromEnum (max_octet1 cidr) + max2 = fromEnum (max_octet2 cidr) + max3 = fromEnum (max_octet3 cidr) + max4 = fromEnum (max_octet4 cidr) -- | Take a list of Strings, and return a regular expression matching -- any of them. +-- alternate :: [String] -> String alternate terms = "(" ++ (intercalate "|" terms) ++ ")" -- | Take two Ints as parameters, and return a regex matching any -- integer between them (inclusive). +-- +-- IMPORTANT: we match from max to min so that if e.g. the last +-- octet is '255', we want '255' to match before '2' in the regex +-- (255|254|...|3|2|1) which does not happen if we use +-- (1|2|3|...|254|255). +-- numeric_range :: Int -> Int -> String numeric_range x y = - alternate (map show [lower..upper]) + alternate (map show $ reverse [lower..upper]) where lower = minimum [x,y] upper = maximum [x,y] @@ -86,56 +99,53 @@ numeric_range x y = main :: IO () main = do - -- First, check for any errors that occurred while parsing - -- the command line options. - errors <- CommandLine.parse_errors - unless (null errors) $ do - hPutStrLn stderr (concat errors) - putStrLn CommandLine.help_text - exitWith (ExitFailure exit_args_parse_failed) - - -- Next, check to see if the 'help' option was passed to the - -- program. If it was, display the help, and exit successfully. - help_opt_set <- CommandLine.help_set - when help_opt_set $ do - putStrLn CommandLine.help_text - exitSuccess - - -- The input function we receive here should know what to read. - inputfunc <- (CommandLine.input_function) - input <- inputfunc - - let cidr_strings = splitWs input - let cidrs = map cidr_from_string cidr_strings + args <- get_args + + -- This reads stdin. + input <- getContents + + let cidr_strings = words input + let cidrs = map readMaybe cidr_strings :: [Maybe Cidr] when (any isNothing cidrs) $ do - putStrLn "Error: not valid CIDR notation." + hPutStrLn stderr "ERROR: not valid CIDR notation:" + + -- Output the bad lines, safely. + let pairs = zip cidr_strings cidrs + + let print_pair :: (String, Maybe Cidr) -> IO () + print_pair (x, Nothing) = hPutStrLn stderr (" * " ++ x) + print_pair (_, _) = return () + + mapM_ print_pair pairs exitWith (ExitFailure exit_invalid_cidr) -- Filter out only the valid ones. let valid_cidrs = catMaybes cidrs - -- Get the mode of operation. - mode <- CommandLine.parse_mode - - case mode of - Regex -> do - let regexes = map cidr_to_regex valid_cidrs + case args of + Regexed{} -> do + let cidrs' = combine_all valid_cidrs + let regexes = map (cidr_to_regex (barriers args)) cidrs' putStrLn $ alternate regexes - Reduce -> do - _ <- mapM print (combine_all valid_cidrs) - return () - Dupe -> do - _ <- mapM print dupes - return () + Reduced{} -> do + -- Pre-normalize all CIDRs if the user asked for it. + let nrml_func = if (normalize args) then Cidr.normalize else id + let sort_func = if (sort args) then List.sort else id :: [Cidr] -> [Cidr] + mapM_ (print . nrml_func) (sort_func $ combine_all valid_cidrs) + Duped{} -> + mapM_ print dupes where dupes = valid_cidrs \\ (combine_all valid_cidrs) - Diff -> do - _ <- mapM putStrLn deletions - _ <- mapM putStrLn additions - return () + Diffed{} -> do + mapM_ putStrLn deletions + mapM_ putStrLn additions where dupes = valid_cidrs \\ (combine_all valid_cidrs) deletions = map (\s -> '-' : (show s)) dupes newcidrs = (combine_all valid_cidrs) \\ valid_cidrs additions = map (\s -> '+' : (show s)) newcidrs + Listed{} -> do + let combined_cidrs = combine_all valid_cidrs + let addrs = concatMap enumerate combined_cidrs + mapM_ print addrs