Update AST to match Kaleidoscope more closely
[kaleidoscope-hs-old.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 module Main where
4
5 import qualified AST
6 import qualified Data.Map as Map
7 import qualified Data.Text.Lazy.IO as Text
8 import           Data.String
9 import Foreign.Ptr
10 import System.IO
11 import LLVM.Context
12 import LLVM.ExecutionEngine
13 import LLVM.Module
14 import LLVM.IRBuilder
15 import LLVM.AST.AddrSpace
16 import LLVM.AST.Constant
17 import LLVM.AST.Float
18 import LLVM.AST.Operand
19 import LLVM.AST.Type as Type
20 import LLVM.Pretty
21
22 foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float
23
24 main :: IO ()
25 main = do
26   program <- read <$> getContents
27   let mdl = buildModule "main" $ mapM buildAST program
28   Text.hPutStrLn stderr (ppllvm mdl)
29   withContext $ \ctx ->
30     withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
31       withModuleFromAST ctx mdl $ \mdl' ->
32         withModuleInEngine mcjit mdl' $ \emdl -> do
33           Just f <- getFunction emdl "expr"
34           let f' = castFunPtr f :: FunPtr (IO Float)
35           exprFun f' >>= print
36
37 buildAST :: AST.AST -> ModuleBuilder Operand
38 buildAST (AST.Function nameStr paramStrs body) = do
39   let name = fromString nameStr
40   function name params float $ \binds -> do
41     let bindMap = Map.fromList (zip paramStrs binds)
42     buildExpr bindMap body >>= ret
43   where params = zip (repeat float) (map fromString paramStrs)
44 buildAST (AST.Eval e) =
45   function "expr" [] float $ \_ -> buildExpr mempty e >>= ret
46
47 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilder Operand
48 buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a))
49 buildExpr binds (AST.Var name) = pure $ binds Map.! name
50
51 buildExpr binds (AST.Call nameStr params) = do
52   paramOps <- mapM (buildExpr binds) params
53   let name = fromString nameStr
54       -- get a pointer to the function
55       typ = FunctionType float (replicate (length params) float) False
56       ptrTyp = Type.PointerType typ (AddrSpace 0)
57       ref = GlobalReference ptrTyp name
58   call (ConstantOperand ref) (zip paramOps (repeat []))
59
60 buildExpr binds (AST.BinOp op a b) = do
61   va <- buildExpr binds a
62   vb <- buildExpr binds b
63   let instr = case op of
64                 AST.Add -> fadd
65                 AST.Sub -> fsub
66                 AST.Mul -> fmul
67   instr va vb