1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
4 import AST as K -- K for Kaleidoscope
7 import Control.Monad.Trans.Class
8 import Control.Monad.Trans.Reader
9 import Control.Monad.IO.Class
11 import qualified Data.Map as Map
12 import qualified Data.Text.Lazy.IO as Text
14 import LLVM.AST.AddrSpace
15 import LLVM.AST.Constant
17 import LLVM.AST.FloatingPointPredicate hiding (False, True)
18 import LLVM.AST.Operand
19 import LLVM.AST.Type as Type
24 import LLVM.OrcJIT.CompileLayer
25 import LLVM.PassManager
29 import System.IO.Error
30 import Text.Read (readMaybe)
32 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
35 { jitEnvContext :: Context
36 , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
37 , jitEnvModuleKey :: ModuleKey
42 withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
43 withExecutionSession $ \exSession ->
44 withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
45 withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
46 withIRCompileLayer linkingLayer tm $ \compLayer ->
47 withModuleKey exSession $ \mdlKey -> do
48 let env = JITEnv ctx compLayer mdlKey
49 _ast <- runReaderT (buildModuleT "main" repl) env
52 -- This can eventually be used to resolve external functions, e.g. a stdlib call
53 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
54 symResolver sym = undefined
56 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
58 liftIO $ hPutStr stderr "ready> "
59 mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
64 Nothing -> liftIO $ hPutStrLn stderr "Couldn't parse"
66 anon <- isAnonExpr <$> hoist (buildAST ast)
69 llvmAst <- moduleSoFar "main"
70 ctx <- lift $ asks jitEnvContext
72 liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
73 Text.hPutStrLn stderr $ ppll def
74 let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
75 -- this returns true if the module was modified
76 withPassManager spec $ flip runPassManager mdl
77 when anon (jit env mdl >>= hPrint stderr)
79 when anon (removeDef def)
83 | isEOFError e = return Nothing
84 | otherwise = ioError e
85 isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
88 jit :: JITEnv -> Module -> IO Double
89 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
90 withModule compLayer mdlKey mdl $ do
91 mangled <- mangleSymbol compLayer "__anon_expr"
92 Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
93 mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
95 type Binds = Map.Map String Operand
97 buildAST :: AST -> ModuleBuilder Operand
98 buildAST (Function (Prototype nameStr paramStrs) body) = do
99 let n = fromString nameStr
100 function n params Type.double $ \ops -> do
101 let binds = Map.fromList (zip paramStrs ops)
102 flip runReaderT binds $ buildExpr body >>= ret
103 where params = zip (repeat Type.double) (map fromString paramStrs)
105 buildAST (Extern (Prototype nameStr params)) =
106 extern (fromString nameStr) (replicate (length params) Type.double) Type.double
108 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
109 const $ flip runReaderT mempty $ buildExpr x >>= ret
111 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
112 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
113 buildExpr (Var n) = do
115 case binds Map.!? n of
117 Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
119 buildExpr (BinOp op a b) = do
124 then uitofp tmp Type.double
137 buildExpr (Call callee params) = do
138 paramOps <- mapM buildExpr params
139 let nam = fromString callee
140 -- get a pointer to the function
141 typ = FunctionType Type.double (replicate (length params) Type.double) False
142 ptrTyp = Type.PointerType typ (AddrSpace 0)
143 ref = GlobalReference ptrTyp nam
144 call (ConstantOperand ref) (zip paramOps (repeat []))
146 buildExpr (If cond thenE elseE) = mdo
147 _ifB <- block `named` "if"
149 -- since everything is a double, false == 0
150 let zero = ConstantOperand (Float (Double 0))
151 condV <- buildExpr cond
152 cmp <- fcmp ONE zero condV `named` "cmp"
154 condBr cmp thenB elseB
156 thenB <- block `named` "then"
157 thenOp <- buildExpr thenE
160 elseB <- block `named` "else"
161 elseOp <- buildExpr elseE
164 mergeB <- block `named` "ifcont"
165 phi [(thenOp, thenB), (elseOp, elseB)]