1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
4 import AST as K -- K for Kaleidoscope
7 import Control.Monad.Trans.Class
8 import Control.Monad.Trans.Reader
9 import Control.Monad.IO.Class
11 import qualified Data.Map as Map
12 import qualified Data.Text.Lazy.IO as Text
14 import LLVM.AST.AddrSpace
15 import LLVM.AST.Constant
17 import LLVM.AST.FloatingPointPredicate hiding (False, True)
18 import LLVM.AST.Operand
19 import LLVM.AST.Type as Type
25 import LLVM.OrcJIT.CompileLayer
26 import LLVM.PassManager
30 import System.IO.Error
31 import Text.Read (readMaybe)
33 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
36 { jitEnvContext :: Context
37 , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
38 , jitEnvModuleKey :: ModuleKey
43 loadLibraryPermanently (Just "stdlib.dylib")
44 withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
45 withExecutionSession $ \exSession ->
46 withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
47 withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
48 withIRCompileLayer linkingLayer tm $ \compLayer ->
49 withModuleKey exSession $ \mdlKey -> do
50 let env = JITEnv ctx compLayer mdlKey
51 _ast <- runReaderT (buildModuleT "main" repl) env
54 -- This can eventually be used to resolve external functions, e.g. a stdlib call
55 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
56 symResolver sym = undefined
58 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
60 liftIO $ hPutStr stderr "ready> "
61 mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
66 Nothing -> liftIO $ hPutStrLn stderr "Couldn't parse"
68 anon <- isAnonExpr <$> hoist (buildAST ast)
71 llvmAst <- moduleSoFar "main"
72 ctx <- lift $ asks jitEnvContext
74 liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
75 Text.hPutStrLn stderr $ ppll def
76 let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
77 -- this returns true if the module was modified
78 withPassManager spec $ flip runPassManager mdl
79 when anon (jit env mdl >>= hPrint stderr)
81 when anon (removeDef def)
85 | isEOFError e = return Nothing
86 | otherwise = ioError e
87 isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
90 jit :: JITEnv -> Module -> IO Double
91 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
92 withModule compLayer mdlKey mdl $ do
93 mangled <- mangleSymbol compLayer "__anon_expr"
94 Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
95 mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
97 type Binds = Map.Map String Operand
99 buildAST :: AST -> ModuleBuilder Operand
100 buildAST (Function (Prototype nameStr paramStrs) body) = do
101 let n = fromString nameStr
102 function n params Type.double $ \ops -> do
103 let binds = Map.fromList (zip paramStrs ops)
104 flip runReaderT binds $ buildExpr body >>= ret
105 where params = zip (repeat Type.double) (map fromString paramStrs)
107 buildAST (Extern (Prototype nameStr params)) =
108 extern (fromString nameStr) (replicate (length params) Type.double) Type.double
110 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
111 const $ flip runReaderT mempty $ buildExpr x >>= ret
113 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
114 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
115 buildExpr (Var n) = do
117 case binds Map.!? n of
119 Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
121 buildExpr (BinOp op a b) = do
126 then uitofp tmp Type.double
139 buildExpr (Call callee params) = do
140 paramOps <- mapM buildExpr params
141 let nam = fromString callee
142 -- get a pointer to the function
143 typ = FunctionType Type.double (replicate (length params) Type.double) False
144 ptrTyp = Type.PointerType typ (AddrSpace 0)
145 ref = GlobalReference ptrTyp nam
146 call (ConstantOperand ref) (zip paramOps (repeat []))
148 buildExpr (If cond thenE elseE) = mdo
149 _ifB <- block `named` "if"
151 -- since everything is a double, false == 0
152 let zero = ConstantOperand (Float (Double 0))
153 condV <- buildExpr cond
154 cmp <- fcmp ONE zero condV `named` "cmp"
156 condBr cmp thenB elseB
158 thenB <- block `named` "then"
159 thenOp <- buildExpr thenE
162 elseB <- block `named` "else"
163 elseOp <- buildExpr elseE
166 mergeB <- block `named` "ifcont"
167 phi [(thenOp, thenB), (elseOp, elseB)]
169 buildExpr (For name init cond mStep body) = mdo
170 preheaderB <- block `named` "preheader"
172 initV <- buildExpr init `named` "init"
174 -- build the condition expression with 'i' in the bindings
175 initCondV <- withReaderT (Map.insert name initV) $
176 (buildExpr cond >>= fcmp ONE zero) `named` "initcond"
178 -- skip the loop if we don't meet the condition with the init
179 condBr initCondV loopB afterB
181 loopB <- block `named` "loop"
182 i <- phi [(initV, preheaderB), (nextVar, loopB)] `named` "i"
184 -- build the body expression with 'i' in the bindings
185 withReaderT (Map.insert name i) $ buildExpr body `named` "body"
187 -- default to 1 if there's no step defined
188 stepV <- case mStep of
189 Just step -> buildExpr step
190 Nothing -> return $ ConstantOperand (Float (Double 1))
192 nextVar <- fadd i stepV `named` "nextvar"
194 let zero = ConstantOperand (Float (Double 0))
195 -- again we need 'i' in the bindings
196 condV <- withReaderT (Map.insert name i) $
197 (buildExpr cond >>= fcmp ONE zero) `named` "cond"
198 condBr condV loopB afterB
200 afterB <- block `named` "after"
201 -- since a for loop doesn't really have a value, return 0
202 return $ ConstantOperand (Float (Double 0))