Find our JIT'ed function and run it
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad
6 import Control.Monad.Trans.Class
7 import Control.Monad.Trans.Reader
8 import Control.Monad.IO.Class
9 import Data.String
10 import qualified Data.Map as Map
11 import qualified Data.Text.Lazy.IO as Text
12 import Foreign.Ptr
13 import LLVM.AST.AddrSpace
14 import LLVM.AST.Constant
15 import LLVM.AST.Float
16 import LLVM.AST.FloatingPointPredicate hiding (False, True)
17 import LLVM.AST.Operand
18 import LLVM.AST.Type as Type
19 import LLVM.Context
20 import LLVM.IRBuilder
21 import LLVM.Module
22 import LLVM.OrcJIT
23 import LLVM.OrcJIT.CompileLayer
24 import LLVM.PassManager
25 import LLVM.Pretty
26 import LLVM.Target
27 import System.IO
28 import System.IO.Error
29 import Text.Read (readMaybe)
30
31 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
32
33 data JITEnv = JITEnv
34   { jitEnvContext :: Context
35   , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
36   , jitEnvModuleKey :: ModuleKey
37   }
38
39 main :: IO ()
40 main = do
41   withContext $ \ctx -> withHostTargetMachine $ \tm -> do
42     withExecutionSession $ \exSession ->
43       withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
44         withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
45           withIRCompileLayer linkingLayer tm $ \compLayer -> do
46             withModuleKey exSession $ \mdlKey -> do
47               let env = JITEnv ctx compLayer mdlKey
48               ast <- runReaderT (buildModuleT "main" repl) env
49               return ()
50
51 -- This can eventually be used to resolve external functions, e.g. a stdlib call
52 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
53 symResolver sym = undefined
54
55 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
56 repl = do
57   liftIO $ hPutStr stderr "ready> "
58   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
59   case mline of
60     Nothing -> return ()
61     Just l -> do
62       case readMaybe l of
63         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
64         Just ast -> do
65           anon <- isAnonExpr <$> hoist (buildAST ast)
66           def <- mostRecentDef
67           
68           ast <- moduleSoFar "main"
69           ctx <- lift $ asks jitEnvContext
70           env <- lift ask
71           liftIO $ withModuleFromAST ctx ast $ \mdl -> do
72             Text.hPutStrLn stderr $ ppll def
73             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
74             -- this returns true if the module was modified
75             withPassManager spec $ flip runPassManager mdl
76             when anon (jit env mdl >>= hPrint stderr)
77
78           when anon (removeDef def)
79       repl
80   where
81     eofHandler e
82       | isEOFError e = return Nothing
83       | otherwise = ioError e
84     isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
85     isAnonExpr _ = False
86
87 jit :: JITEnv -> Module -> IO Double
88 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
89   withModule compLayer mdlKey mdl $ do
90     mangled <- mangleSymbol compLayer "__anon_expr"
91     Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
92     mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
93
94 type Binds = Map.Map String Operand
95
96 buildAST :: AST -> ModuleBuilder Operand
97 buildAST (Function (Prototype nameStr paramStrs) body) = do
98   let n = fromString nameStr
99   function n params Type.double $ \ops -> do
100     let binds = Map.fromList (zip paramStrs ops)
101     flip runReaderT binds $ buildExpr body >>= ret
102   where params = zip (repeat Type.double) (map fromString paramStrs)
103
104 buildAST (Extern (Prototype nameStr params)) =
105   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
106
107 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
108   const $ flip runReaderT mempty $ buildExpr x >>= ret
109
110 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
111 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
112 buildExpr (Var n) = do
113   binds <- ask
114   case binds Map.!? n of
115     Just x -> pure x
116     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
117
118 buildExpr (BinOp op a b) = do
119   opA <- buildExpr a
120   opB <- buildExpr b
121   tmp <- instr opA opB
122   if isCmp
123     then uitofp tmp Type.double
124     else return tmp
125   where isCmp
126           | Cmp _ <- op = True
127           | otherwise = False
128         instr = case op of
129                   K.Add -> fadd
130                   K.Sub -> fsub
131                   K.Mul -> fmul
132                   K.Cmp LT -> fcmp OLT
133                   K.Cmp GT -> fcmp OGT
134                   K.Cmp EQ -> fcmp OEQ
135
136 buildExpr (Call callee params) = do
137   paramOps <- mapM buildExpr params
138   let nam = fromString callee
139       -- get a pointer to the function
140       typ = FunctionType Type.double (replicate (length params) Type.double) False
141       ptrTyp = Type.PointerType typ (AddrSpace 0)
142       ref = GlobalReference ptrTyp nam
143   call (ConstantOperand ref) (zip paramOps (repeat []))