Add default optimisation passes
[kaleidoscope-hs-old.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 module Main where
4
5 import qualified AST
6 import qualified Data.Map as Map
7 import qualified Data.Text.Lazy.IO as Text
8 import           Data.String
9 import Foreign.Ptr
10 import System.IO
11 import LLVM.Context
12 import LLVM.ExecutionEngine
13 import LLVM.Module
14 import LLVM.PassManager
15 import LLVM.IRBuilder
16 import LLVM.AST.AddrSpace
17 import LLVM.AST.Constant
18 import LLVM.AST.Float
19 import LLVM.AST.Operand
20 import LLVM.AST.Type as Type
21 import LLVM.Pretty
22 import           Control.Monad
23
24 foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float
25
26 main :: IO ()
27 main = do
28   AST.Program asts <- read <$> getContents
29   let mdl = buildModule "main" $ mapM buildAST asts
30   withContext $ \ctx ->
31     withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
32       withModuleFromAST ctx mdl $ \mdl' ->
33         withPassManager defaultCuratedPassSetSpec $ \pm -> do
34           runPassManager pm mdl' >>= guard
35           Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl'
36           withModuleInEngine mcjit mdl' $ \emdl -> do
37             Just f <- getFunction emdl "expr"
38             let f' = castFunPtr f :: FunPtr (IO Float)
39             exprFun f' >>= print
40
41 buildAST :: AST.AST -> ModuleBuilder Operand
42 buildAST (AST.Function nameStr paramStrs body) = do
43   let name = fromString nameStr
44   function name params float $ \binds -> do
45     let bindMap = Map.fromList (zip paramStrs binds)
46     buildExpr bindMap body >>= ret
47   where params = zip (repeat float) (map fromString paramStrs)
48 buildAST (AST.Eval e) =
49   function "expr" [] float $ \_ -> buildExpr mempty e >>= ret
50
51 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilder Operand
52 buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a))
53 buildExpr binds (AST.Var name) = pure $ binds Map.! name
54
55 buildExpr binds (AST.Call nameStr params) = do
56   paramOps <- mapM (buildExpr binds) params
57   let name = fromString nameStr
58       -- get a pointer to the function
59       typ = FunctionType float (replicate (length params) float) False
60       ptrTyp = Type.PointerType typ (AddrSpace 0)
61       ref = GlobalReference ptrTyp name
62   call (ConstantOperand ref) (zip paramOps (repeat []))
63
64 buildExpr binds (AST.BinOp op a b) = do
65   va <- buildExpr binds a
66   vb <- buildExpr binds b
67   let instr = case op of
68                 AST.Add -> fadd
69                 AST.Sub -> fsub
70                 AST.Mul -> fmul
71   instr va vb