Detect and remove __anon_exprs in repl
[kaleidoscope-hs.git] / Main.hs
diff --git a/Main.hs b/Main.hs
index ca6629ab447cd64b7c8748eec07ffa7fe6b34e9d..4dafa02b5f1acc3e2828c35a570993fba1abcfd0 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -2,38 +2,94 @@
 
 import AST as K -- K for Kaleidoscope
 import Utils
 
 import AST as K -- K for Kaleidoscope
 import Utils
+import Control.Monad
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Control.Monad.IO.Class
+import Data.String
+import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
 import qualified Data.Text.Lazy.IO as Text
+import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
+import LLVM.Context
 import LLVM.IRBuilder
 import LLVM.IRBuilder
+import LLVM.Module
+import LLVM.PassManager
 import LLVM.Pretty
 import LLVM.Pretty
+import LLVM.Target
 import System.IO
 import System.IO
+import System.IO.Error
 import Text.Read (readMaybe)
 
 import Text.Read (readMaybe)
 
-main = buildModuleT "main" repl
+main :: IO ()
+main = do
+  withContext $ \ctx -> withHostTargetMachine $ \tm -> do
+    ast <- runReaderT (buildModuleT "main" repl) ctx
+    return ()
 
 
-repl :: ModuleBuilderT IO ()
+repl :: ModuleBuilderT (ReaderT Context IO) ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
 repl = do
   liftIO $ hPutStr stderr "ready> "
-  ast <- liftIO $ readMaybe <$> getLine
-  case ast of
+  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
+  case mline of
+    Nothing -> return ()
+    Just l -> do
+      case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
-    Just x -> do
-      hoist $ buildAST x
-      mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
+        Just ast -> do
+          anon <- isAnonExpr <$> hoist (buildAST ast)
+          def <- mostRecentDef
+          
+          ast <- moduleSoFar "main"
+          ctx <- lift ask
+          liftIO $ withModuleFromAST ctx ast $ \mdl -> do
+            Text.hPutStrLn stderr $ ppll def
+            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+            -- this returns true if the module was modified
+            withPassManager spec $ flip runPassManager mdl
+            Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
+            when anon (jit mdl >>= hPrint stderr)
+
+          when anon (removeDef def)
       repl
   where
       repl
   where
+    eofHandler e
+      | isEOFError e = return Nothing
+      | otherwise = ioError e
+    isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
+    isAnonExpr _ = False
+
+jit :: Module -> IO Double
+jit _mdl = putStrLn "Working on it!" >> return 0
+
+type Binds = Map.Map String Operand
 
 buildAST :: AST -> ModuleBuilder Operand
 
 buildAST :: AST -> ModuleBuilder Operand
+buildAST (Function (Prototype nameStr paramStrs) body) = do
+  let n = fromString nameStr
+  function n params Type.double $ \ops -> do
+    let binds = Map.fromList (zip paramStrs ops)
+    flip runReaderT binds $ buildExpr body >>= ret
+  where params = zip (repeat Type.double) (map fromString paramStrs)
+
+buildAST (Extern (Prototype nameStr params)) =
+  extern (fromString nameStr) (replicate (length params) Type.double) Type.double
+
 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
-  const $ buildExpr x >>= ret
+  const $ flip runReaderT mempty $ buildExpr x >>= ret
 
 
-buildExpr :: Expr -> IRBuilderT ModuleBuilder Operand
+buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
+buildExpr (Var n) = do
+  binds <- ask
+  case binds Map.!? n of
+    Just x -> pure x
+    Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
+
 buildExpr (BinOp op a b) = do
   opA <- buildExpr a
   opB <- buildExpr b
 buildExpr (BinOp op a b) = do
   opA <- buildExpr a
   opB <- buildExpr b
@@ -51,3 +107,12 @@ buildExpr (BinOp op a b) = do
                   K.Cmp LT -> fcmp OLT
                   K.Cmp GT -> fcmp OGT
                   K.Cmp EQ -> fcmp OEQ
                   K.Cmp LT -> fcmp OLT
                   K.Cmp GT -> fcmp OGT
                   K.Cmp EQ -> fcmp OEQ
+
+buildExpr (Call callee params) = do
+  paramOps <- mapM buildExpr params
+  let nam = fromString callee
+      -- get a pointer to the function
+      typ = FunctionType Type.double (replicate (length params) Type.double) False
+      ptrTyp = Type.PointerType typ (AddrSpace 0)
+      ref = GlobalReference ptrTyp nam
+  call (ConstantOperand ref) (zip paramOps (repeat []))