Detect and remove __anon_exprs in repl
authorLuke Lau <luke_lau@icloud.com>
Sun, 2 Jun 2019 16:39:47 +0000 (17:39 +0100)
committerLuke Lau <luke_lau@icloud.com>
Thu, 7 Nov 2019 17:11:07 +0000 (17:11 +0000)
You might have noticed that when entering in top-level expressions you
end up getting duplicate __anon_expr.1, __anon_expr.2s etc. Since we
don't want these anonymous functions to stick around, in this commit
we're detecting them and then removing them once we're done.

We also only want to run the JIT whenever the user has entered a
top-level expression, so we've also sketched that out.

Main.hs
Utils.hs

diff --git a/Main.hs b/Main.hs
index c3f929214099d738794ea2b6a76abe7be0d3124f..bc7c3077f9b6bbdda78273381921826110acf62a 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -2,6 +2,7 @@
 
 import AST as K -- K for Kaleidoscope
 import Utils
+import Control.Monad
 import Control.Monad.Trans.Class
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
@@ -17,6 +18,8 @@ import LLVM.AST.Type as Type
 import LLVM.Context
 import LLVM.IRBuilder
 import LLVM.Module
+import LLVM.OrcJIT
+import LLVM.OrcJIT.CompileLayer
 import LLVM.PassManager
 import LLVM.Pretty
 import LLVM.Target
@@ -24,13 +27,29 @@ import System.IO
 import System.IO.Error
 import Text.Read (readMaybe)
 
+data JITEnv = JITEnv
+  { jitEnvContext :: Context
+  , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
+  , jitEnvModuleKey :: ModuleKey
+  }
+
 main :: IO ()
-main = do
-  withContext $ \ctx -> withHostTargetMachineDefault $ \tm -> do
-    ast <- runReaderT (buildModuleT "main" repl) ctx
+main =
+  withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
+    withExecutionSession $ \exSession ->
+      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
+        withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
+          withIRCompileLayer linkingLayer tm $ \compLayer ->
+            withModuleKey exSession $ \mdlKey -> do
+              let env = JITEnv ctx compLayer mdlKey
+              _ast <- runReaderT (buildModuleT "main" repl) env
               return ()
 
-repl :: ModuleBuilderT (ReaderT Context IO) ()
+-- This can eventually be used to resolve external functions, e.g. a stdlib call
+symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
+symResolver sym = undefined
+
+repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
@@ -40,21 +59,32 @@ repl = do
       case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Just ast -> do
-          hoist $ buildAST ast
-          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
+          anon <- isAnonExpr <$> hoist (buildAST ast)
+          def <- mostRecentDef
           
-          ast <- moduleSoFar "main"
-          ctx <- lift ask
-          liftIO $ withModuleFromAST ctx ast $ \mdl -> do
+          llvmAst <- moduleSoFar "main"
+          ctx <- lift $ asks jitEnvContext
+          env <- lift ask
+          liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
+            Text.hPutStrLn stderr $ ppll def
             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
             -- this returns true if the module was modified
             withPassManager spec $ flip runPassManager mdl
-            Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
+            when anon (jit env mdl >>= hPrint stderr)
+
+          when anon (removeDef def)
       repl
   where
     eofHandler e
       | isEOFError e = return Nothing
       | otherwise = ioError e
+    isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
+    isAnonExpr _ = False
+
+jit :: JITEnv -> Module -> IO Double
+jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
+  withModule compLayer mdlKey mdl $
+    return 0
 
 type Binds = Map.Map String Operand
 
index 3bd3b37e01e54274cf4841b52e18d8d4b2cea97e..cd325a40b3c3c296549e0ea11ad9947b1c9d53a4 100644 (file)
--- a/Utils.hs
+++ b/Utils.hs
@@ -6,6 +6,7 @@ module Utils where
 import Control.Monad.Trans.State
 import Data.ByteString.Short (ShortByteString)
 import Data.Functor.Identity
+import Data.List
 import LLVM.AST
 import LLVM.IRBuilder.Module
 import LLVM.IRBuilder.Internal.SnocList
@@ -16,6 +17,13 @@ moduleSoFar nm = do
   let ds = getSnocList (builderDefs s)
   return $ defaultModule { moduleName = nm, moduleDefinitions = ds }
 
+removeDef :: MonadModuleBuilder m => Definition -> m ()
+removeDef def = liftModuleState (modify update)
+  where
+    update (ModuleBuilderState defs typeDefs) =
+      let newDefs = SnocList (delete def (getSnocList defs))
+      in ModuleBuilderState newDefs typeDefs
+
 mostRecentDef :: Monad m => ModuleBuilderT m Definition
 mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get