1 {-# LANGUAGE OverloadedStrings #-}
3 import AST as K -- K for Kaleidoscope
6 import Control.Monad.Trans.Class
7 import Control.Monad.Trans.Reader
8 import Control.Monad.IO.Class
10 import qualified Data.Map as Map
11 import qualified Data.Text.Lazy.IO as Text
13 import LLVM.AST.AddrSpace
14 import LLVM.AST.Constant
16 import LLVM.AST.FloatingPointPredicate hiding (False, True)
17 import LLVM.AST.Operand
18 import LLVM.AST.Type as Type
23 import LLVM.OrcJIT.CompileLayer
24 import LLVM.PassManager
28 import System.IO.Error
29 import Text.Read (readMaybe)
31 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
34 { jitEnvContext :: Context
35 , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
36 , jitEnvModuleKey :: ModuleKey
41 withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
42 withExecutionSession $ \exSession ->
43 withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
44 withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
45 withIRCompileLayer linkingLayer tm $ \compLayer ->
46 withModuleKey exSession $ \mdlKey -> do
47 let env = JITEnv ctx compLayer mdlKey
48 _ast <- runReaderT (buildModuleT "main" repl) env
51 -- This can eventually be used to resolve external functions, e.g. a stdlib call
52 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
53 symResolver sym = undefined
55 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
57 liftIO $ hPutStr stderr "ready> "
58 mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
63 Nothing -> liftIO $ hPutStrLn stderr "Couldn't parse"
65 anon <- isAnonExpr <$> hoist (buildAST ast)
68 llvmAst <- moduleSoFar "main"
69 ctx <- lift $ asks jitEnvContext
71 liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
72 Text.hPutStrLn stderr $ ppll def
73 let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
74 -- this returns true if the module was modified
75 withPassManager spec $ flip runPassManager mdl
76 when anon (jit env mdl >>= hPrint stderr)
78 when anon (removeDef def)
82 | isEOFError e = return Nothing
83 | otherwise = ioError e
84 isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
87 jit :: JITEnv -> Module -> IO Double
88 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
89 withModule compLayer mdlKey mdl $ do
90 mangled <- mangleSymbol compLayer "__anon_expr"
91 Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
92 mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
94 type Binds = Map.Map String Operand
96 buildAST :: AST -> ModuleBuilder Operand
97 buildAST (Function (Prototype nameStr paramStrs) body) = do
98 let n = fromString nameStr
99 function n params Type.double $ \ops -> do
100 let binds = Map.fromList (zip paramStrs ops)
101 flip runReaderT binds $ buildExpr body >>= ret
102 where params = zip (repeat Type.double) (map fromString paramStrs)
104 buildAST (Extern (Prototype nameStr params)) =
105 extern (fromString nameStr) (replicate (length params) Type.double) Type.double
107 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
108 const $ flip runReaderT mempty $ buildExpr x >>= ret
110 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
111 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
112 buildExpr (Var n) = do
114 case binds Map.!? n of
116 Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
118 buildExpr (BinOp op a b) = do
123 then uitofp tmp Type.double
136 buildExpr (Call callee params) = do
137 paramOps <- mapM buildExpr params
138 let nam = fromString callee
139 -- get a pointer to the function
140 typ = FunctionType Type.double (replicate (length params) Type.double) False
141 ptrTyp = Type.PointerType typ (AddrSpace 0)
142 ref = GlobalReference ptrTyp nam
143 call (ConstantOperand ref) (zip paramOps (repeat []))