Generate code for externs
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad.Trans.Reader
6 import Control.Monad.IO.Class
7 import Data.String
8 import qualified Data.Map as Map
9 import qualified Data.Text.Lazy.IO as Text
10 import LLVM.AST.Constant
11 import LLVM.AST.Float
12 import LLVM.AST.FloatingPointPredicate hiding (False, True)
13 import LLVM.AST.Operand
14 import LLVM.AST.Type as Type
15 import LLVM.IRBuilder
16 import LLVM.Pretty
17 import System.IO
18 import Text.Read (readMaybe)
19
20 main = buildModuleT "main" repl
21
22 repl :: ModuleBuilderT IO ()
23 repl = do
24   liftIO $ hPutStr stderr "ready> "
25   ast <- liftIO $ readMaybe <$> getLine
26   case ast of
27     Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
28     Just x -> do
29       hoist $ buildAST x
30       mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
31   repl
32   where 
33
34 type Binds = Map.Map String Operand
35
36 buildAST :: AST -> ModuleBuilder Operand
37 buildAST (Function (Prototype nameStr paramStrs) body) = do
38   let n = fromString nameStr
39   function n params Type.double $ \ops -> do
40     let binds = Map.fromList (zip paramStrs ops)
41     flip runReaderT binds $ buildExpr body >>= ret
42   where params = zip (repeat Type.double) (map fromString paramStrs)
43
44 buildAST (Extern (Prototype nameStr params)) =
45   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
46
47 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
48   const $ flip runReaderT mempty $ buildExpr x >>= ret
49
50 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
51 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
52 buildExpr (Var n) = do
53   binds <- ask
54   case binds Map.!? n of
55     Just x -> pure x
56     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
57
58 buildExpr (BinOp op a b) = do
59   opA <- buildExpr a
60   opB <- buildExpr b
61   tmp <- instr opA opB
62   if isCmp
63     then uitofp tmp Type.double
64     else return tmp
65   where isCmp
66           | Cmp _ <- op = True
67           | otherwise = False
68         instr = case op of
69                   K.Add -> fadd
70                   K.Sub -> fsub
71                   K.Mul -> fmul
72                   K.Cmp LT -> fcmp OLT
73                   K.Cmp GT -> fcmp OGT
74                   K.Cmp EQ -> fcmp OEQ