]> gitweb.michael.orlitzky.com - spline3.git/blob - src/Main.hs
src/Main.hs: import Args(..) explicitly.
[spline3.git] / src / Main.hs
1 {-# LANGUAGE RecordWildCards, DoAndIfThenElse #-}
2
3 module Main (main)
4 where
5
6 import Control.Monad ( when )
7 import qualified Data.Array.Repa as R
8 import Data.Maybe ( fromJust )
9 import GHC.Conc ( getNumProcessors, setNumCapabilities )
10 import System.IO ( hPutStrLn, stderr )
11 import System.Exit (
12 ExitCode( ExitFailure ),
13 exitSuccess,
14 exitWith )
15
16 import CommandLine (
17 Args(Args, depth, height, input, lower_threshold, output,
18 scale, slice, upper_threshold, width),
19 apply_args )
20 import ExitCodes ( exit_arg_not_positive, exit_arg_out_of_bounds )
21 import Grid ( zoom )
22 import Volumetric (
23 bracket_array,
24 flip_x,
25 flip_y,
26 read_word16s,
27 round_array,
28 swap_bytes,
29 write_values_to_bmp,
30 write_word16s,
31 z_slice )
32
33
34 validate_args :: Args -> IO ()
35 validate_args Args{..} = do
36 when (scale <= 0) $ do
37 hPutStrLn stderr "ERROR: scale must be greater than zero."
38 exitWith (ExitFailure exit_arg_not_positive)
39
40 when (width <= 0) $ do
41 hPutStrLn stderr "ERROR: width must be greater than zero."
42 exitWith (ExitFailure exit_arg_not_positive)
43
44 when (height <= 0) $ do
45 hPutStrLn stderr "ERROR: height must be greater than zero."
46 exitWith (ExitFailure exit_arg_not_positive)
47
48 when (depth <= 0) $ do
49 hPutStrLn stderr "ERROR: depth must be greater than zero."
50 exitWith (ExitFailure exit_arg_not_positive)
51
52 case slice of
53 Just s ->
54 when (s < 0 || s > depth) $ do
55 hPutStrLn stderr "ERROR: slice must be between zero and depth."
56 exitWith (ExitFailure exit_arg_out_of_bounds)
57 Nothing -> return ()
58
59
60 main :: IO ()
61 main = do
62 args@Args{..} <- apply_args
63 -- validate_args will simply exit if there's a problem.
64 validate_args args
65
66 -- The first thing we do is set the number of processors. We get the
67 -- number of processors (cores) in the machine with
68 -- getNumProcessors, and set it with setNumCapabilities. This is so
69 -- we don't have to pass +RTS -Nfoo on the command line every time.
70 num_procs <- getNumProcessors
71 setNumCapabilities num_procs
72
73 let shape = (R.Z R.:. depth R.:. height R.:. width) :: R.DIM3
74
75 -- Determine whether we're doing 2d or 3d. If we're given a slice,
76 -- assume 2d.
77 let main_function = case slice of
78 Nothing -> main3d
79 Just _ -> main2d
80
81 main_function args shape
82 exitSuccess
83
84
85 main3d :: Args -> R.DIM3 -> IO ()
86 main3d Args{..} shape = do
87 let zoom_factor = (scale, scale, scale)
88 arr <- read_word16s input shape
89 let arr_swapped = swap_bytes arr
90 let arr_shaped = R.reshape shape arr_swapped
91 dbl_data <- R.computeUnboxedP $ R.map fromIntegral arr_shaped
92 raw_output <- zoom dbl_data zoom_factor
93 let word16_output = round_array raw_output
94 -- Switch the bytes order back to what it was. This lets us use the
95 -- same program to view the input/output data.
96 swapped_output <- R.computeUnboxedP $ swap_bytes word16_output
97 write_word16s output swapped_output
98
99
100 main2d :: Args -> R.DIM3 -> IO ()
101 main2d Args{..} shape = do
102 let zoom_factor = (1 :: Int, scale, scale)
103 arr <- read_word16s input shape
104 arrSlice <- R.computeUnboxedP
105 $ z_slice (fromJust slice)
106 $ flip_x width
107 $ flip_y height
108 $ swap_bytes arr
109 let arrSlice' = R.reshape slice3d arrSlice
110
111 -- If zoom isn't being inlined we need to extract the slice before hand,
112 -- and convert it to the require formed.
113 dbl_data <- R.computeUnboxedP $ R.map fromIntegral arrSlice'
114 raw_output <- zoom dbl_data zoom_factor
115 arrSlice0 <- R.computeUnboxedP $ z_slice 0 raw_output
116
117 -- Make doubles from the thresholds which are given as Ints.
118 let lt = fromIntegral lower_threshold :: Double
119 let ut = fromIntegral upper_threshold :: Double
120
121 let arr_bracketed = bracket_array lt ut arrSlice0
122 values <- R.computeUnboxedP $ R.map fromIntegral arr_bracketed
123 write_values_to_bmp output values
124
125 where
126 slice3d :: R.DIM3
127 slice3d = (R.Z R.:. 1 R.:. height R.:. width)