]> gitweb.michael.orlitzky.com - hath.git/blob - src/Main.hs
Rewrite command-line parsing to use cmdargs.
[hath.git] / src / Main.hs
1 module Main
2 where
3
4 import Control.Concurrent.ParallelIO.Global ( stopGlobalPool )
5 import Control.Monad (when)
6 import qualified Data.ByteString.Char8 as BS (intercalate, pack, unpack)
7 import Data.List ((\\), intercalate)
8 import Data.Maybe (catMaybes, isNothing)
9 import Data.String.Utils (splitWs)
10 import System.Exit (ExitCode(..), exitWith)
11 import System.IO (stderr, hPutStrLn)
12 import Text.Read (readMaybe)
13
14 import Cidr (
15 Cidr(..),
16 combine_all,
17 enumerate,
18 max_octet1,
19 max_octet2,
20 max_octet3,
21 max_octet4,
22 min_octet1,
23 min_octet2,
24 min_octet3,
25 min_octet4 )
26 import CommandLine (Args(..), get_args)
27 import DNS (Domain, PTRResult, lookup_ptrs)
28 import ExitCodes ( exit_invalid_cidr )
29 import Octet ()
30
31
32 -- | A regular expression that matches a non-address character.
33 non_addr_char :: String
34 non_addr_char = "[^\\.0-9]"
35
36
37 -- | Add non_addr_chars on either side of the given String. This
38 -- prevents (for example) the regex '127.0.0.1' from matching
39 -- '127.0.0.100'.
40 add_barriers :: String -> String
41 add_barriers x = non_addr_char ++ x ++ non_addr_char
42
43
44 -- | The magic happens here. We take a CIDR String as an argument, and
45 -- return the equivalent regular expression. We do this as follows:
46 --
47 -- 1. Compute the minimum possible value of each octet.
48 -- 2. Compute the maximum possible value of each octet.
49 -- 3. Generate a regex matching every value between those min and
50 -- max values.
51 -- 4. Join the regexes from step 3 with regexes matching periods.
52 -- 5. Stick an address boundary on either side of the result.
53 --
54 cidr_to_regex :: Bool -> Cidr.Cidr -> String
55 cidr_to_regex use_barriers cidr =
56 let f = if use_barriers then add_barriers else id in
57 f (intercalate "\\." [range1, range2, range3, range4])
58 where
59 range1 = numeric_range min1 max1
60 range2 = numeric_range min2 max2
61 range3 = numeric_range min3 max3
62 range4 = numeric_range min4 max4
63 min1 = fromEnum (min_octet1 cidr)
64 min2 = fromEnum (min_octet2 cidr)
65 min3 = fromEnum (min_octet3 cidr)
66 min4 = fromEnum (min_octet4 cidr)
67 max1 = fromEnum (max_octet1 cidr)
68 max2 = fromEnum (max_octet2 cidr)
69 max3 = fromEnum (max_octet3 cidr)
70 max4 = fromEnum (max_octet4 cidr)
71
72
73
74 -- | Take a list of Strings, and return a regular expression matching
75 -- any of them.
76 alternate :: [String] -> String
77 alternate terms = "(" ++ (intercalate "|" terms) ++ ")"
78
79
80 -- | Take two Ints as parameters, and return a regex matching any
81 -- integer between them (inclusive).
82 numeric_range :: Int -> Int -> String
83 numeric_range x y =
84 alternate (map show [lower..upper])
85 where
86 lower = minimum [x,y]
87 upper = maximum [x,y]
88
89
90 main :: IO ()
91 main = do
92 args <- get_args
93
94 -- This reads stdin.
95 input <- getContents
96
97 let cidr_strings = splitWs input
98 let cidrs = map readMaybe cidr_strings
99
100 when (any isNothing cidrs) $ do
101 hPutStrLn stderr "ERROR: not valid CIDR notation:"
102
103 -- Output the bad lines, safely.
104 let pairs = zip cidr_strings cidrs
105 let print_pair (x, Nothing) = hPutStrLn stderr (" * " ++ x)
106 print_pair (_, _) = return ()
107
108 mapM_ print_pair pairs
109 exitWith (ExitFailure exit_invalid_cidr)
110
111 -- Filter out only the valid ones.
112 let valid_cidrs = catMaybes cidrs
113
114 case args of
115 Regexed{} -> do
116 let cidrs' = combine_all valid_cidrs
117 let regexes = map (cidr_to_regex (barriers args)) cidrs'
118 putStrLn $ alternate regexes
119 Reduced{} ->
120 mapM_ print (combine_all valid_cidrs)
121 Duped{} ->
122 mapM_ print dupes
123 where
124 dupes = valid_cidrs \\ (combine_all valid_cidrs)
125 Diffed{} -> do
126 mapM_ putStrLn deletions
127 mapM_ putStrLn additions
128 where
129 dupes = valid_cidrs \\ (combine_all valid_cidrs)
130 deletions = map (\s -> '-' : (show s)) dupes
131 newcidrs = (combine_all valid_cidrs) \\ valid_cidrs
132 additions = map (\s -> '+' : (show s)) newcidrs
133 Listed{} -> do
134 let combined_cidrs = combine_all valid_cidrs
135 let addrs = concatMap enumerate combined_cidrs
136 mapM_ print addrs
137 Reversed{} -> do
138 let combined_cidrs = combine_all valid_cidrs
139 let addrs = concatMap enumerate combined_cidrs
140 let addr_bytestrings = map (BS.pack . show) addrs
141 ptrs <- lookup_ptrs addr_bytestrings
142 let pairs = zip addr_bytestrings ptrs
143 mapM_ (putStrLn . show_pair) pairs
144
145 stopGlobalPool
146
147 where
148 show_pair :: (Domain, PTRResult) -> String
149 show_pair (s, eds) =
150 (BS.unpack s) ++ ": " ++ results
151 where
152 space = BS.pack " "
153 results =
154 case eds of
155 Left err -> "ERROR (" ++ (show err) ++ ")"
156 Right ds -> BS.unpack $ BS.intercalate space ds