89fd91a39cf663048b77497b879e9d01903e5751
[kaleidoscope-hs-old.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
3
4 module Main where
5
6 import qualified AST
7 import Control.Monad
8 import Control.Monad.Trans.Class
9 import qualified Data.Map as Map
10 import qualified Data.Text.Lazy.IO as Text
11 import Data.String
12 import Foreign.Ptr
13 import System.Exit
14 import System.IO
15 import LLVM.Context
16 import LLVM.ExecutionEngine
17 import LLVM.Module
18 import LLVM.PassManager
19 import LLVM.IRBuilder
20 import LLVM.AST.AddrSpace
21 import LLVM.AST.Constant
22 import LLVM.AST.Float
23 import LLVM.AST.FloatingPointPredicate hiding (False, True)
24 import LLVM.AST.Operand
25 import LLVM.AST.Type as Type
26 import LLVM.Pretty
27
28 import Debug.Trace
29
30 type ModuleBuilderE = ModuleBuilderT (Either String)
31
32 foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float
33
34 main :: IO ()
35 main = do
36   AST.Program asts <- read <$> getContents
37   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
38   case eitherMdl of
39     Left err -> die err
40     Right mdl -> withContext $ \ctx -> do
41       hPutStrLn stderr "Before optimisation:"
42       Text.hPutStrLn stderr (ppllvm mdl)
43       withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
44         withModuleFromAST ctx mdl $ \mdl' ->
45           withPassManager defaultCuratedPassSetSpec $ \pm -> do
46             runPassManager pm mdl' >>= guard
47             hPutStrLn stderr "After optimisation:"
48             Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl'
49             withModuleInEngine mcjit mdl' $ \emdl -> do
50               Just f <- getFunction emdl "expr"
51               let f' = castFunPtr f :: FunPtr (IO Float)
52               exprFun f' >>= print
53
54 buildAST :: AST.AST -> ModuleBuilderE Operand
55 buildAST (AST.Function nameStr paramStrs body) = do
56   let n = fromString nameStr
57   function n params float $ \binds -> do
58     let bindMap = Map.fromList (zip paramStrs binds)
59     buildExpr bindMap body >>= ret
60   where params = zip (repeat float) (map fromString paramStrs)
61 buildAST (AST.Eval e) =
62   function "expr" [] float $ \_ -> buildExpr mempty e >>= ret
63
64 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand
65 buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a))
66 buildExpr binds (AST.Var n) = case binds Map.!? n of
67   Just x -> pure x
68   Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope"
69
70 buildExpr binds (AST.Call nameStr params) = do
71   paramOps <- mapM (buildExpr binds) params
72   let name = fromString nameStr
73       -- get a pointer to the function
74       typ = FunctionType float (replicate (length params) float) False
75       ptrTyp = Type.PointerType typ (AddrSpace 0)
76       ref = GlobalReference ptrTyp name
77   call (ConstantOperand ref) (zip paramOps (repeat []))
78
79 buildExpr binds (AST.BinOp op a b) = do
80   va <- buildExpr binds a
81   vb <- buildExpr binds b
82   let instr = case op of
83                 AST.Add -> fadd
84                 AST.Sub -> fsub
85                 AST.Mul -> fmul
86                 AST.Cmp GT -> fcmp OGT
87                 AST.Cmp LT -> fcmp OLT
88                 AST.Cmp EQ -> fcmp OEQ
89   instr va vb
90
91 buildExpr binds (AST.If cond thenE elseE) = mdo
92   _ifB <- block `named` "if"
93   condV <- buildExpr binds cond
94   condBr condV thenB elseB
95
96   thenB <- block `named` "then"
97   thenOp <- buildExpr binds thenE
98   br mergeB
99
100   elseB <- block `named` "else"
101   elseOp <- buildExpr binds elseE
102   br mergeB
103
104   mergeB <- block `named` "ifcont"
105   traceShowId <$> phi [(thenOp, thenB), (elseOp, elseB)]