Find our JIT'ed function and run it
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
3
4 import AST as K -- K for Kaleidoscope
5 import Utils
6 import Control.Monad
7 import Control.Monad.Trans.Class
8 import Control.Monad.Trans.Reader
9 import Control.Monad.IO.Class
10 import Data.String
11 import qualified Data.Map as Map
12 import qualified Data.Text.Lazy.IO as Text
13 import Foreign.Ptr
14 import LLVM.AST.AddrSpace
15 import LLVM.AST.Constant
16 import LLVM.AST.Float
17 import LLVM.AST.FloatingPointPredicate hiding (False, True)
18 import LLVM.AST.Operand
19 import LLVM.AST.Type as Type
20 import LLVM.Context
21 import LLVM.IRBuilder
22 import LLVM.Linking
23 import LLVM.Module
24 import LLVM.OrcJIT
25 import LLVM.OrcJIT.CompileLayer
26 import LLVM.PassManager
27 import LLVM.Pretty
28 import LLVM.Target
29 import System.IO
30 import System.IO.Error
31 import Text.Read (readMaybe)
32
33 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
34
35 data JITEnv = JITEnv
36   { jitEnvContext :: Context
37   , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
38   , jitEnvModuleKey :: ModuleKey
39   }
40
41 main :: IO ()
42 main = do
43   loadLibraryPermanently (Just "stdlib.dylib")
44   withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
45     withExecutionSession $ \exSession ->
46       withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
47         withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
48           withIRCompileLayer linkingLayer tm $ \compLayer ->
49             withModuleKey exSession $ \mdlKey -> do
50               let env = JITEnv ctx compLayer mdlKey
51               _ast <- runReaderT (buildModuleT "main" repl) env
52               return ()
53
54 -- This can eventually be used to resolve external functions, e.g. a stdlib call
55 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
56 symResolver sym = undefined
57
58 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
59 repl = do
60   liftIO $ hPutStr stderr "ready> "
61   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
62   case mline of
63     Nothing -> return ()
64     Just l -> do
65       case readMaybe l of
66         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
67         Just ast -> do
68           anon <- isAnonExpr <$> hoist (buildAST ast)
69           def <- mostRecentDef
70           
71           llvmAst <- moduleSoFar "main"
72           ctx <- lift $ asks jitEnvContext
73           env <- lift ask
74           liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
75             Text.hPutStrLn stderr $ ppll def
76             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
77             -- this returns true if the module was modified
78             withPassManager spec $ flip runPassManager mdl
79             when anon (jit env mdl >>= hPrint stderr)
80
81           when anon (removeDef def)
82       repl
83   where
84     eofHandler e
85       | isEOFError e = return Nothing
86       | otherwise = ioError e
87     isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
88     isAnonExpr _ = False
89
90 jit :: JITEnv -> Module -> IO Double
91 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
92   withModule compLayer mdlKey mdl $ do
93     mangled <- mangleSymbol compLayer "__anon_expr"
94     Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
95     mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
96
97 type Binds = Map.Map String Operand
98
99 buildAST :: AST -> ModuleBuilder Operand
100 buildAST (Function (Prototype nameStr paramStrs) body) = do
101   let n = fromString nameStr
102   function n params Type.double $ \ops -> do
103     let binds = Map.fromList (zip paramStrs ops)
104     flip runReaderT binds $ buildExpr body >>= ret
105   where params = zip (repeat Type.double) (map fromString paramStrs)
106
107 buildAST (Extern (Prototype nameStr params)) =
108   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
109
110 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
111   const $ flip runReaderT mempty $ buildExpr x >>= ret
112
113 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
114 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
115 buildExpr (Var n) = do
116   binds <- ask
117   case binds Map.!? n of
118     Just x -> pure x
119     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
120
121 buildExpr (BinOp op a b) = do
122   opA <- buildExpr a
123   opB <- buildExpr b
124   tmp <- instr opA opB
125   if isCmp
126     then uitofp tmp Type.double
127     else return tmp
128   where isCmp
129           | Cmp _ <- op = True
130           | otherwise = False
131         instr = case op of
132                   K.Add -> fadd
133                   K.Sub -> fsub
134                   K.Mul -> fmul
135                   K.Cmp LT -> fcmp OLT
136                   K.Cmp GT -> fcmp OGT
137                   K.Cmp EQ -> fcmp OEQ
138
139 buildExpr (Call callee params) = do
140   paramOps <- mapM buildExpr params
141   let nam = fromString callee
142       -- get a pointer to the function
143       typ = FunctionType Type.double (replicate (length params) Type.double) False
144       ptrTyp = Type.PointerType typ (AddrSpace 0)
145       ref = GlobalReference ptrTyp nam
146   call (ConstantOperand ref) (zip paramOps (repeat []))
147
148 buildExpr (If cond thenE elseE) = mdo
149   _ifB <- block `named` "if"
150
151   -- since everything is a double, false == 0
152   let zero = ConstantOperand (Float (Double 0))
153   condV <- buildExpr cond
154   cmp <- fcmp ONE zero condV `named` "cmp"
155
156   condBr cmp thenB elseB
157
158   thenB <- block `named` "then"
159   thenOp <- buildExpr thenE
160   br mergeB
161
162   elseB <- block `named` "else"
163   elseOp <- buildExpr elseE
164   br mergeB
165
166   mergeB <- block `named` "ifcont"
167   phi [(thenOp, thenB), (elseOp, elseB)]
168
169 buildExpr (For name init cond mStep body) = mdo
170   preheaderB <- block `named` "preheader"
171
172   initV <- buildExpr init `named` "init"
173   
174   -- build the condition expression with 'i' in the bindings
175   initCondV <- withReaderT (Map.insert name initV) $
176                 (buildExpr cond >>= fcmp ONE zero) `named` "initcond"
177
178   -- skip the loop if we don't meet the condition with the init
179   condBr initCondV loopB afterB
180
181   loopB <- block `named` "loop"
182   i <- phi [(initV, preheaderB), (nextVar, loopB)] `named` "i"
183
184   -- build the body expression with 'i' in the bindings
185   withReaderT (Map.insert name i) $ buildExpr body `named` "body"
186
187   -- default to 1 if there's no step defined
188   stepV <- case mStep of
189     Just step -> buildExpr step
190     Nothing -> return $ ConstantOperand (Float (Double 1))
191
192   nextVar <- fadd i stepV `named` "nextvar"
193
194   let zero = ConstantOperand (Float (Double 0))
195   -- again we need 'i' in the bindings
196   condV <- withReaderT (Map.insert name i) $
197             (buildExpr cond >>= fcmp ONE zero) `named` "cond"
198   condBr condV loopB afterB
199
200   afterB <- block `named` "after"
201   -- since a for loop doesn't really have a value, return 0
202   return $ ConstantOperand (Float (Double 0))
203