{-# LANGUAGE PatternGuards #-} module Report ( report, report_tests ) where import Control.Monad ( filterM ) import qualified Data.ByteString.Char8 as BS ( pack ) import Data.Maybe ( catMaybes, listToMaybe ) import Data.String.Utils ( join, split, strip ) import Database.HDBC ( IConnection, execute, prepare, sFetchAllRows') import Database.HDBC.Sqlite3 ( connectSqlite3 ) import Data.List ( (\\) ) import System.Console.CmdArgs.Default ( Default(..) ) import Test.Tasty ( TestTree, testGroup ) import Test.Tasty.HUnit ( (@?=), testCase ) import Configuration ( Configuration(..) ) import DNS ( lookup_mxs, normalize ) import MxList ( MxList(..) ) -- Type synonyms to make the signatures below a little more clear. type Domain = String type Address = String type Goto = String data Forward = Forward Address Goto deriving (Show) get_domain_list :: IConnection a => a -> String -> IO [Domain] get_domain_list conn query = do stmt <- prepare conn query -- We really want executeRaw here, but there's a bug: it will tell -- us we can't fetch rows from the statement since it hasn't been -- executed yet! _ <- execute stmt [] -- rows :: [[Maybe String]] rows <- sFetchAllRows' stmt -- rows' :: [Maybe String] let rows' = map (listToMaybe . catMaybes) rows -- domains :: [String] let domains = catMaybes rows' return domains get_forward_list :: IConnection a => a -> String -> IO [Forward] get_forward_list conn query = do stmt <- prepare conn query -- We really want executeRaw here, but there's a bug: it will tell -- us we can't fetch rows from the statement since it hasn't been -- executed yet! _ <- execute stmt [] -- rows :: [[Maybe String]] rows <- sFetchAllRows' stmt -- forwards :: [Forward] let forwards = concatMap (row_to_forwards . catMaybes) rows return forwards where row_to_forwards :: [String] -> [Forward] row_to_forwards (addr:gotos:_) = [Forward addr (strip g) | g <- split "," gotos] row_to_forwards _ = [] find_remote_forwards :: [Domain] -> [Forward] -> [Forward] find_remote_forwards domains forwards = filter is_remote forwards where is_remote :: Forward -> Bool is_remote (Forward _ goto) = let parts = split "@" goto in case parts of (_:dp:[]) -> not $ dp `elem` domains _ -> True -- Assume it's remote if something is wrong format_forward :: Forward -> String format_forward (Forward addr goto) = addr ++ " -> " ++ goto -- If the MX records for a domain are exactly those contained in the -- MxList, then we exclude that domain from the report. Splitting on -- the '@' is a lazy way of obtaining the domain, but if it's good -- enough for determining that a forward is remote, then it's good -- enough for this. filter_by_mx :: MxList -> [Forward] -> IO [Forward] -- This special case is necessary! Otherwise if we have an empty -- exclude list and a domain that has no MX record, it will be -- excluded. filter_by_mx (MxList []) = return filter_by_mx (MxList mxs) = filterM all_mxs_excluded where all_mxs_excluded :: Forward -> IO Bool all_mxs_excluded (Forward addr _) = case (split "@" addr) of (_:domain_part:[]) -> do fw_mxs <- lookup_mxs (BS.pack domain_part) let norm_mxs = map (normalize . BS.pack) mxs if (norm_mxs \\ fw_mxs) == [] then return False else return True _ -> return True -- Report it if we can't figure out the domain. -- | Given a connection and a 'Configuration', produces the report as -- a 'String'. -- report :: IConnection a => Configuration -> a -> IO String report cfg conn = do domains <- get_domain_list conn (domain_query cfg) forwards <- get_forward_list conn (forward_query cfg) valid_forwards <- filter_by_mx (exclude_mx cfg) forwards let remote_forwards = find_remote_forwards domains valid_forwards let forward_strings = map format_forward remote_forwards return $ (join "\n" forward_strings) ++ "\n" -- * Tests report_tests :: TestTree report_tests = testGroup "Report Tests" [ test_example1 ] test_example1 :: TestTree test_example1 = testCase desc $ do conn <- connectSqlite3 "test/fixtures/postfixadmin.sqlite3" let cfg = def :: Configuration actual <- report cfg conn actual @?= expected where desc = "all remote forwards are found" expected = "user1@example.com -> user1@example.net\n" ++ "user2@example.com -> user1@example.org\n" ++ "user2@example.com -> user2@example.org\n" ++ "user2@example.com -> user3@example.org\n"