Add JIT tutorial-1
authorLuke Lau <luke_lau@icloud.com>
Mon, 3 Jun 2019 14:48:04 +0000 (15:48 +0100)
committerLuke Lau <luke_lau@icloud.com>
Mon, 3 Jun 2019 23:15:37 +0000 (00:15 +0100)
We have LLVM IR now, but our computers still can't run it. We could
compile our code "offline" and write it to a file, but LLVM also
provides frameworks for JITing: Just-in-time compilation. This is where
the code is compiled just before it is run, and we will be using it make
an interactive REPL.

Note that JITs are not the same as interpreters: An interpreter reads
the program and directly computes the result. A JIT reads the program
and generates more code for the computer to run, which then computes the
result.

The current LLVM Kaleidoscope tutorial uses the old MC JIT framework:
This tutorial will be using the fancy new OrcJIT framework. It's a bit
more complicated but provides a lot more flexibility.

Main.hs
Utils.hs

diff --git a/Main.hs b/Main.hs
index 2a5a7e0fbc6c490b0d90af2da2e54a843138341f..48d93a224de17958d9f4586bc1aa079e43aee5c5 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -2,27 +2,57 @@
 
 import AST as K -- K for Kaleidoscope
 import Utils
+import Control.Monad
+import Control.Monad.Trans.Class
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
+import Foreign.Ptr
 import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
+import LLVM.Context
 import LLVM.IRBuilder
+import LLVM.Module
+import LLVM.OrcJIT
+import LLVM.OrcJIT.CompileLayer
+import LLVM.PassManager
 import LLVM.Pretty
+import LLVM.Target
 import System.IO
 import System.IO.Error
 import Text.Read (readMaybe)
 
+foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
+
+data JITEnv = JITEnv
+  { jitEnvContext :: Context
+  , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
+  , jitEnvModuleKey :: ModuleKey
+  }
+
 main :: IO ()
-main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
+main = do
+  withContext $ \ctx -> withHostTargetMachine $ \tm -> do
+    withExecutionSession $ \exSession ->
+      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
+        withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
+          withIRCompileLayer linkingLayer tm $ \compLayer -> do
+            withModuleKey exSession $ \mdlKey -> do
+              let env = JITEnv ctx compLayer mdlKey
+              ast <- runReaderT (buildModuleT "main" repl) env
+              return ()
+
+-- This can eventually be used to resolve external functions, e.g. a stdlib call
+symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
+symResolver sym = undefined
 
-repl :: ModuleBuilderT IO ()
+repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
@@ -32,13 +62,34 @@ repl = do
       case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Just ast -> do
-          hoist $ buildAST ast
-          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
+          anon <- isAnonExpr <$> hoist (buildAST ast)
+          def <- mostRecentDef
+          
+          ast <- moduleSoFar "main"
+          ctx <- lift $ asks jitEnvContext
+          env <- lift ask
+          liftIO $ withModuleFromAST ctx ast $ \mdl -> do
+            Text.hPutStrLn stderr $ ppll def
+            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+            -- this returns true if the module was modified
+            withPassManager spec $ flip runPassManager mdl
+            when anon (jit env mdl >>= hPrint stderr)
+
+          when anon (removeDef def)
       repl
   where
     eofHandler e
       | isEOFError e = return Nothing
       | otherwise = ioError e
+    isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
+    isAnonExpr _ = False
+
+jit :: JITEnv -> Module -> IO Double
+jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
+  withModule compLayer mdlKey mdl $ do
+    mangled <- mangleSymbol compLayer "__anon_expr"
+    Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
+    mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
 
 type Binds = Map.Map String Operand
 
index 2a83c3591276ea116a969301fd4f24ae772af949..cd325a40b3c3c296549e0ea11ad9947b1c9d53a4 100644 (file)
--- a/Utils.hs
+++ b/Utils.hs
@@ -4,11 +4,26 @@ Shoving away gross stuff into this one module.
 module Utils where
 
 import Control.Monad.Trans.State
+import Data.ByteString.Short (ShortByteString)
 import Data.Functor.Identity
+import Data.List
 import LLVM.AST
 import LLVM.IRBuilder.Module
 import LLVM.IRBuilder.Internal.SnocList
 
+moduleSoFar :: MonadModuleBuilder m => ShortByteString -> m Module
+moduleSoFar nm = do
+  s <- liftModuleState get
+  let ds = getSnocList (builderDefs s)
+  return $ defaultModule { moduleName = nm, moduleDefinitions = ds }
+
+removeDef :: MonadModuleBuilder m => Definition -> m ()
+removeDef def = liftModuleState (modify update)
+  where
+    update (ModuleBuilderState defs typeDefs) =
+      let newDefs = SnocList (delete def (getSnocList defs))
+      in ModuleBuilderState newDefs typeDefs
+
 mostRecentDef :: Monad m => ModuleBuilderT m Definition
 mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get