Set up the LLVM context and optimise the module
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad.Trans.Reader
6 import Control.Monad.IO.Class
7 import Data.String
8 import qualified Data.Map as Map
9 import qualified Data.Text.Lazy.IO as Text
10 import LLVM.AST.AddrSpace
11 import LLVM.AST.Constant
12 import LLVM.AST.Float
13 import LLVM.AST.FloatingPointPredicate hiding (False, True)
14 import LLVM.AST.Operand
15 import LLVM.AST.Type as Type
16 import LLVM.Context
17 import LLVM.IRBuilder
18 import LLVM.Module
19 import LLVM.PassManager
20 import LLVM.Pretty
21 import LLVM.Target
22 import System.IO
23 import System.IO.Error
24 import Text.Read (readMaybe)
25
26 main :: IO ()
27 main = do
28   mdl' <- buildModuleT "main" repl
29   withContext $ \ctx -> withHostTargetMachine $ \tm ->
30     withModuleFromAST ctx mdl' $ \mdl -> do
31       let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
32       -- this returns true if the module was modified
33       withPassManager spec $ flip runPassManager mdl
34       Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
35
36 repl :: ModuleBuilderT IO ()
37 repl = do
38   liftIO $ hPutStr stderr "ready> "
39   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
40   case mline of
41     Nothing -> return ()
42     Just l -> do
43       case readMaybe l of
44         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
45         Just ast -> do
46           hoist $ buildAST ast
47           mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
48       repl
49   where
50     eofHandler e
51       | isEOFError e = return Nothing
52       | otherwise = ioError e
53
54 type Binds = Map.Map String Operand
55
56 buildAST :: AST -> ModuleBuilder Operand
57 buildAST (Function (Prototype nameStr paramStrs) body) = do
58   let n = fromString nameStr
59   function n params Type.double $ \ops -> do
60     let binds = Map.fromList (zip paramStrs ops)
61     flip runReaderT binds $ buildExpr body >>= ret
62   where params = zip (repeat Type.double) (map fromString paramStrs)
63
64 buildAST (Extern (Prototype nameStr params)) =
65   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
66
67 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
68   const $ flip runReaderT mempty $ buildExpr x >>= ret
69
70 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
71 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
72 buildExpr (Var n) = do
73   binds <- ask
74   case binds Map.!? n of
75     Just x -> pure x
76     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
77
78 buildExpr (BinOp op a b) = do
79   opA <- buildExpr a
80   opB <- buildExpr b
81   tmp <- instr opA opB
82   if isCmp
83     then uitofp tmp Type.double
84     else return tmp
85   where isCmp
86           | Cmp _ <- op = True
87           | otherwise = False
88         instr = case op of
89                   K.Add -> fadd
90                   K.Sub -> fsub
91                   K.Mul -> fmul
92                   K.Cmp LT -> fcmp OLT
93                   K.Cmp GT -> fcmp OGT
94                   K.Cmp EQ -> fcmp OEQ
95
96 buildExpr (Call callee params) = do
97   paramOps <- mapM buildExpr params
98   let nam = fromString callee
99       -- get a pointer to the function
100       typ = FunctionType Type.double (replicate (length params) Type.double) False
101       ptrTyp = Type.PointerType typ (AddrSpace 0)
102       ref = GlobalReference ptrTyp nam
103   call (ConstantOperand ref) (zip paramOps (repeat []))