module Main (main) where import Control.Monad (when) import Data.List ((\\), intercalate) import qualified Data.List as List (sort) import Data.Maybe (catMaybes, isNothing) import System.Exit (ExitCode( ExitFailure ), exitWith) import System.IO (stderr, hPutStrLn) import Text.Read (readMaybe) import Cidr ( Cidr(), combine_all, enumerate, max_octet1, max_octet2, max_octet3, max_octet4, min_octet1, min_octet2, min_octet3, min_octet4 ) import qualified Cidr ( normalize ) import CommandLine( Args( Regexed, Reduced, Duped, Diffed, Listed, barriers, normalize, sort ), get_args ) import ExitCodes ( exit_invalid_cidr ) import Octet () -- | A regular expression that matches a non-address character. -- non_addr_char :: String non_addr_char = "[^\\.0-9]" -- | Add non_addr_chars on either side of the given String. This -- prevents (for example) the regex '127.0.0.1' from matching -- '127.0.0.100'. -- add_barriers :: String -> String add_barriers x = non_addr_char ++ x ++ non_addr_char -- | The magic happens here. We take a CIDR String as an argument, and -- return the equivalent regular expression. We do this as follows: -- -- 1. Compute the minimum possible value of each octet. -- 2. Compute the maximum possible value of each octet. -- 3. Generate a regex matching every value between those min and -- max values. -- 4. Join the regexes from step 3 with regexes matching periods. -- 5. Stick an address boundary on either side of the result if -- use_barriers is True. -- cidr_to_regex :: Bool -> Cidr -> String cidr_to_regex use_barriers cidr = let f = if use_barriers then add_barriers else id in f (intercalate "\\." [range1, range2, range3, range4]) where range1 = numeric_range min1 max1 range2 = numeric_range min2 max2 range3 = numeric_range min3 max3 range4 = numeric_range min4 max4 min1 = fromEnum (min_octet1 cidr) min2 = fromEnum (min_octet2 cidr) min3 = fromEnum (min_octet3 cidr) min4 = fromEnum (min_octet4 cidr) max1 = fromEnum (max_octet1 cidr) max2 = fromEnum (max_octet2 cidr) max3 = fromEnum (max_octet3 cidr) max4 = fromEnum (max_octet4 cidr) -- | Take a list of Strings, and return a regular expression matching -- any of them. -- alternate :: [String] -> String alternate terms = "(" ++ (intercalate "|" terms) ++ ")" -- | Take two Ints as parameters, and return a regex matching any -- integer between them (inclusive). -- -- IMPORTANT: we match from max to min so that if e.g. the last -- octet is '255', we want '255' to match before '2' in the regex -- (255|254|...|3|2|1) which does not happen if we use -- (1|2|3|...|254|255). -- numeric_range :: Int -> Int -> String numeric_range x y = alternate (map show $ reverse [lower..upper]) where lower = minimum [x,y] upper = maximum [x,y] main :: IO () main = do args <- get_args -- This reads stdin. input <- getContents let cidr_strings = words input let cidrs = map readMaybe cidr_strings :: [Maybe Cidr] when (any isNothing cidrs) $ do hPutStrLn stderr "ERROR: not valid CIDR notation:" -- Output the bad lines, safely. let pairs = zip cidr_strings cidrs let print_pair :: (String, Maybe Cidr) -> IO () print_pair (x, Nothing) = hPutStrLn stderr (" * " ++ x) print_pair (_, _) = return () mapM_ print_pair pairs exitWith (ExitFailure exit_invalid_cidr) -- Filter out only the valid ones. let valid_cidrs = catMaybes cidrs case args of Regexed{} -> do let cidrs' = combine_all valid_cidrs let regexes = map (cidr_to_regex (barriers args)) cidrs' putStrLn $ alternate regexes Reduced{} -> do -- Pre-normalize all CIDRs if the user asked for it. let nrml_func = if (normalize args) then Cidr.normalize else id let sort_func = if (sort args) then List.sort else id :: [Cidr] -> [Cidr] mapM_ (print . nrml_func) (sort_func $ combine_all valid_cidrs) Duped{} -> mapM_ print dupes where dupes = valid_cidrs \\ (combine_all valid_cidrs) Diffed{} -> do mapM_ putStrLn deletions mapM_ putStrLn additions where dupes = valid_cidrs \\ (combine_all valid_cidrs) deletions = map (\s -> '-' : (show s)) dupes newcidrs = (combine_all valid_cidrs) \\ valid_cidrs additions = map (\s -> '+' : (show s)) newcidrs Listed{} -> do let combined_cidrs = combine_all valid_cidrs let addrs = concatMap enumerate combined_cidrs mapM_ print addrs