From: Luke Lau Date: Mon, 3 Jun 2019 14:48:04 +0000 (+0100) Subject: Add JIT X-Git-Url: https://git.lukelau.me/?a=commitdiff_plain;h=refs%2Fheads%2Ftutorial-1;hp=de8c7223c79f10c69f9916db1f15b34d20938e2c;p=kaleidoscope-hs.git Add JIT We have LLVM IR now, but our computers still can't run it. We could compile our code "offline" and write it to a file, but LLVM also provides frameworks for JITing: Just-in-time compilation. This is where the code is compiled just before it is run, and we will be using it make an interactive REPL. Note that JITs are not the same as interpreters: An interpreter reads the program and directly computes the result. A JIT reads the program and generates more code for the computer to run, which then computes the result. The current LLVM Kaleidoscope tutorial uses the old MC JIT framework: This tutorial will be using the fancy new OrcJIT framework. It's a bit more complicated but provides a lot more flexibility. --- diff --git a/Main.hs b/Main.hs index 2a5a7e0..48d93a2 100644 --- a/Main.hs +++ b/Main.hs @@ -2,27 +2,57 @@ import AST as K -- K for Kaleidoscope import Utils +import Control.Monad +import Control.Monad.Trans.Class import Control.Monad.Trans.Reader import Control.Monad.IO.Class import Data.String import qualified Data.Map as Map import qualified Data.Text.Lazy.IO as Text +import Foreign.Ptr import LLVM.AST.AddrSpace import LLVM.AST.Constant import LLVM.AST.Float import LLVM.AST.FloatingPointPredicate hiding (False, True) import LLVM.AST.Operand import LLVM.AST.Type as Type +import LLVM.Context import LLVM.IRBuilder +import LLVM.Module +import LLVM.OrcJIT +import LLVM.OrcJIT.CompileLayer +import LLVM.PassManager import LLVM.Pretty +import LLVM.Target import System.IO import System.IO.Error import Text.Read (readMaybe) +foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double + +data JITEnv = JITEnv + { jitEnvContext :: Context + , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer + , jitEnvModuleKey :: ModuleKey + } + main :: IO () -main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll +main = do + withContext $ \ctx -> withHostTargetMachine $ \tm -> do + withExecutionSession $ \exSession -> + withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr -> + withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer -> + withIRCompileLayer linkingLayer tm $ \compLayer -> do + withModuleKey exSession $ \mdlKey -> do + let env = JITEnv ctx compLayer mdlKey + ast <- runReaderT (buildModuleT "main" repl) env + return () + +-- This can eventually be used to resolve external functions, e.g. a stdlib call +symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol) +symResolver sym = undefined -repl :: ModuleBuilderT IO () +repl :: ModuleBuilderT (ReaderT JITEnv IO) () repl = do liftIO $ hPutStr stderr "ready> " mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler @@ -32,13 +62,34 @@ repl = do case readMaybe l of Nothing -> liftIO $ hPutStrLn stderr "Couldn't parse" Just ast -> do - hoist $ buildAST ast - mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll + anon <- isAnonExpr <$> hoist (buildAST ast) + def <- mostRecentDef + + ast <- moduleSoFar "main" + ctx <- lift $ asks jitEnvContext + env <- lift ask + liftIO $ withModuleFromAST ctx ast $ \mdl -> do + Text.hPutStrLn stderr $ ppll def + let spec = defaultCuratedPassSetSpec { optLevel = Just 3 } + -- this returns true if the module was modified + withPassManager spec $ flip runPassManager mdl + when anon (jit env mdl >>= hPrint stderr) + + when anon (removeDef def) repl where eofHandler e | isEOFError e = return Nothing | otherwise = ioError e + isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True + isAnonExpr _ = False + +jit :: JITEnv -> Module -> IO Double +jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl = + withModule compLayer mdlKey mdl $ do + mangled <- mangleSymbol compLayer "__anon_expr" + Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False + mkFun (castPtrToFunPtr (wordPtrToPtr fPtr)) type Binds = Map.Map String Operand diff --git a/Utils.hs b/Utils.hs index 2a83c35..cd325a4 100644 --- a/Utils.hs +++ b/Utils.hs @@ -4,11 +4,26 @@ Shoving away gross stuff into this one module. module Utils where import Control.Monad.Trans.State +import Data.ByteString.Short (ShortByteString) import Data.Functor.Identity +import Data.List import LLVM.AST import LLVM.IRBuilder.Module import LLVM.IRBuilder.Internal.SnocList +moduleSoFar :: MonadModuleBuilder m => ShortByteString -> m Module +moduleSoFar nm = do + s <- liftModuleState get + let ds = getSnocList (builderDefs s) + return $ defaultModule { moduleName = nm, moduleDefinitions = ds } + +removeDef :: MonadModuleBuilder m => Definition -> m () +removeDef def = liftModuleState (modify update) + where + update (ModuleBuilderState defs typeDefs) = + let newDefs = SnocList (delete def (getSnocList defs)) + in ModuleBuilderState newDefs typeDefs + mostRecentDef :: Monad m => ModuleBuilderT m Definition mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get