1 {-# LANGUAGE OverloadedStrings #-}
3 import AST as K -- K for Kaleidoscope
6 import Control.Monad.Trans.Class
7 import Control.Monad.Trans.Reader
8 import Control.Monad.IO.Class
10 import qualified Data.Map as Map
11 import qualified Data.Text.Lazy.IO as Text
12 import LLVM.AST.AddrSpace
13 import LLVM.AST.Constant
15 import LLVM.AST.FloatingPointPredicate hiding (False, True)
16 import LLVM.AST.Operand
17 import LLVM.AST.Type as Type
22 import LLVM.OrcJIT.CompileLayer
23 import LLVM.PassManager
27 import System.IO.Error
28 import Text.Read (readMaybe)
31 { jitEnvContext :: Context
32 , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
33 , jitEnvModuleKey :: ModuleKey
38 withContext $ \ctx -> withHostTargetMachine $ \tm -> do
39 withExecutionSession $ \exSession ->
40 withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
41 withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
42 withIRCompileLayer linkingLayer tm $ \compLayer -> do
43 withModuleKey exSession $ \mdlKey -> do
44 let env = JITEnv ctx compLayer mdlKey
45 ast <- runReaderT (buildModuleT "main" repl) env
48 -- This can eventually be used to resolve external functions, e.g. a stdlib call
49 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
50 symResolver sym = undefined
52 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
54 liftIO $ hPutStr stderr "ready> "
55 mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
60 Nothing -> liftIO $ hPutStrLn stderr "Couldn't parse"
62 anon <- isAnonExpr <$> hoist (buildAST ast)
65 ast <- moduleSoFar "main"
66 ctx <- lift $ asks jitEnvContext
68 liftIO $ withModuleFromAST ctx ast $ \mdl -> do
69 Text.hPutStrLn stderr $ ppll def
70 let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
71 -- this returns true if the module was modified
72 withPassManager spec $ flip runPassManager mdl
73 when anon (jit env mdl >>= hPrint stderr)
75 when anon (removeDef def)
79 | isEOFError e = return Nothing
80 | otherwise = ioError e
81 isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
84 jit :: JITEnv -> Module -> IO Double
85 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
86 withModule compLayer mdlKey mdl $ do
89 type Binds = Map.Map String Operand
91 buildAST :: AST -> ModuleBuilder Operand
92 buildAST (Function (Prototype nameStr paramStrs) body) = do
93 let n = fromString nameStr
94 function n params Type.double $ \ops -> do
95 let binds = Map.fromList (zip paramStrs ops)
96 flip runReaderT binds $ buildExpr body >>= ret
97 where params = zip (repeat Type.double) (map fromString paramStrs)
99 buildAST (Extern (Prototype nameStr params)) =
100 extern (fromString nameStr) (replicate (length params) Type.double) Type.double
102 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
103 const $ flip runReaderT mempty $ buildExpr x >>= ret
105 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
106 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
107 buildExpr (Var n) = do
109 case binds Map.!? n of
111 Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
113 buildExpr (BinOp op a b) = do
118 then uitofp tmp Type.double
131 buildExpr (Call callee params) = do
132 paramOps <- mapM buildExpr params
133 let nam = fromString callee
134 -- get a pointer to the function
135 typ = FunctionType Type.double (replicate (length params) Type.double) False
136 ptrTyp = Type.PointerType typ (AddrSpace 0)
137 ref = GlobalReference ptrTyp nam
138 call (ConstantOperand ref) (zip paramOps (repeat []))