Add LLVM IR codegen
authorLuke Lau <luke_lau@icloud.com>
Mon, 3 Jun 2019 14:47:05 +0000 (15:47 +0100)
committerLuke Lau <luke_lau@icloud.com>
Mon, 3 Jun 2019 23:11:03 +0000 (00:11 +0100)
Now that we have our AST built up, its time to start thinking about
semantics. And to think about semantics, we need to start building up
code that does what our AST says.

In most compilers, we don't directly convert the AST right down to
machine code: Usually there's an intermediate representation involved
that's somewhere between our programming language and machine code. LLVM
has an intermediate representation called LLVM IR, and that's what we'll
be converting our AST to.

llvm-hs provides a monadic way of building up modules and functions,
with ModuleBuilder and IRBuilder respectively. To generate our code we
will traverse our AST inside these monads, spitting out LLVM IR as we go
along.

AST.hs
Main.hs
Utils.hs [new file with mode: 0644]

diff --git a/AST.hs b/AST.hs
index b57d7cb8cede024a9c2bb0e366d8d066474d331d..9ff555a74d10c6d970fc8b7df0a37c1d98fdbffc 100644 (file)
--- a/AST.hs
+++ b/AST.hs
@@ -4,7 +4,7 @@ import Data.Char
 import Text.Read 
 import Text.ParserCombinators.ReadP hiding ((+++), choice)
 
-data Expr = Num Float
+data Expr = Num Double
           | Var String
           | BinOp BinOp Expr Expr
           | Call String [Expr]
@@ -18,6 +18,8 @@ instance Read Expr where
                              , parseVar
                              , parseCall
                              , parseBinOp "<" 10 (Cmp LT)
+                             , parseBinOp ">" 10 (Cmp GT)
+                             , parseBinOp "==" 10 (Cmp EQ)
                              , parseBinOp "+" 20 Add
                              , parseBinOp "-" 20 Sub
                              , parseBinOp "*" 40 Mul
diff --git a/Main.hs b/Main.hs
index ec0de8c8bcfd731b298ed5a09c5df8784e9f0fd3..2a5a7e0fbc6c490b0d90af2da2e54a843138341f 100644 (file)
--- a/Main.hs
+++ b/Main.hs
@@ -1,10 +1,92 @@
-import AST
+{-# LANGUAGE OverloadedStrings #-}
+
+import AST as K -- K for Kaleidoscope
+import Utils
+import Control.Monad.Trans.Reader
+import Control.Monad.IO.Class
+import Data.String
+import qualified Data.Map as Map
+import qualified Data.Text.Lazy.IO as Text
+import LLVM.AST.AddrSpace
+import LLVM.AST.Constant
+import LLVM.AST.Float
+import LLVM.AST.FloatingPointPredicate hiding (False, True)
+import LLVM.AST.Operand
+import LLVM.AST.Type as Type
+import LLVM.IRBuilder
+import LLVM.Pretty
 import System.IO
-import Text.Read
-main = do
-  hPutStr stderr "ready> "
-  ast <- (readMaybe <$> getLine) :: IO (Maybe AST)
-  case ast of
-    Just x -> hPrint stderr x
-    Nothing ->  hPutStrLn stderr "Couldn't parse"
-  main
+import System.IO.Error
+import Text.Read (readMaybe)
+
+main :: IO ()
+main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
+
+repl :: ModuleBuilderT IO ()
+repl = do
+  liftIO $ hPutStr stderr "ready> "
+  mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
+  case mline of
+    Nothing -> return ()
+    Just l -> do
+      case readMaybe l of
+        Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
+        Just ast -> do
+          hoist $ buildAST ast
+          mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
+      repl
+  where
+    eofHandler e
+      | isEOFError e = return Nothing
+      | otherwise = ioError e
+
+type Binds = Map.Map String Operand
+
+buildAST :: AST -> ModuleBuilder Operand
+buildAST (Function (Prototype nameStr paramStrs) body) = do
+  let n = fromString nameStr
+  function n params Type.double $ \ops -> do
+    let binds = Map.fromList (zip paramStrs ops)
+    flip runReaderT binds $ buildExpr body >>= ret
+  where params = zip (repeat Type.double) (map fromString paramStrs)
+
+buildAST (Extern (Prototype nameStr params)) =
+  extern (fromString nameStr) (replicate (length params) Type.double) Type.double
+
+buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
+  const $ flip runReaderT mempty $ buildExpr x >>= ret
+
+buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
+buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
+buildExpr (Var n) = do
+  binds <- ask
+  case binds Map.!? n of
+    Just x -> pure x
+    Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
+
+buildExpr (BinOp op a b) = do
+  opA <- buildExpr a
+  opB <- buildExpr b
+  tmp <- instr opA opB
+  if isCmp
+    then uitofp tmp Type.double
+    else return tmp
+  where isCmp
+          | Cmp _ <- op = True
+          | otherwise = False
+        instr = case op of
+                  K.Add -> fadd
+                  K.Sub -> fsub
+                  K.Mul -> fmul
+                  K.Cmp LT -> fcmp OLT
+                  K.Cmp GT -> fcmp OGT
+                  K.Cmp EQ -> fcmp OEQ
+
+buildExpr (Call callee params) = do
+  paramOps <- mapM buildExpr params
+  let nam = fromString callee
+      -- get a pointer to the function
+      typ = FunctionType Type.double (replicate (length params) Type.double) False
+      ptrTyp = Type.PointerType typ (AddrSpace 0)
+      ref = GlobalReference ptrTyp nam
+  call (ConstantOperand ref) (zip paramOps (repeat []))
diff --git a/Utils.hs b/Utils.hs
new file mode 100644 (file)
index 0000000..2a83c35
--- /dev/null
+++ b/Utils.hs
@@ -0,0 +1,17 @@
+{-|
+Shoving away gross stuff into this one module.
+-}
+module Utils where
+
+import Control.Monad.Trans.State
+import Data.Functor.Identity
+import LLVM.AST
+import LLVM.IRBuilder.Module
+import LLVM.IRBuilder.Internal.SnocList
+
+mostRecentDef :: Monad m => ModuleBuilderT m Definition
+mostRecentDef = last . getSnocList . builderDefs <$> liftModuleState get
+
+hoist :: Monad m => ModuleBuilder a -> ModuleBuilderT m a
+hoist m = ModuleBuilderT $ StateT $
+  return . runIdentity . runStateT (unModuleBuilderT m)