]> gitweb.michael.orlitzky.com - hath.git/blob - src/Main.hs
Switch from test-framework to tasty.
[hath.git] / src / Main.hs
1 module Main
2 where
3
4 import Control.Concurrent.ParallelIO.Global ( stopGlobalPool )
5 import Control.Monad (when)
6 import qualified Data.ByteString.Char8 as BS (intercalate, pack, unpack)
7 import Data.List ((\\), intercalate)
8 import Data.Maybe (catMaybes, isNothing)
9 import Data.String.Utils (splitWs)
10 import Network.DNS.Types ( DNSError (NameError) )
11 import System.Exit (ExitCode(..), exitWith)
12 import System.IO (stderr, hPutStrLn)
13 import Text.Read (readMaybe)
14
15 import Cidr (
16 Cidr(..),
17 combine_all,
18 enumerate,
19 max_octet1,
20 max_octet2,
21 max_octet3,
22 max_octet4,
23 min_octet1,
24 min_octet2,
25 min_octet3,
26 min_octet4 )
27 import CommandLine (Args(..), get_args)
28 import DNS (Domain, PTRResult, lookup_ptrs)
29 import ExitCodes ( exit_invalid_cidr )
30 import Octet ()
31
32
33 -- | A regular expression that matches a non-address character.
34 non_addr_char :: String
35 non_addr_char = "[^\\.0-9]"
36
37
38 -- | Add non_addr_chars on either side of the given String. This
39 -- prevents (for example) the regex '127.0.0.1' from matching
40 -- '127.0.0.100'.
41 add_barriers :: String -> String
42 add_barriers x = non_addr_char ++ x ++ non_addr_char
43
44
45 -- | The magic happens here. We take a CIDR String as an argument, and
46 -- return the equivalent regular expression. We do this as follows:
47 --
48 -- 1. Compute the minimum possible value of each octet.
49 -- 2. Compute the maximum possible value of each octet.
50 -- 3. Generate a regex matching every value between those min and
51 -- max values.
52 -- 4. Join the regexes from step 3 with regexes matching periods.
53 -- 5. Stick an address boundary on either side of the result if
54 -- use_barriers is True.
55 --
56 cidr_to_regex :: Bool -> Cidr.Cidr -> String
57 cidr_to_regex use_barriers cidr =
58 let f = if use_barriers then add_barriers else id in
59 f (intercalate "\\." [range1, range2, range3, range4])
60 where
61 range1 = numeric_range min1 max1
62 range2 = numeric_range min2 max2
63 range3 = numeric_range min3 max3
64 range4 = numeric_range min4 max4
65 min1 = fromEnum (min_octet1 cidr)
66 min2 = fromEnum (min_octet2 cidr)
67 min3 = fromEnum (min_octet3 cidr)
68 min4 = fromEnum (min_octet4 cidr)
69 max1 = fromEnum (max_octet1 cidr)
70 max2 = fromEnum (max_octet2 cidr)
71 max3 = fromEnum (max_octet3 cidr)
72 max4 = fromEnum (max_octet4 cidr)
73
74
75
76 -- | Take a list of Strings, and return a regular expression matching
77 -- any of them.
78 alternate :: [String] -> String
79 alternate terms = "(" ++ (intercalate "|" terms) ++ ")"
80
81
82 -- | Take two Ints as parameters, and return a regex matching any
83 -- integer between them (inclusive).
84 --
85 -- IMPORTANT: we match from max to min so that if e.g. the last
86 -- octet is '255', we want '255' to match before '2' in the regex
87 -- (255|254|...|3|2|1) which does not happen if we use
88 -- (1|2|3|...|254|255).
89 --
90 numeric_range :: Int -> Int -> String
91 numeric_range x y =
92 alternate (map show $ reverse [lower..upper])
93 where
94 lower = minimum [x,y]
95 upper = maximum [x,y]
96
97
98 main :: IO ()
99 main = do
100 args <- get_args
101
102 -- This reads stdin.
103 input <- getContents
104
105 let cidr_strings = splitWs input
106 let cidrs = map readMaybe cidr_strings
107
108 when (any isNothing cidrs) $ do
109 hPutStrLn stderr "ERROR: not valid CIDR notation:"
110
111 -- Output the bad lines, safely.
112 let pairs = zip cidr_strings cidrs
113 let print_pair (x, Nothing) = hPutStrLn stderr (" * " ++ x)
114 print_pair (_, _) = return ()
115
116 mapM_ print_pair pairs
117 exitWith (ExitFailure exit_invalid_cidr)
118
119 -- Filter out only the valid ones.
120 let valid_cidrs = catMaybes cidrs
121
122 case args of
123 Regexed{} -> do
124 let cidrs' = combine_all valid_cidrs
125 let regexes = map (cidr_to_regex (barriers args)) cidrs'
126 putStrLn $ alternate regexes
127 Reduced{} ->
128 mapM_ print (combine_all valid_cidrs)
129 Duped{} ->
130 mapM_ print dupes
131 where
132 dupes = valid_cidrs \\ (combine_all valid_cidrs)
133 Diffed{} -> do
134 mapM_ putStrLn deletions
135 mapM_ putStrLn additions
136 where
137 dupes = valid_cidrs \\ (combine_all valid_cidrs)
138 deletions = map (\s -> '-' : (show s)) dupes
139 newcidrs = (combine_all valid_cidrs) \\ valid_cidrs
140 additions = map (\s -> '+' : (show s)) newcidrs
141 Listed{} -> do
142 let combined_cidrs = combine_all valid_cidrs
143 let addrs = concatMap enumerate combined_cidrs
144 mapM_ print addrs
145 Reversed{} -> do
146 let combined_cidrs = combine_all valid_cidrs
147 let addrs = concatMap enumerate combined_cidrs
148 let addr_bytestrings = map (BS.pack . show) addrs
149 ptrs <- lookup_ptrs addr_bytestrings
150 let pairs = zip addr_bytestrings ptrs
151 mapM_ (putStrLn . show_pair) pairs
152
153 stopGlobalPool
154
155 where
156 show_pair :: (Domain, PTRResult) -> String
157 show_pair (s, eds) =
158 (BS.unpack s) ++ ": " ++ results
159 where
160 space = BS.pack " "
161 results =
162 case eds of
163 -- NameError simply means "not found" so we output nothing.
164 Left NameError -> ""
165 Left err -> "ERROR (" ++ (show err) ++ ")"
166 Right ds -> BS.unpack $ BS.intercalate space ds