Generate code for call expressions
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad.Trans.Reader
6 import Control.Monad.IO.Class
7 import Data.String
8 import qualified Data.Map as Map
9 import qualified Data.Text.Lazy.IO as Text
10 import LLVM.AST.AddrSpace
11 import LLVM.AST.Constant
12 import LLVM.AST.Float
13 import LLVM.AST.FloatingPointPredicate hiding (False, True)
14 import LLVM.AST.Operand
15 import LLVM.AST.Type as Type
16 import LLVM.IRBuilder
17 import LLVM.Pretty
18 import System.IO
19 import System.IO.Error
20 import Text.Read (readMaybe)
21
22 main :: IO ()
23 main = buildModuleT "main" repl >>= Text.hPutStrLn stderr . ("\n" <>) . ppll
24
25 repl :: ModuleBuilderT IO ()
26 repl = do
27   liftIO $ hPutStr stderr "ready> "
28   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
29   case mline of
30     Nothing -> return ()
31     Just l -> do
32       case readMaybe l of
33         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
34         Just ast -> do
35           hoist $ buildAST ast
36           mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
37       repl
38   where
39     eofHandler e
40       | isEOFError e = return Nothing
41       | otherwise = ioError e
42
43 type Binds = Map.Map String Operand
44
45 buildAST :: AST -> ModuleBuilder Operand
46 buildAST (Function (Prototype nameStr paramStrs) body) = do
47   let n = fromString nameStr
48   function n params Type.double $ \ops -> do
49     let binds = Map.fromList (zip paramStrs ops)
50     flip runReaderT binds $ buildExpr body >>= ret
51   where params = zip (repeat Type.double) (map fromString paramStrs)
52
53 buildAST (Extern (Prototype nameStr params)) =
54   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
55
56 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
57   const $ flip runReaderT mempty $ buildExpr x >>= ret
58
59 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
60 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
61 buildExpr (Var n) = do
62   binds <- ask
63   case binds Map.!? n of
64     Just x -> pure x
65     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
66
67 buildExpr (BinOp op a b) = do
68   opA <- buildExpr a
69   opB <- buildExpr b
70   tmp <- instr opA opB
71   if isCmp
72     then uitofp tmp Type.double
73     else return tmp
74   where isCmp
75           | Cmp _ <- op = True
76           | otherwise = False
77         instr = case op of
78                   K.Add -> fadd
79                   K.Sub -> fsub
80                   K.Mul -> fmul
81                   K.Cmp LT -> fcmp OLT
82                   K.Cmp GT -> fcmp OGT
83                   K.Cmp EQ -> fcmp OEQ
84
85 buildExpr (Call callee params) = do
86   paramOps <- mapM buildExpr params
87   let nam = fromString callee
88       -- get a pointer to the function
89       typ = FunctionType Type.double (replicate (length params) Type.double) False
90       ptrTyp = Type.PointerType typ (AddrSpace 0)
91       ref = GlobalReference ptrTyp nam
92   call (ConstantOperand ref) (zip paramOps (repeat []))