328aaf3c72c56d9b1d13f4f1de4376cf07b5681d
[kaleidoscope-hs-old.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
3
4 module Main where
5
6 import qualified AST
7 import Control.Monad
8 import Control.Monad.Trans.Class
9 import qualified Data.Map as Map
10 import qualified Data.Text.Lazy.IO as Text
11 import Data.String
12 import Foreign.Ptr
13 import System.Exit
14 import System.IO
15 import LLVM.Context
16 import LLVM.ExecutionEngine
17 import LLVM.Module
18 import LLVM.PassManager
19 import LLVM.IRBuilder
20 import LLVM.AST.AddrSpace
21 import LLVM.AST.Constant
22 import LLVM.AST.Float
23 import LLVM.AST.FloatingPointPredicate hiding (False, True)
24 import LLVM.AST.Operand
25 import LLVM.AST.Type as Type
26 import LLVM.AST.Typed
27 import LLVM.Pretty
28
29 import Debug.Trace
30
31 type ModuleBuilderE = ModuleBuilderT (Either String)
32
33 foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float
34
35 main :: IO ()
36 main = do
37   AST.Program asts <- read <$> getContents
38   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
39   case eitherMdl of
40     Left err -> die err
41     Right mdl -> withContext $ \ctx -> do
42       hPutStrLn stderr "Before optimisation:"
43       Text.hPutStrLn stderr (ppllvm mdl)
44       withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
45         withModuleFromAST ctx mdl $ \mdl' ->
46           withPassManager defaultCuratedPassSetSpec $ \pm -> do
47             runPassManager pm mdl' >>= guard
48             hPutStrLn stderr "After optimisation:"
49             Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl'
50             withModuleInEngine mcjit mdl' $ \emdl -> do
51               Just f <- getFunction emdl "expr"
52               let f' = castFunPtr f :: FunPtr (IO Float)
53               exprFun f' >>= print
54
55 buildAST :: AST.AST -> ModuleBuilderE Operand
56 buildAST (AST.Function nameStr paramStrs body) = do
57   let n = fromString nameStr
58   function n params float $ \binds -> do
59     let bindMap = Map.fromList (zip paramStrs binds)
60     buildExpr bindMap body >>= ret
61   where params = zip (repeat float) (map fromString paramStrs)
62 buildAST (AST.Eval e) =
63   function "expr" [] float $ \_ -> buildExpr mempty e >>= ret
64
65 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand
66 buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a))
67 buildExpr binds (AST.Var n) = case binds Map.!? n of
68   Just x -> pure x
69   Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope"
70
71 buildExpr binds (AST.Call nameStr params) = do
72   paramOps <- mapM (buildExpr binds) params
73   let name = fromString nameStr
74       -- get a pointer to the function
75       typ = FunctionType float (replicate (length params) float) False
76       ptrTyp = Type.PointerType typ (AddrSpace 0)
77       ref = GlobalReference ptrTyp name
78   call (ConstantOperand ref) (zip paramOps (repeat []))
79
80 buildExpr binds (AST.BinOp op a b) = do
81   va <- buildExpr binds a
82   vb <- buildExpr binds b
83   let instr = case op of
84                 AST.Add -> fadd
85                 AST.Sub -> fsub
86                 AST.Mul -> fmul
87                 AST.Cmp GT -> fcmp OGT
88                 AST.Cmp LT -> fcmp OLT
89                 AST.Cmp EQ -> fcmp OEQ
90   instr va vb
91
92 buildExpr binds (AST.If cond thenE elseE) = mdo
93   _ifB <- block `named` "if"
94
95   condV <- buildExpr binds cond
96   when (typeOf condV /= i1) $ lift $ lift $ Left "Not a boolean"
97   condBr condV thenB elseB
98
99   thenB <- block `named` "then"
100   thenOp <- buildExpr binds thenE
101   br mergeB
102
103   elseB <- block `named` "else"
104   elseOp <- buildExpr binds elseE
105   br mergeB
106
107   mergeB <- block `named` "ifcont"
108   traceShowId <$> phi [(thenOp, thenB), (elseOp, elseB)]