1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
4 import AST as K -- K for Kaleidoscope
7 import Control.Monad.Trans.Class
8 import Control.Monad.Trans.Reader
9 import Control.Monad.IO.Class
11 import qualified Data.Map as Map
12 import qualified Data.Text.Lazy.IO as Text
14 import LLVM.AST.AddrSpace
15 import LLVM.AST.Constant
17 import LLVM.AST.FloatingPointPredicate hiding (False, True)
18 import LLVM.AST.Operand
19 import LLVM.AST.Type as Type
25 import LLVM.OrcJIT.CompileLayer
26 import LLVM.PassManager
31 import System.IO.Error
32 import Text.Read (readMaybe)
34 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
37 { jitEnvContext :: Context
38 , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
39 , jitEnvModuleKey :: ModuleKey
44 loadLibraryPermanently (Just "stdlib.dylib")
45 withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
46 withExecutionSession $ \exSession ->
47 withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
48 withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
49 withIRCompileLayer linkingLayer tm $ \compLayer ->
50 withModuleKey exSession $ \mdlKey -> do
51 let env = JITEnv ctx compLayer mdlKey
52 _ast <- runReaderT (buildModuleT "main" repl) env
55 -- This can eventually be used to resolve external functions, e.g. a stdlib call
56 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
58 ptr <- getSymbolAddressInProcess sym
59 putStrLn $ "Resolving " <> show sym <> " to 0x" <> showHex ptr ""
60 return (Right (JITSymbol ptr defaultJITSymbolFlags))
62 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
64 liftIO $ hPutStr stderr "ready> "
65 mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
70 Nothing -> liftIO $ hPutStrLn stderr "Couldn't parse"
72 anon <- isAnonExpr <$> hoist (buildAST ast)
75 llvmAst <- moduleSoFar "main"
76 ctx <- lift $ asks jitEnvContext
78 liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
79 Text.hPutStrLn stderr $ ppll def
80 let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
81 -- this returns true if the module was modified
82 withPassManager spec $ flip runPassManager mdl
83 when anon (jit env mdl >>= hPrint stderr)
85 when anon (removeDef def)
89 | isEOFError e = return Nothing
90 | otherwise = ioError e
91 isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
94 jit :: JITEnv -> Module -> IO Double
95 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
96 withModule compLayer mdlKey mdl $ do
97 mangled <- mangleSymbol compLayer "__anon_expr"
98 Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
99 mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
101 type Binds = Map.Map String Operand
103 buildAST :: AST -> ModuleBuilder Operand
104 buildAST (Function (Prototype nameStr paramStrs) body) = do
105 let n = fromString nameStr
106 function n params Type.double $ \ops -> do
107 let binds = Map.fromList (zip paramStrs ops)
108 flip runReaderT binds $ buildExpr body >>= ret
109 where params = zip (repeat Type.double) (map fromString paramStrs)
111 buildAST (Extern (Prototype nameStr params)) =
112 extern (fromString nameStr) (replicate (length params) Type.double) Type.double
114 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
115 const $ flip runReaderT mempty $ buildExpr x >>= ret
117 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
118 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
119 buildExpr (Var n) = do
121 case binds Map.!? n of
123 Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
125 buildExpr (BinOp op a b) = do
130 then uitofp tmp Type.double
143 buildExpr (Call callee params) = do
144 paramOps <- mapM buildExpr params
145 let nam = fromString callee
146 -- get a pointer to the function
147 typ = FunctionType Type.double (replicate (length params) Type.double) False
148 ptrTyp = Type.PointerType typ (AddrSpace 0)
149 ref = GlobalReference ptrTyp nam
150 call (ConstantOperand ref) (zip paramOps (repeat []))
152 buildExpr (If cond thenE elseE) = mdo
153 _ifB <- block `named` "if"
155 -- since everything is a double, false == 0
156 let zero = ConstantOperand (Float (Double 0))
157 condV <- buildExpr cond
158 cmp <- fcmp ONE zero condV `named` "cmp"
160 condBr cmp thenB elseB
162 thenB <- block `named` "then"
163 thenOp <- buildExpr thenE
166 elseB <- block `named` "else"
167 elseOp <- buildExpr elseE
170 mergeB <- block `named` "ifcont"
171 phi [(thenOp, thenB), (elseOp, elseB)]
173 buildExpr (For name init cond mStep body) = mdo
174 preheaderB <- block `named` "preheader"
176 initV <- buildExpr init `named` "init"
178 -- build the condition expression with 'i' in the bindings
179 initCondV <- withReaderT (Map.insert name initV) $
180 (buildExpr cond >>= fcmp ONE zero) `named` "initcond"
182 -- skip the loop if we don't meet the condition with the init
183 condBr initCondV loopB afterB
185 loopB <- block `named` "loop"
186 i <- phi [(initV, preheaderB), (nextVar, loopB)] `named` "i"
188 -- build the body expression with 'i' in the bindings
189 withReaderT (Map.insert name i) $ buildExpr body `named` "body"
191 -- default to 1 if there's no step defined
192 stepV <- case mStep of
193 Just step -> buildExpr step
194 Nothing -> return $ ConstantOperand (Float (Double 1))
196 nextVar <- fadd i stepV `named` "nextvar"
198 let zero = ConstantOperand (Float (Double 0))
199 -- again we need 'i' in the bindings
200 condV <- withReaderT (Map.insert name i) $
201 (buildExpr cond >>= fcmp ONE zero) `named` "cond"
202 condBr condV loopB afterB
204 afterB <- block `named` "after"
205 -- since a for loop doesn't really have a value, return 0
206 return $ ConstantOperand (Float (Double 0))