Set up the LLVM context and optimise the module
[kaleidoscope-hs.git] / Main.hs
diff --git a/Main.hs b/Main.hs
index f32003bdd2dbb48eba325526bf1849444fd1d09d..c3f929214099d738794ea2b6a76abe7be0d3124f 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -2,26 +2,35 @@
 
 import AST as K -- K for Kaleidoscope
 import Utils
 
 import AST as K -- K for Kaleidoscope
 import Utils
+import Control.Monad.Trans.Class
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
+import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
+import LLVM.Context
 import LLVM.IRBuilder
 import LLVM.IRBuilder
+import LLVM.Module
+import LLVM.PassManager
 import LLVM.Pretty
 import LLVM.Pretty
+import LLVM.Target
 import System.IO
 import System.IO.Error
 import Text.Read (readMaybe)
 
 main :: IO ()
 import System.IO
 import System.IO.Error
 import Text.Read (readMaybe)
 
 main :: IO ()
-main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
+main = do
+  withContext $ \ctx -> withHostTargetMachineDefault $ \tm -> do
+    ast <- runReaderT (buildModuleT "main" repl) ctx
+    return ()
 
 
-repl :: ModuleBuilderT IO ()
+repl :: ModuleBuilderT (ReaderT Context IO) ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
@@ -33,6 +42,14 @@ repl = do
         Just ast -> do
           hoist $ buildAST ast
           mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
         Just ast -> do
           hoist $ buildAST ast
           mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
+
+          ast <- moduleSoFar "main"
+          ctx <- lift ask
+          liftIO $ withModuleFromAST ctx ast $ \mdl -> do
+            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+            -- this returns true if the module was modified
+            withPassManager spec $ flip runPassManager mdl
+            Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
       repl
   where
     eofHandler e
       repl
   where
     eofHandler e
@@ -80,3 +97,12 @@ buildExpr (BinOp op a b) = do
                   K.Cmp LT -> fcmp OLT
                   K.Cmp GT -> fcmp OGT
                   K.Cmp EQ -> fcmp OEQ
                   K.Cmp LT -> fcmp OLT
                   K.Cmp GT -> fcmp OGT
                   K.Cmp EQ -> fcmp OEQ
+
+buildExpr (Call callee params) = do
+  paramOps <- mapM buildExpr params
+  let nam = fromString callee
+      -- get a pointer to the function
+      typ = FunctionType Type.double (replicate (length params) Type.double) False
+      ptrTyp = Type.PointerType typ (AddrSpace 0)
+      ref = GlobalReference ptrTyp nam
+  call (ConstantOperand ref) (zip paramOps (repeat []))