4d8275828510ceed0bbcc99f3b2b95798042b2ae
[kaleidoscope-hs-old.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
3
4 module Main where
5
6 import qualified AST
7 import Control.Monad
8 import Control.Monad.Trans.Class
9 import qualified Data.Map as Map
10 import qualified Data.Text.Lazy.IO as Text
11 import Data.String
12 import Foreign.Ptr
13 import System.Exit
14 import System.IO
15 import LLVM.Context
16 import LLVM.ExecutionEngine
17 import LLVM.Module
18 import LLVM.PassManager
19 import LLVM.IRBuilder
20 import LLVM.AST.AddrSpace
21 import LLVM.AST.Constant
22 import LLVM.AST.Float
23 import LLVM.AST.FloatingPointPredicate hiding (False, True)
24 import LLVM.AST.Operand
25 import LLVM.AST.Type as Type
26 import LLVM.AST.Typed
27 import LLVM.Pretty
28
29 type ModuleBuilderE = ModuleBuilderT (Either String)
30
31 foreign import ccall "dynamic" exprFun :: FunPtr (IO Double) -> IO Double
32
33 main :: IO ()
34 main = do
35   AST.Program asts <- read <$> getContents
36   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
37   case eitherMdl of
38     Left err -> die err
39     Right mdl -> withContext $ \ctx -> do
40       hPutStrLn stderr "Before optimisation:"
41       Text.hPutStrLn stderr (ppllvm mdl)
42       withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
43         withModuleFromAST ctx mdl $ \mdl' -> do
44           -- withPassManager defaultCuratedPassSetSpec $ \pm -> do
45           --   runPassManager pm mdl' >>= guard
46           hPutStrLn stderr "After optimisation:"
47           Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl'
48           withModuleInEngine mcjit mdl' $ \emdl -> do
49           Just f <- getFunction emdl "expr"
50           let f' = castFunPtr f :: FunPtr (IO Double)
51           exprFun f' >>= print
52
53 evalProg :: AST.Program -> IO (Maybe Double)
54 evalProg (AST.Program asts) = do
55   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
56   case eitherMdl of
57     Left _ -> return Nothing
58     Right mdl -> withContext $ \ctx ->
59       withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
60         withModuleFromAST ctx mdl $ \mdl' ->
61           withModuleInEngine mcjit mdl' $ \emdl -> do
62             Just f <- getFunction emdl "expr"
63             let f' = castFunPtr f :: FunPtr (IO Double)
64             Just <$> exprFun f'
65
66 -- | Builds up programs at the top-level of an LLVM Module
67 -- >>> evalProg (read "31 - 5")
68 -- Just 26.0
69 -- >>> evalProg (read "extern pow(x e); pow(3,2)")
70 -- Just 9.0
71 buildAST :: AST.AST -> ModuleBuilderE Operand
72 buildAST (AST.Function nameStr paramStrs body) = do
73   let n = fromString nameStr
74   function n params Type.double $ \binds -> do
75     let bindMap = Map.fromList (zip paramStrs binds)
76     buildExpr bindMap body >>= ret
77   where params = zip (repeat Type.double) (map fromString paramStrs)
78 buildAST (AST.Extern nameStr params) =
79   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
80 buildAST (AST.Eval e) =
81   function "expr" [] Type.double $ \_ -> buildExpr mempty e >>= ret
82
83 -- | Builds up expressions, which are operands in LLVM IR
84 -- >>> evalProg (read "def foo(x) x * 2; foo(6)")
85 -- Just 12.0
86 -- >>> evalProg (read "if 3 > 2 then 42 else 12")
87 -- Just 42.0
88 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand
89 buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Double a))
90 buildExpr binds (AST.Var n) = case binds Map.!? n of
91   Just x -> pure x
92   Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope"
93
94 buildExpr binds (AST.Call nameStr params) = do
95   paramOps <- mapM (buildExpr binds) params
96   let name = fromString nameStr
97       -- get a pointer to the function
98       typ = FunctionType Type.double (replicate (length params) Type.double) False
99       ptrTyp = Type.PointerType typ (AddrSpace 0)
100       ref = GlobalReference ptrTyp name
101   call (ConstantOperand ref) (zip paramOps (repeat []))
102
103 buildExpr binds (AST.BinOp op a b) = do
104   va <- buildExpr binds a
105   vb <- buildExpr binds b
106   let instr = case op of
107                 AST.Add -> fadd
108                 AST.Sub -> fsub
109                 AST.Mul -> fmul
110                 AST.Cmp GT -> fcmp OGT
111                 AST.Cmp LT -> fcmp OLT
112                 AST.Cmp EQ -> fcmp OEQ
113   instr va vb
114
115 buildExpr binds (AST.If cond thenE elseE) = mdo
116   _ifB <- block `named` "if"
117   condV <- buildExpr binds cond
118   when (typeOf condV /= i1) $ lift $ lift $ Left "Not a boolean"
119   condBr condV thenB elseB
120
121   thenB <- block `named` "then"
122   thenOp <- buildExpr binds thenE
123   br mergeB
124
125   elseB <- block `named` "else"
126   elseOp <- buildExpr binds elseE
127   br mergeB
128
129   mergeB <- block `named` "ifcont"
130   phi [(thenOp, thenB), (elseOp, elseB)]