Add for loops
[kaleidoscope-hs-old.git] / Main.hs
diff --git a/Main.hs b/Main.hs
index 4d8275828510ceed0bbcc99f3b2b95798042b2ae..7fa4499024c94db07b3264d0a08964af7b194c67 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -13,7 +13,8 @@ import Foreign.Ptr
 import System.Exit
 import System.IO
 import LLVM.Context
-import LLVM.ExecutionEngine
+import LLVM.OrcJIT
+import LLVM.OrcJIT.CompileLayer
 import LLVM.Module
 import LLVM.PassManager
 import LLVM.IRBuilder
@@ -25,49 +26,73 @@ import LLVM.AST.Operand
 import LLVM.AST.Type as Type
 import LLVM.AST.Typed
 import LLVM.Pretty
+import LLVM.Linking
+import LLVM.Target
+
+import Control.Concurrent.MVar
 
 type ModuleBuilderE = ModuleBuilderT (Either String)
 
-foreign import ccall "dynamic" exprFun :: FunPtr (IO Double) -> IO Double
+foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
 
 main :: IO ()
 main = do
   AST.Program asts <- read <$> getContents
   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
+
   case eitherMdl of
     Left err -> die err
-    Right mdl -> withContext $ \ctx -> do
+    Right mdl' -> withContext $ \ctx ->
+      withHostTargetMachine $ \tm -> do
         hPutStrLn stderr "Before optimisation:"
-      Text.hPutStrLn stderr (ppllvm mdl)
-      withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
-        withModuleFromAST ctx mdl $ \mdl' -> do
-          -- withPassManager defaultCuratedPassSetSpec $ \pm -> do
-          --   runPassManager pm mdl' >>= guard
+        Text.hPutStrLn stderr (ppllvm mdl')
+
+        withModuleFromAST ctx mdl' $ \mdl -> do
+          let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+          withPassManager spec $ flip runPassManager mdl
           hPutStrLn stderr "After optimisation:"
-          Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl'
-          withModuleInEngine mcjit mdl' $ \emdl -> do
-          Just f <- getFunction emdl "expr"
-          let f' = castFunPtr f :: FunPtr (IO Double)
-          exprFun f' >>= print
+          Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl
+          jit tm mdl >>= print
+
+jit :: TargetMachine -> Module -> IO Double
+jit tm mdl = do
+  loadLibraryPermanently (Just "stdlib.dylib") >>= guard . not
+  compLayerVar <- newEmptyMVar
+  
+  -- jit time
+  withExecutionSession $ \exSession ->
+    withSymbolResolver exSession (SymbolResolver (symResolver compLayerVar)) $ \symResolverPtr ->
+      withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
+        withModuleKey exSession $ \mdlKey ->
+          withIRCompileLayer linkingLayer tm $ \compLayer -> do
+            putMVar compLayerVar compLayer
+
+            withModule compLayer mdlKey mdl $ do
+              mangled <- mangleSymbol compLayer "expr"
+              Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
+              mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
+
+  where symResolver clv sym = do
+          cl <- readMVar clv
+          ms <- findSymbol cl sym False
+          case ms of
+            Right s -> return (return s)
+            _ -> do
+              addr <- getSymbolAddressInProcess sym
+              return $ return (JITSymbol addr (JITSymbolFlags False False True True))
 
 evalProg :: AST.Program -> IO (Maybe Double)
 evalProg (AST.Program asts) = do
   let eitherMdl = buildModuleT "main" $ mapM buildAST asts
   case eitherMdl of
     Left _ -> return Nothing
-    Right mdl -> withContext $ \ctx ->
-      withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit ->
-        withModuleFromAST ctx mdl $ \mdl' ->
-          withModuleInEngine mcjit mdl' $ \emdl -> do
-            Just f <- getFunction emdl "expr"
-            let f' = castFunPtr f :: FunPtr (IO Double)
-            Just <$> exprFun f'
+    Right mdl' -> withContext $ \ctx ->
+      withHostTargetMachine $ \tm ->
+        withModuleFromAST ctx mdl' (fmap Just . jit tm)
 
 -- | Builds up programs at the top-level of an LLVM Module
 -- >>> evalProg (read "31 - 5")
 -- Just 26.0
--- >>> evalProg (read "extern pow(x e); pow(3,2)")
--- Just 9.0
 buildAST :: AST.AST -> ModuleBuilderE Operand
 buildAST (AST.Function nameStr paramStrs body) = do
   let n = fromString nameStr
@@ -128,3 +153,33 @@ buildExpr binds (AST.If cond thenE elseE) = mdo
 
   mergeB <- block `named` "ifcont"
   phi [(thenOp, thenB), (elseOp, elseB)]
+
+buildExpr binds (AST.For ident start cond mStep body) = mdo
+  startV <- buildExpr binds start
+
+  preheaderB <- block `named` "preheader"
+
+  br loopB
+
+  loopB <- block `named` "loop"
+
+  i <- phi [(startV, preheaderB), (nextVar, loopB)] `named` "i"
+
+  let newBinds = Map.insert ident i binds
+
+  buildExpr newBinds body `named` "body"
+
+  stepV <- case mStep of
+    Just step -> buildExpr newBinds step
+    Nothing -> pure $ ConstantOperand (Float (Double 1))
+
+  nextVar <- fadd i stepV `named` "nextvar"
+
+  condV <- buildExpr newBinds cond `named` "cond"
+
+  condBr condV loopB afterB
+
+  afterB <- block `named` "after"
+
+  return (ConstantOperand (Float (Double 0)))
+