Power up the OrcJIT
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad
6 import Control.Monad.Trans.Class
7 import Control.Monad.Trans.Reader
8 import Control.Monad.IO.Class
9 import Data.String
10 import qualified Data.Map as Map
11 import qualified Data.Text.Lazy.IO as Text
12 import LLVM.AST.AddrSpace
13 import LLVM.AST.Constant
14 import LLVM.AST.Float
15 import LLVM.AST.FloatingPointPredicate hiding (False, True)
16 import LLVM.AST.Operand
17 import LLVM.AST.Type as Type
18 import LLVM.Context
19 import LLVM.IRBuilder
20 import LLVM.Module
21 import LLVM.OrcJIT
22 import LLVM.OrcJIT.CompileLayer
23 import LLVM.PassManager
24 import LLVM.Pretty
25 import LLVM.Target
26 import System.IO
27 import System.IO.Error
28 import Text.Read (readMaybe)
29
30 data JITEnv = JITEnv
31   { jitEnvContext :: Context
32   , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
33   , jitEnvModuleKey :: ModuleKey
34   }
35
36 main :: IO ()
37 main = do
38   withContext $ \ctx -> withHostTargetMachine $ \tm -> do
39     withExecutionSession $ \exSession ->
40       withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
41         withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
42           withIRCompileLayer linkingLayer tm $ \compLayer -> do
43             withModuleKey exSession $ \mdlKey -> do
44               let env = JITEnv ctx compLayer mdlKey
45               ast <- runReaderT (buildModuleT "main" repl) env
46               return ()
47
48 -- This can eventually be used to resolve external functions, e.g. a stdlib call
49 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
50 symResolver sym = undefined
51
52 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
53 repl = do
54   liftIO $ hPutStr stderr "ready> "
55   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
56   case mline of
57     Nothing -> return ()
58     Just l -> do
59       case readMaybe l of
60         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
61         Just ast -> do
62           anon <- isAnonExpr <$> hoist (buildAST ast)
63           def <- mostRecentDef
64           
65           ast <- moduleSoFar "main"
66           ctx <- lift $ asks jitEnvContext
67           env <- lift ask
68           liftIO $ withModuleFromAST ctx ast $ \mdl -> do
69             Text.hPutStrLn stderr $ ppll def
70             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
71             -- this returns true if the module was modified
72             withPassManager spec $ flip runPassManager mdl
73             when anon (jit env mdl >>= hPrint stderr)
74
75           when anon (removeDef def)
76       repl
77   where
78     eofHandler e
79       | isEOFError e = return Nothing
80       | otherwise = ioError e
81     isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
82     isAnonExpr _ = False
83
84 jit :: JITEnv -> Module -> IO Double
85 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
86   withModule compLayer mdlKey mdl $ do
87     return 0
88
89 type Binds = Map.Map String Operand
90
91 buildAST :: AST -> ModuleBuilder Operand
92 buildAST (Function (Prototype nameStr paramStrs) body) = do
93   let n = fromString nameStr
94   function n params Type.double $ \ops -> do
95     let binds = Map.fromList (zip paramStrs ops)
96     flip runReaderT binds $ buildExpr body >>= ret
97   where params = zip (repeat Type.double) (map fromString paramStrs)
98
99 buildAST (Extern (Prototype nameStr params)) =
100   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
101
102 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
103   const $ flip runReaderT mempty $ buildExpr x >>= ret
104
105 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
106 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
107 buildExpr (Var n) = do
108   binds <- ask
109   case binds Map.!? n of
110     Just x -> pure x
111     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
112
113 buildExpr (BinOp op a b) = do
114   opA <- buildExpr a
115   opB <- buildExpr b
116   tmp <- instr opA opB
117   if isCmp
118     then uitofp tmp Type.double
119     else return tmp
120   where isCmp
121           | Cmp _ <- op = True
122           | otherwise = False
123         instr = case op of
124                   K.Add -> fadd
125                   K.Sub -> fsub
126                   K.Mul -> fmul
127                   K.Cmp LT -> fcmp OLT
128                   K.Cmp GT -> fcmp OGT
129                   K.Cmp EQ -> fcmp OEQ
130
131 buildExpr (Call callee params) = do
132   paramOps <- mapM buildExpr params
133   let nam = fromString callee
134       -- get a pointer to the function
135       typ = FunctionType Type.double (replicate (length params) Type.double) False
136       ptrTyp = Type.PointerType typ (AddrSpace 0)
137       ref = GlobalReference ptrTyp nam
138   call (ConstantOperand ref) (zip paramOps (repeat []))