Add putchard and move to OrcJIT
[kaleidoscope-hs-old.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
3
4 module Main where
5
6 import qualified AST
7 import Control.Monad
8 import Control.Monad.Trans.Class
9 import qualified Data.Map as Map
10 import qualified Data.Text.Lazy.IO as Text
11 import Data.String
12 import Foreign.Ptr
13 import System.Exit
14 import System.IO
15 import LLVM.Context
16 import LLVM.OrcJIT
17 import LLVM.OrcJIT.CompileLayer
18 import LLVM.Module
19 import LLVM.PassManager
20 import LLVM.IRBuilder
21 import LLVM.AST.AddrSpace
22 import LLVM.AST.Constant
23 import LLVM.AST.Float
24 import LLVM.AST.FloatingPointPredicate hiding (False, True)
25 import LLVM.AST.Operand
26 import LLVM.AST.Type as Type
27 import LLVM.AST.Typed
28 import LLVM.Pretty
29 import LLVM.Linking
30 import LLVM.Target
31
32 import Control.Concurrent.MVar
33
34 type ModuleBuilderE = ModuleBuilderT (Either String)
35
36 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
37
38 main :: IO ()
39 main = do
40   AST.Program asts <- read <$> getContents
41   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
42
43   case eitherMdl of
44     Left err -> die err
45     Right mdl' -> withContext $ \ctx ->
46       withHostTargetMachine $ \tm -> do
47         -- hPutStrLn stderr "Before optimisation:"
48         -- Text.hPutStrLn stderr (ppllvm mdl')
49
50         withModuleFromAST ctx mdl' $ \mdl -> do
51           let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
52           withPassManager spec $ flip runPassManager mdl
53           -- hPutStrLn stderr "After optimisation:"
54           -- Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl
55           jit tm mdl >>= print
56
57 jit :: TargetMachine -> Module -> IO Double
58 jit tm mdl = do
59   loadLibraryPermanently (Just "stdlib.dylib") >>= guard . not
60   compLayerVar <- newEmptyMVar
61   
62   -- jit time
63   withExecutionSession $ \exSession ->
64     withSymbolResolver exSession (SymbolResolver (symResolver compLayerVar)) $ \symResolverPtr ->
65       withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
66         withModuleKey exSession $ \mdlKey ->
67           withIRCompileLayer linkingLayer tm $ \compLayer -> do
68             putMVar compLayerVar compLayer
69
70             withModule compLayer mdlKey mdl $ do
71               mangled <- mangleSymbol compLayer "expr"
72               Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
73               mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
74
75   where symResolver clv sym = do
76           cl <- readMVar clv
77           ms <- findSymbol cl sym False
78           case ms of
79             Right s -> return (return s)
80             _ -> do
81               addr <- getSymbolAddressInProcess sym
82               return $ return (JITSymbol addr (JITSymbolFlags False False True True))
83
84 evalProg :: AST.Program -> IO (Maybe Double)
85 evalProg (AST.Program asts) = do
86   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
87   case eitherMdl of
88     Left _ -> return Nothing
89     Right mdl' -> withContext $ \ctx ->
90       withHostTargetMachine $ \tm ->
91         withModuleFromAST ctx mdl' (fmap Just . jit tm)
92
93 -- | Builds up programs at the top-level of an LLVM Module
94 -- >>> evalProg (read "31 - 5")
95 -- Just 26.0
96 buildAST :: AST.AST -> ModuleBuilderE Operand
97 buildAST (AST.Function nameStr paramStrs body) = do
98   let n = fromString nameStr
99   function n params Type.double $ \binds -> do
100     let bindMap = Map.fromList (zip paramStrs binds)
101     buildExpr bindMap body >>= ret
102   where params = zip (repeat Type.double) (map fromString paramStrs)
103 buildAST (AST.Extern nameStr params) =
104   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
105 buildAST (AST.Eval e) =
106   function "expr" [] Type.double $ \_ -> buildExpr mempty e >>= ret
107
108 -- | Builds up expressions, which are operands in LLVM IR
109 -- >>> evalProg (read "def foo(x) x * 2; foo(6)")
110 -- Just 12.0
111 -- >>> evalProg (read "if 3 > 2 then 42 else 12")
112 -- Just 42.0
113 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand
114 buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Double a))
115 buildExpr binds (AST.Var n) = case binds Map.!? n of
116   Just x -> pure x
117   Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope"
118
119 buildExpr binds (AST.Call nameStr params) = do
120   paramOps <- mapM (buildExpr binds) params
121   let name = fromString nameStr
122       -- get a pointer to the function
123       typ = FunctionType Type.double (replicate (length params) Type.double) False
124       ptrTyp = Type.PointerType typ (AddrSpace 0)
125       ref = GlobalReference ptrTyp name
126   call (ConstantOperand ref) (zip paramOps (repeat []))
127
128 buildExpr binds (AST.BinOp op a b) = do
129   va <- buildExpr binds a
130   vb <- buildExpr binds b
131   let instr = case op of
132                 AST.Add -> fadd
133                 AST.Sub -> fsub
134                 AST.Mul -> fmul
135                 AST.Cmp GT -> fcmp OGT
136                 AST.Cmp LT -> fcmp OLT
137                 AST.Cmp EQ -> fcmp OEQ
138   instr va vb
139
140 buildExpr binds (AST.If cond thenE elseE) = mdo
141   _ifB <- block `named` "if"
142   condV <- buildExpr binds cond
143   when (typeOf condV /= i1) $ lift $ lift $ Left "Not a boolean"
144   condBr condV thenB elseB
145
146   thenB <- block `named` "then"
147   thenOp <- buildExpr binds thenE
148   br mergeB
149
150   elseB <- block `named` "else"
151   elseOp <- buildExpr binds elseE
152   br mergeB
153
154   mergeB <- block `named` "ifcont"
155   phi [(thenOp, thenB), (elseOp, elseB)]