Add basic error handling
[kaleidoscope-hs-old.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 module Main where
4
5 import qualified AST
6 import qualified Data.Map as Map
7 import qualified Data.Text.Lazy.IO as Text
8 import           Data.String
9 import Foreign.Ptr
10 import System.IO
11 import           System.Exit
12 import LLVM.Context
13 import LLVM.ExecutionEngine
14 import LLVM.Module
15 import LLVM.PassManager
16 import LLVM.IRBuilder
17 import LLVM.AST.AddrSpace
18 import LLVM.AST.Constant
19 import LLVM.AST.Float
20 import LLVM.AST.Operand
21 import LLVM.AST.Type as Type
22 import LLVM.Pretty
23 import           Control.Monad
24 import           Control.Monad.Trans.Class
25
26 type ModuleBuilderE = ModuleBuilderT (Either String)
27
28 foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float
29
30 main :: IO ()
31 main = do
32   AST.Program asts <- read <$> getContents
33   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
34   case eitherMdl of
35     Left err -> die err
36     Right mdl -> withContext $ \ctx ->
37       withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
38         withModuleFromAST ctx mdl $ \mdl' ->
39           withPassManager defaultCuratedPassSetSpec $ \pm -> do
40             runPassManager pm mdl' >>= guard
41             Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl'
42             withModuleInEngine mcjit mdl' $ \emdl -> do
43               Just f <- getFunction emdl "expr"
44               let f' = castFunPtr f :: FunPtr (IO Float)
45               exprFun f' >>= print
46
47 buildAST :: AST.AST -> ModuleBuilderE Operand
48 buildAST (AST.Function nameStr paramStrs body) = do
49   let n = fromString nameStr
50   function n params float $ \binds -> do
51     let bindMap = Map.fromList (zip paramStrs binds)
52     buildExpr bindMap body >>= ret
53   where params = zip (repeat float) (map fromString paramStrs)
54 buildAST (AST.Eval e) =
55   function "expr" [] float $ \_ -> buildExpr mempty e >>= ret
56
57 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand
58 buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a))
59 buildExpr binds (AST.Var n) = case binds Map.!? n of
60   Just x -> pure x
61   Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope"
62
63 buildExpr binds (AST.Call nameStr params) = do
64   paramOps <- mapM (buildExpr binds) params
65   let name = fromString nameStr
66       -- get a pointer to the function
67       typ = FunctionType float (replicate (length params) float) False
68       ptrTyp = Type.PointerType typ (AddrSpace 0)
69       ref = GlobalReference ptrTyp name
70   call (ConstantOperand ref) (zip paramOps (repeat []))
71
72 buildExpr binds (AST.BinOp op a b) = do
73   va <- buildExpr binds a
74   vb <- buildExpr binds b
75   let instr = case op of
76                 AST.Add -> fadd
77                 AST.Sub -> fsub
78                 AST.Mul -> fmul
79   instr va vb