Detect and remove __anon_exprs in repl
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2
3 import AST as K -- K for Kaleidoscope
4 import Utils
5 import Control.Monad
6 import Control.Monad.Trans.Class
7 import Control.Monad.Trans.Reader
8 import Control.Monad.IO.Class
9 import Data.String
10 import qualified Data.Map as Map
11 import qualified Data.Text.Lazy.IO as Text
12 import LLVM.AST.AddrSpace
13 import LLVM.AST.Constant
14 import LLVM.AST.Float
15 import LLVM.AST.FloatingPointPredicate hiding (False, True)
16 import LLVM.AST.Operand
17 import LLVM.AST.Type as Type
18 import LLVM.Context
19 import LLVM.IRBuilder
20 import LLVM.Module
21 import LLVM.PassManager
22 import LLVM.Pretty
23 import LLVM.Target
24 import System.IO
25 import System.IO.Error
26 import Text.Read (readMaybe)
27
28 main :: IO ()
29 main = do
30   withContext $ \ctx -> withHostTargetMachine $ \tm -> do
31     ast <- runReaderT (buildModuleT "main" repl) ctx
32     return ()
33
34 repl :: ModuleBuilderT (ReaderT Context IO) ()
35 repl = do
36   liftIO $ hPutStr stderr "ready> "
37   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
38   case mline of
39     Nothing -> return ()
40     Just l -> do
41       case readMaybe l of
42         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
43         Just ast -> do
44           anon <- isAnonExpr <$> hoist (buildAST ast)
45           def <- mostRecentDef
46           
47           ast <- moduleSoFar "main"
48           ctx <- lift ask
49           liftIO $ withModuleFromAST ctx ast $ \mdl -> do
50             Text.hPutStrLn stderr $ ppll def
51             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
52             -- this returns true if the module was modified
53             withPassManager spec $ flip runPassManager mdl
54             Text.hPutStrLn stderr . ("\n" <>) . ppllvm =<< moduleAST mdl
55             when anon (jit mdl >>= hPrint stderr)
56
57           when anon (removeDef def)
58       repl
59   where
60     eofHandler e
61       | isEOFError e = return Nothing
62       | otherwise = ioError e
63     isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
64     isAnonExpr _ = False
65
66 jit :: Module -> IO Double
67 jit _mdl = putStrLn "Working on it!" >> return 0
68
69 type Binds = Map.Map String Operand
70
71 buildAST :: AST -> ModuleBuilder Operand
72 buildAST (Function (Prototype nameStr paramStrs) body) = do
73   let n = fromString nameStr
74   function n params Type.double $ \ops -> do
75     let binds = Map.fromList (zip paramStrs ops)
76     flip runReaderT binds $ buildExpr body >>= ret
77   where params = zip (repeat Type.double) (map fromString paramStrs)
78
79 buildAST (Extern (Prototype nameStr params)) =
80   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
81
82 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
83   const $ flip runReaderT mempty $ buildExpr x >>= ret
84
85 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
86 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
87 buildExpr (Var n) = do
88   binds <- ask
89   case binds Map.!? n of
90     Just x -> pure x
91     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
92
93 buildExpr (BinOp op a b) = do
94   opA <- buildExpr a
95   opB <- buildExpr b
96   tmp <- instr opA opB
97   if isCmp
98     then uitofp tmp Type.double
99     else return tmp
100   where isCmp
101           | Cmp _ <- op = True
102           | otherwise = False
103         instr = case op of
104                   K.Add -> fadd
105                   K.Sub -> fsub
106                   K.Mul -> fmul
107                   K.Cmp LT -> fcmp OLT
108                   K.Cmp GT -> fcmp OGT
109                   K.Cmp EQ -> fcmp OEQ
110
111 buildExpr (Call callee params) = do
112   paramOps <- mapM buildExpr params
113   let nam = fromString callee
114       -- get a pointer to the function
115       typ = FunctionType Type.double (replicate (length params) Type.double) False
116       ptrTyp = Type.PointerType typ (AddrSpace 0)
117       ref = GlobalReference ptrTyp nam
118   call (ConstantOperand ref) (zip paramOps (repeat []))