Find our JIT'ed function and run it
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad.Trans.Class
6 import Control.Monad.Trans.Reader
7 import Control.Monad.IO.Class
8 import Data.String
9 import qualified Data.Map as Map
10 import qualified Data.Text.Lazy.IO as Text
11 import LLVM.AST.AddrSpace
12 import LLVM.AST.Constant
13 import LLVM.AST.Float
14 import LLVM.AST.FloatingPointPredicate hiding (False, True)
15 import LLVM.AST.Operand
16 import LLVM.AST.Type as Type
17 import LLVM.Context
18 import LLVM.IRBuilder
19 import LLVM.Module
20 import LLVM.PassManager
21 import LLVM.Pretty
22 import LLVM.Target
23 import System.IO
24 import System.IO.Error
25 import Text.Read (readMaybe)
26
27 main :: IO ()
28 main = do
29   withContext $ \ctx -> withHostTargetMachineDefault $ \tm -> do
30     ast <- runReaderT (buildModuleT "main" repl) ctx
31     return ()
32
33 repl :: ModuleBuilderT (ReaderT Context IO) ()
34 repl = do
35   liftIO $ hPutStr stderr "ready> "
36   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
37   case mline of
38     Nothing -> return ()
39     Just l -> do
40       case readMaybe l of
41         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
42         Just ast -> do
43           hoist $ buildAST ast
44           mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
45
46           ast <- moduleSoFar "main"
47           ctx <- lift ask
48           liftIO $ withModuleFromAST ctx ast $ \mdl -> do
49             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
50             -- this returns true if the module was modified
51             withPassManager spec $ flip runPassManager mdl
52             Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
53       repl
54   where
55     eofHandler e
56       | isEOFError e = return Nothing
57       | otherwise = ioError e
58
59 type Binds = Map.Map String Operand
60
61 buildAST :: AST -> ModuleBuilder Operand
62 buildAST (Function (Prototype nameStr paramStrs) body) = do
63   let n = fromString nameStr
64   function n params Type.double $ \ops -> do
65     let binds = Map.fromList (zip paramStrs ops)
66     flip runReaderT binds $ buildExpr body >>= ret
67   where params = zip (repeat Type.double) (map fromString paramStrs)
68
69 buildAST (Extern (Prototype nameStr params)) =
70   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
71
72 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
73   const $ flip runReaderT mempty $ buildExpr x >>= ret
74
75 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
76 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
77 buildExpr (Var n) = do
78   binds <- ask
79   case binds Map.!? n of
80     Just x -> pure x
81     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
82
83 buildExpr (BinOp op a b) = do
84   opA <- buildExpr a
85   opB <- buildExpr b
86   tmp <- instr opA opB
87   if isCmp
88     then uitofp tmp Type.double
89     else return tmp
90   where isCmp
91           | Cmp _ <- op = True
92           | otherwise = False
93         instr = case op of
94                   K.Add -> fadd
95                   K.Sub -> fsub
96                   K.Mul -> fmul
97                   K.Cmp LT -> fcmp OLT
98                   K.Cmp GT -> fcmp OGT
99                   K.Cmp EQ -> fcmp OEQ
100
101 buildExpr (Call callee params) = do
102   paramOps <- mapM buildExpr params
103   let nam = fromString callee
104       -- get a pointer to the function
105       typ = FunctionType Type.double (replicate (length params) Type.double) False
106       ptrTyp = Type.PointerType typ (AddrSpace 0)
107       ref = GlobalReference ptrTyp nam
108   call (ConstantOperand ref) (zip paramOps (repeat []))