Set up the LLVM context and optimise the module
[kaleidoscope-hs.git] / Main.hs
diff --git a/Main.hs b/Main.hs
index 2eae262967be844fff1886ac2d34318912b17e53..bff1c08c965c037080cc6aba052a60c40f1dc999 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -7,29 +7,49 @@ import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
+import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
+import LLVM.Context
 import LLVM.IRBuilder
 import LLVM.IRBuilder
+import LLVM.Module
+import LLVM.PassManager
 import LLVM.Pretty
 import LLVM.Pretty
+import LLVM.Target
 import System.IO
 import System.IO
+import System.IO.Error
 import Text.Read (readMaybe)
 
 import Text.Read (readMaybe)
 
-main = buildModuleT "main" repl
+main :: IO ()
+main = do
+  mdl' <- buildModuleT "main" repl
+  withContext $ \ctx -> withHostTargetMachine $ \tm ->
+    withModuleFromAST ctx mdl' $ \mdl -> do
+      let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+      -- this returns true if the module was modified
+      withPassManager spec $ flip runPassManager mdl
+      Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
 
 repl :: ModuleBuilderT IO ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
 
 repl :: ModuleBuilderT IO ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
-  ast <- liftIO $ readMaybe <$> getLine
-  case ast of
+  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
+  case mline of
+    Nothing -> return ()
+    Just l -> do
+      case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
-    Just x -> do
-      hoist $ buildAST x
+        Just ast -> do
+          hoist $ buildAST ast
           mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
       repl
   where
           mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
       repl
   where
+    eofHandler e
+      | isEOFError e = return Nothing
+      | otherwise = ioError e
 
 type Binds = Map.Map String Operand
 
 
 type Binds = Map.Map String Operand
 
@@ -41,6 +61,9 @@ buildAST (Function (Prototype nameStr paramStrs) body) = do
     flip runReaderT binds $ buildExpr body >>= ret
   where params = zip (repeat Type.double) (map fromString paramStrs)
 
     flip runReaderT binds $ buildExpr body >>= ret
   where params = zip (repeat Type.double) (map fromString paramStrs)
 
+buildAST (Extern (Prototype nameStr params)) =
+  extern (fromString nameStr) (replicate (length params) Type.double) Type.double
+
 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
   const $ flip runReaderT mempty $ buildExpr x >>= ret
 
 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
   const $ flip runReaderT mempty $ buildExpr x >>= ret
 
@@ -69,3 +92,12 @@ buildExpr (BinOp op a b) = do
                   K.Cmp LT -> fcmp OLT
                   K.Cmp GT -> fcmp OGT
                   K.Cmp EQ -> fcmp OEQ
                   K.Cmp LT -> fcmp OLT
                   K.Cmp GT -> fcmp OGT
                   K.Cmp EQ -> fcmp OEQ
+
+buildExpr (Call callee params) = do
+  paramOps <- mapM buildExpr params
+  let nam = fromString callee
+      -- get a pointer to the function
+      typ = FunctionType Type.double (replicate (length params) Type.double) False
+      ptrTyp = Type.PointerType typ (AddrSpace 0)
+      ref = GlobalReference ptrTyp nam
+  call (ConstantOperand ref) (zip paramOps (repeat []))