Find our JIT'ed function and run it
[kaleidoscope-hs.git] / Main.hs
diff --git a/Main.hs b/Main.hs
index bff1c08c965c037080cc6aba052a60c40f1dc999..48d93a224de17958d9f4586bc1aa079e43aee5c5 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -2,11 +2,14 @@
 
 import AST as K -- K for Kaleidoscope
 import Utils
 
 import AST as K -- K for Kaleidoscope
 import Utils
+import Control.Monad
+import Control.Monad.Trans.Class
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
+import Foreign.Ptr
 import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
@@ -16,6 +19,8 @@ import LLVM.AST.Type as Type
 import LLVM.Context
 import LLVM.IRBuilder
 import LLVM.Module
 import LLVM.Context
 import LLVM.IRBuilder
 import LLVM.Module
+import LLVM.OrcJIT
+import LLVM.OrcJIT.CompileLayer
 import LLVM.PassManager
 import LLVM.Pretty
 import LLVM.Target
 import LLVM.PassManager
 import LLVM.Pretty
 import LLVM.Target
@@ -23,17 +28,31 @@ import System.IO
 import System.IO.Error
 import Text.Read (readMaybe)
 
 import System.IO.Error
 import Text.Read (readMaybe)
 
+foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
+
+data JITEnv = JITEnv
+  { jitEnvContext :: Context
+  , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
+  , jitEnvModuleKey :: ModuleKey
+  }
+
 main :: IO ()
 main = do
 main :: IO ()
 main = do
-  mdl' <- buildModuleT "main" repl
-  withContext $ \ctx -> withHostTargetMachine $ \tm ->
-    withModuleFromAST ctx mdl' $ \mdl -> do
-      let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
-      -- this returns true if the module was modified
-      withPassManager spec $ flip runPassManager mdl
-      Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
+  withContext $ \ctx -> withHostTargetMachine $ \tm -> do
+    withExecutionSession $ \exSession ->
+      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
+        withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
+          withIRCompileLayer linkingLayer tm $ \compLayer -> do
+            withModuleKey exSession $ \mdlKey -> do
+              let env = JITEnv ctx compLayer mdlKey
+              ast <- runReaderT (buildModuleT "main" repl) env
+              return ()
+
+-- This can eventually be used to resolve external functions, e.g. a stdlib call
+symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
+symResolver sym = undefined
 
 
-repl :: ModuleBuilderT IO ()
+repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
@@ -43,13 +62,34 @@ repl = do
       case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Just ast -> do
       case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Just ast -> do
-          hoist $ buildAST ast
-          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
+          anon <- isAnonExpr <$> hoist (buildAST ast)
+          def <- mostRecentDef
+          
+          ast <- moduleSoFar "main"
+          ctx <- lift $ asks jitEnvContext
+          env <- lift ask
+          liftIO $ withModuleFromAST ctx ast $ \mdl -> do
+            Text.hPutStrLn stderr $ ppll def
+            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+            -- this returns true if the module was modified
+            withPassManager spec $ flip runPassManager mdl
+            when anon (jit env mdl >>= hPrint stderr)
+
+          when anon (removeDef def)
       repl
   where
     eofHandler e
       | isEOFError e = return Nothing
       | otherwise = ioError e
       repl
   where
     eofHandler e
       | isEOFError e = return Nothing
       | otherwise = ioError e
+    isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
+    isAnonExpr _ = False
+
+jit :: JITEnv -> Module -> IO Double
+jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
+  withModule compLayer mdlKey mdl $ do
+    mangled <- mangleSymbol compLayer "__anon_expr"
+    Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
+    mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
 
 type Binds = Map.Map String Operand
 
 
 type Binds = Map.Map String Operand