Add putchard and move to OrcJIT
[kaleidoscope-hs-old.git] / Main.hs
diff --git a/Main.hs b/Main.hs
index 4d8275828510ceed0bbcc99f3b2b95798042b2ae..f096ea6fa5177b8deb5671c570cb97ba79f37192 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -13,7 +13,8 @@ import Foreign.Ptr
 import System.Exit
 import System.IO
 import LLVM.Context
-import LLVM.ExecutionEngine
+import LLVM.OrcJIT
+import LLVM.OrcJIT.CompileLayer
 import LLVM.Module
 import LLVM.PassManager
 import LLVM.IRBuilder
@@ -25,49 +26,73 @@ import LLVM.AST.Operand
 import LLVM.AST.Type as Type
 import LLVM.AST.Typed
 import LLVM.Pretty
+import LLVM.Linking
+import LLVM.Target
+
+import Control.Concurrent.MVar
 
 type ModuleBuilderE = ModuleBuilderT (Either String)
 
-foreign import ccall "dynamic" exprFun :: FunPtr (IO Double) -> IO Double
+foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
 
 main :: IO ()
 main = do
   AST.Program asts <- read <$> getContents
   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
+
   case eitherMdl of
     Left err -> die err
-    Right mdl -> withContext $ \ctx -> do
-      hPutStrLn stderr "Before optimisation:"
-      Text.hPutStrLn stderr (ppllvm mdl)
-      withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
-        withModuleFromAST ctx mdl $ \mdl' -> do
-          -- withPassManager defaultCuratedPassSetSpec $ \pm -> do
-          --   runPassManager pm mdl' >>= guard
-          hPutStrLn stderr "After optimisation:"
-          Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl'
-          withModuleInEngine mcjit mdl' $ \emdl -> do
-          Just f <- getFunction emdl "expr"
-          let f' = castFunPtr f :: FunPtr (IO Double)
-          exprFun f' >>= print
+    Right mdl' -> withContext $ \ctx ->
+      withHostTargetMachine $ \tm -> do
+        -- hPutStrLn stderr "Before optimisation:"
+        -- Text.hPutStrLn stderr (ppllvm mdl')
+
+        withModuleFromAST ctx mdl' $ \mdl -> do
+          let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+          withPassManager spec $ flip runPassManager mdl
+          -- hPutStrLn stderr "After optimisation:"
+          -- Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl
+          jit tm mdl >>= print
+
+jit :: TargetMachine -> Module -> IO Double
+jit tm mdl = do
+  loadLibraryPermanently (Just "stdlib.dylib") >>= guard . not
+  compLayerVar <- newEmptyMVar
+  
+  -- jit time
+  withExecutionSession $ \exSession ->
+    withSymbolResolver exSession (SymbolResolver (symResolver compLayerVar)) $ \symResolverPtr ->
+      withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
+        withModuleKey exSession $ \mdlKey ->
+          withIRCompileLayer linkingLayer tm $ \compLayer -> do
+            putMVar compLayerVar compLayer
+
+            withModule compLayer mdlKey mdl $ do
+              mangled <- mangleSymbol compLayer "expr"
+              Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
+              mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
+
+  where symResolver clv sym = do
+          cl <- readMVar clv
+          ms <- findSymbol cl sym False
+          case ms of
+            Right s -> return (return s)
+            _ -> do
+              addr <- getSymbolAddressInProcess sym
+              return $ return (JITSymbol addr (JITSymbolFlags False False True True))
 
 evalProg :: AST.Program -> IO (Maybe Double)
 evalProg (AST.Program asts) = do
   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
   case eitherMdl of
     Left _ -> return Nothing
-    Right mdl -> withContext $ \ctx ->
-      withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
-        withModuleFromAST ctx mdl $ \mdl' ->
-          withModuleInEngine mcjit mdl' $ \emdl -> do
-            Just f <- getFunction emdl "expr"
-            let f' = castFunPtr f :: FunPtr (IO Double)
-            Just <$> exprFun f'
+    Right mdl' -> withContext $ \ctx ->
+      withHostTargetMachine $ \tm ->
+        withModuleFromAST ctx mdl' (fmap Just . jit tm)
 
 -- | Builds up programs at the top-level of an LLVM Module
 -- >>> evalProg (read "31 - 5")
 -- Just 26.0
--- >>> evalProg (read "extern pow(x e); pow(3,2)")
--- Just 9.0
 buildAST :: AST.AST -> ModuleBuilderE Operand
 buildAST (AST.Function nameStr paramStrs body) = do
   let n = fromString nameStr