Add standard library
[kaleidoscope-hs.git] / Main.hs
diff --git a/Main.hs b/Main.hs
index 2a5a7e0fbc6c490b0d90af2da2e54a843138341f..64d73e60e0a059864222aec09f8b2ad0803adcfd 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -1,28 +1,65 @@
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE RecursiveDo #-}
 
 import AST as K -- K for Kaleidoscope
 import Utils
 
 import AST as K -- K for Kaleidoscope
 import Utils
+import Control.Monad
+import Control.Monad.Trans.Class
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
 import Control.Monad.Trans.Reader
 import Control.Monad.IO.Class
 import Data.String
 import qualified Data.Map as Map
 import qualified Data.Text.Lazy.IO as Text
+import Foreign.Ptr
 import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
 import LLVM.AST.AddrSpace
 import LLVM.AST.Constant
 import LLVM.AST.Float
 import LLVM.AST.FloatingPointPredicate hiding (False, True)
 import LLVM.AST.Operand
 import LLVM.AST.Type as Type
+import LLVM.Context
 import LLVM.IRBuilder
 import LLVM.IRBuilder
+import LLVM.Linking
+import LLVM.Module
+import LLVM.OrcJIT
+import LLVM.OrcJIT.CompileLayer
+import LLVM.PassManager
 import LLVM.Pretty
 import LLVM.Pretty
+import LLVM.Target
+import Numeric
 import System.IO
 import System.IO.Error
 import Text.Read (readMaybe)
 
 import System.IO
 import System.IO.Error
 import Text.Read (readMaybe)
 
+foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
+
+data JITEnv = JITEnv
+  { jitEnvContext :: Context
+  , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
+  , jitEnvModuleKey :: ModuleKey
+  }
+
 main :: IO ()
 main :: IO ()
-main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
+main = do
+  loadLibraryPermanently (Just "stdlib.dylib")
+  withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
+    withExecutionSession $ \exSession ->
+      withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
+        withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
+          withIRCompileLayer linkingLayer tm $ \compLayer ->
+            withModuleKey exSession $ \mdlKey -> do
+              let env = JITEnv ctx compLayer mdlKey
+              _ast <- runReaderT (buildModuleT "main" repl) env
+              return ()
+
+-- This can eventually be used to resolve external functions, e.g. a stdlib call
+symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
+symResolver sym = do
+  ptr <- getSymbolAddressInProcess sym
+  putStrLn $ "Resolving " <> show sym <> " to 0x" <> showHex ptr ""
+  return (Right (JITSymbol ptr defaultJITSymbolFlags))
 
 
-repl :: ModuleBuilderT IO ()
+repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
 repl = do
   liftIO $ hPutStr stderr "ready> "
   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
@@ -32,13 +69,34 @@ repl = do
       case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Just ast -> do
       case readMaybe l of
         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
         Just ast -> do
-          hoist $ buildAST ast
-          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
+          anon <- isAnonExpr <$> hoist (buildAST ast)
+          def <- mostRecentDef
+          
+          llvmAst <- moduleSoFar "main"
+          ctx <- lift $ asks jitEnvContext
+          env <- lift ask
+          liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
+            Text.hPutStrLn stderr $ ppll def
+            let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
+            -- this returns true if the module was modified
+            withPassManager spec $ flip runPassManager mdl
+            when anon (jit env mdl >>= hPrint stderr)
+
+          when anon (removeDef def)
       repl
   where
     eofHandler e
       | isEOFError e = return Nothing
       | otherwise = ioError e
       repl
   where
     eofHandler e
       | isEOFError e = return Nothing
       | otherwise = ioError e
+    isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
+    isAnonExpr _ = False
+
+jit :: JITEnv -> Module -> IO Double
+jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
+  withModule compLayer mdlKey mdl $ do
+    mangled <- mangleSymbol compLayer "__anon_expr"
+    Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
+    mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
 
 type Binds = Map.Map String Operand
 
 
 type Binds = Map.Map String Operand
 
@@ -90,3 +148,60 @@ buildExpr (Call callee params) = do
       ptrTyp = Type.PointerType typ (AddrSpace 0)
       ref = GlobalReference ptrTyp nam
   call (ConstantOperand ref) (zip paramOps (repeat []))
       ptrTyp = Type.PointerType typ (AddrSpace 0)
       ref = GlobalReference ptrTyp nam
   call (ConstantOperand ref) (zip paramOps (repeat []))
+
+buildExpr (If cond thenE elseE) = mdo
+  _ifB <- block `named` "if"
+
+  -- since everything is a double, false == 0
+  let zero = ConstantOperand (Float (Double 0))
+  condV <- buildExpr cond
+  cmp <- fcmp ONE zero condV `named` "cmp"
+
+  condBr cmp thenB elseB
+
+  thenB <- block `named` "then"
+  thenOp <- buildExpr thenE
+  br mergeB
+
+  elseB <- block `named` "else"
+  elseOp <- buildExpr elseE
+  br mergeB
+
+  mergeB <- block `named` "ifcont"
+  phi [(thenOp, thenB), (elseOp, elseB)]
+
+buildExpr (For name init cond mStep body) = mdo
+  preheaderB <- block `named` "preheader"
+
+  initV <- buildExpr init `named` "init"
+  
+  -- build the condition expression with 'i' in the bindings
+  initCondV <- withReaderT (Map.insert name initV) $
+                (buildExpr cond >>= fcmp ONE zero) `named` "initcond"
+
+  -- skip the loop if we don't meet the condition with the init
+  condBr initCondV loopB afterB
+
+  loopB <- block `named` "loop"
+  i <- phi [(initV, preheaderB), (nextVar, loopB)] `named` "i"
+
+  -- build the body expression with 'i' in the bindings
+  withReaderT (Map.insert name i) $ buildExpr body `named` "body"
+
+  -- default to 1 if there's no step defined
+  stepV <- case mStep of
+    Just step -> buildExpr step
+    Nothing -> return $ ConstantOperand (Float (Double 1))
+
+  nextVar <- fadd i stepV `named` "nextvar"
+
+  let zero = ConstantOperand (Float (Double 0))
+  -- again we need 'i' in the bindings
+  condV <- withReaderT (Map.insert name i) $
+            (buildExpr cond >>= fcmp ONE zero) `named` "cond"
+  condBr condV loopB afterB
+
+  afterB <- block `named` "after"
+  -- since a for loop doesn't really have a value, return 0
+  return $ ConstantOperand (Float (Double 0))
+