Generate code for functions and variables
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad.Trans.Reader
6 import Control.Monad.IO.Class
7 import Data.String
8 import qualified Data.Map as Map
9 import qualified Data.Text.Lazy.IO as Text
10 import LLVM.AST.Constant
11 import LLVM.AST.Float
12 import LLVM.AST.FloatingPointPredicate hiding (False, True)
13 import LLVM.AST.Operand
14 import LLVM.AST.Type as Type
15 import LLVM.IRBuilder
16 import LLVM.Pretty
17 import System.IO
18 import Text.Read (readMaybe)
19
20 main = buildModuleT "main" repl
21
22 repl :: ModuleBuilderT IO ()
23 repl = do
24   liftIO $ hPutStr stderr "ready> "
25   ast <- liftIO $ readMaybe <$> getLine
26   case ast of
27     Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
28     Just x -> do
29       hoist $ buildAST x
30       mostRecentDef >>= liftIO . Text.hPutStrLn stderr . ppll
31   repl
32   where 
33
34 type Binds = Map.Map String Operand
35
36 buildAST :: AST -> ModuleBuilder Operand
37 buildAST (Function (Prototype nameStr paramStrs) body) = do
38   let n = fromString nameStr
39   function n params Type.double $ \ops -> do
40     let binds = Map.fromList (zip paramStrs ops)
41     flip runReaderT binds $ buildExpr body >>= ret
42   where params = zip (repeat Type.double) (map fromString paramStrs)
43
44 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
45   const $ flip runReaderT mempty $ buildExpr x >>= ret
46
47 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
48 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
49 buildExpr (Var n) = do
50   binds <- ask
51   case binds Map.!? n of
52     Just x -> pure x
53     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
54
55 buildExpr (BinOp op a b) = do
56   opA <- buildExpr a
57   opB <- buildExpr b
58   tmp <- instr opA opB
59   if isCmp
60     then uitofp tmp Type.double
61     else return tmp
62   where isCmp
63           | Cmp _ <- op = True
64           | otherwise = False
65         instr = case op of
66                   K.Add -> fadd
67                   K.Sub -> fsub
68                   K.Mul -> fmul
69                   K.Cmp LT -> fcmp OLT
70                   K.Cmp GT -> fcmp OGT
71                   K.Cmp EQ -> fcmp OEQ