Find our JIT'ed function and run it
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
3
4 import AST as K -- K for Kaleidoscope
5 import Utils
6 import Control.Monad
7 import Control.Monad.Trans.Class
8 import Control.Monad.Trans.Reader
9 import Control.Monad.IO.Class
10 import Data.String
11 import qualified Data.Map as Map
12 import qualified Data.Text.Lazy.IO as Text
13 import Foreign.Ptr
14 import LLVM.AST.AddrSpace
15 import LLVM.AST.Constant
16 import LLVM.AST.Float
17 import LLVM.AST.FloatingPointPredicate hiding (False, True)
18 import LLVM.AST.Operand
19 import LLVM.AST.Type as Type
20 import LLVM.Context
21 import LLVM.IRBuilder
22 import LLVM.Module
23 import LLVM.OrcJIT
24 import LLVM.OrcJIT.CompileLayer
25 import LLVM.PassManager
26 import LLVM.Pretty
27 import LLVM.Target
28 import System.IO
29 import System.IO.Error
30 import Text.Read (readMaybe)
31
32 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
33
34 data JITEnv = JITEnv
35   { jitEnvContext :: Context
36   , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
37   , jitEnvModuleKey :: ModuleKey
38   }
39
40 main :: IO ()
41 main =
42   withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
43     withExecutionSession $ \exSession ->
44       withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
45         withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
46           withIRCompileLayer linkingLayer tm $ \compLayer ->
47             withModuleKey exSession $ \mdlKey -> do
48               let env = JITEnv ctx compLayer mdlKey
49               _ast <- runReaderT (buildModuleT "main" repl) env
50               return ()
51
52 -- This can eventually be used to resolve external functions, e.g. a stdlib call
53 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
54 symResolver sym = undefined
55
56 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
57 repl = do
58   liftIO $ hPutStr stderr "ready> "
59   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
60   case mline of
61     Nothing -> return ()
62     Just l -> do
63       case readMaybe l of
64         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
65         Just ast -> do
66           anon <- isAnonExpr <$> hoist (buildAST ast)
67           def <- mostRecentDef
68           
69           llvmAst <- moduleSoFar "main"
70           ctx <- lift $ asks jitEnvContext
71           env <- lift ask
72           liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
73             Text.hPutStrLn stderr $ ppll def
74             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
75             -- this returns true if the module was modified
76             withPassManager spec $ flip runPassManager mdl
77             when anon (jit env mdl >>= hPrint stderr)
78
79           when anon (removeDef def)
80       repl
81   where
82     eofHandler e
83       | isEOFError e = return Nothing
84       | otherwise = ioError e
85     isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
86     isAnonExpr _ = False
87
88 jit :: JITEnv -> Module -> IO Double
89 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
90   withModule compLayer mdlKey mdl $ do
91     mangled <- mangleSymbol compLayer "__anon_expr"
92     Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
93     mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
94
95 type Binds = Map.Map String Operand
96
97 buildAST :: AST -> ModuleBuilder Operand
98 buildAST (Function (Prototype nameStr paramStrs) body) = do
99   let n = fromString nameStr
100   function n params Type.double $ \ops -> do
101     let binds = Map.fromList (zip paramStrs ops)
102     flip runReaderT binds $ buildExpr body >>= ret
103   where params = zip (repeat Type.double) (map fromString paramStrs)
104
105 buildAST (Extern (Prototype nameStr params)) =
106   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
107
108 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
109   const $ flip runReaderT mempty $ buildExpr x >>= ret
110
111 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
112 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
113 buildExpr (Var n) = do
114   binds <- ask
115   case binds Map.!? n of
116     Just x -> pure x
117     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
118
119 buildExpr (BinOp op a b) = do
120   opA <- buildExpr a
121   opB <- buildExpr b
122   tmp <- instr opA opB
123   if isCmp
124     then uitofp tmp Type.double
125     else return tmp
126   where isCmp
127           | Cmp _ <- op = True
128           | otherwise = False
129         instr = case op of
130                   K.Add -> fadd
131                   K.Sub -> fsub
132                   K.Mul -> fmul
133                   K.Cmp LT -> fcmp OLT
134                   K.Cmp GT -> fcmp OGT
135                   K.Cmp EQ -> fcmp OEQ
136
137 buildExpr (Call callee params) = do
138   paramOps <- mapM buildExpr params
139   let nam = fromString callee
140       -- get a pointer to the function
141       typ = FunctionType Type.double (replicate (length params) Type.double) False
142       ptrTyp = Type.PointerType typ (AddrSpace 0)
143       ref = GlobalReference ptrTyp nam
144   call (ConstantOperand ref) (zip paramOps (repeat []))
145
146 buildExpr (If cond thenE elseE) = mdo
147   _ifB <- block `named` "if"
148
149   -- since everything is a double, false == 0
150   let zero = ConstantOperand (Float (Double 0))
151   condV <- buildExpr cond
152   cmp <- fcmp ONE zero condV `named` "cmp"
153
154   condBr cmp thenB elseB
155
156   thenB <- block `named` "then"
157   thenOp <- buildExpr thenE
158   br mergeB
159
160   elseB <- block `named` "else"
161   elseOp <- buildExpr elseE
162   br mergeB
163
164   mergeB <- block `named` "ifcont"
165   phi [(thenOp, thenB), (elseOp, elseB)]