Find our JIT'ed function and run it
[kaleidoscope-hs.git] / Main.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE RecursiveDo #-}
3
4 import AST as K -- K for Kaleidoscope
5 import Utils
6 import Control.Monad
7 import Control.Monad.Trans.Class
8 import Control.Monad.Trans.Reader
9 import Control.Monad.IO.Class
10 import Data.String
11 import qualified Data.Map as Map
12 import qualified Data.Text.Lazy.IO as Text
13 import Foreign.Ptr
14 import LLVM.AST.AddrSpace
15 import LLVM.AST.Constant
16 import LLVM.AST.Float
17 import LLVM.AST.FloatingPointPredicate hiding (False, True)
18 import LLVM.AST.Operand
19 import LLVM.AST.Type as Type
20 import LLVM.Context
21 import LLVM.IRBuilder
22 import LLVM.Linking
23 import LLVM.Module
24 import LLVM.OrcJIT
25 import LLVM.OrcJIT.CompileLayer
26 import LLVM.PassManager
27 import LLVM.Pretty
28 import LLVM.Target
29 import Numeric
30 import System.IO
31 import System.IO.Error
32 import Text.Read (readMaybe)
33
34 foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double
35
36 data JITEnv = JITEnv
37   { jitEnvContext :: Context
38   , jitEnvCompileLayer :: IRCompileLayer ObjectLinkingLayer
39   , jitEnvModuleKey :: ModuleKey
40   }
41
42 main :: IO ()
43 main = do
44   loadLibraryPermanently (Just "stdlib.dylib")
45   withContext $ \ctx -> withHostTargetMachineDefault $ \tm ->
46     withExecutionSession $ \exSession ->
47       withSymbolResolver exSession (SymbolResolver symResolver) $ \symResolverPtr ->
48         withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer ->
49           withIRCompileLayer linkingLayer tm $ \compLayer ->
50             withModuleKey exSession $ \mdlKey -> do
51               let env = JITEnv ctx compLayer mdlKey
52               _ast <- runReaderT (buildModuleT "main" repl) env
53               return ()
54
55 -- This can eventually be used to resolve external functions, e.g. a stdlib call
56 symResolver :: MangledSymbol -> IO (Either JITSymbolError JITSymbol)
57 symResolver sym = do
58   ptr <- getSymbolAddressInProcess sym
59   putStrLn $ "Resolving " <> show sym <> " to 0x" <> showHex ptr ""
60   return (Right (JITSymbol ptr defaultJITSymbolFlags))
61
62 repl :: ModuleBuilderT (ReaderT JITEnv IO) ()
63 repl = do
64   liftIO $ hPutStr stderr "ready> "
65   mline <- liftIO $ catchIOError (Just <$> getLine) eofHandler
66   case mline of
67     Nothing -> return ()
68     Just l -> do
69       case readMaybe l of
70         Nothing ->  liftIO $ hPutStrLn stderr "Couldn't parse"
71         Just ast -> do
72           anon <- isAnonExpr <$> hoist (buildAST ast)
73           def <- mostRecentDef
74           
75           llvmAst <- moduleSoFar "main"
76           ctx <- lift $ asks jitEnvContext
77           env <- lift ask
78           liftIO $ withModuleFromAST ctx llvmAst $ \mdl -> do
79             Text.hPutStrLn stderr $ ppll def
80             let spec = defaultCuratedPassSetSpec { optLevel = Just 3 }
81             -- this returns true if the module was modified
82             withPassManager spec $ flip runPassManager mdl
83             when anon (jit env mdl >>= hPrint stderr)
84
85           when anon (removeDef def)
86       repl
87   where
88     eofHandler e
89       | isEOFError e = return Nothing
90       | otherwise = ioError e
91     isAnonExpr (ConstantOperand (GlobalReference _ "__anon_expr")) = True
92     isAnonExpr _ = False
93
94 jit :: JITEnv -> Module -> IO Double
95 jit JITEnv{jitEnvCompileLayer=compLayer, jitEnvModuleKey=mdlKey} mdl =
96   withModule compLayer mdlKey mdl $ do
97     mangled <- mangleSymbol compLayer "__anon_expr"
98     Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False
99     mkFun (castPtrToFunPtr (wordPtrToPtr fPtr))
100
101 type Binds = Map.Map String Operand
102
103 buildAST :: AST -> ModuleBuilder Operand
104 buildAST (Function (Prototype nameStr paramStrs) body) = do
105   let n = fromString nameStr
106   function n params Type.double $ \ops -> do
107     let binds = Map.fromList (zip paramStrs ops)
108     flip runReaderT binds $ buildExpr body >>= ret
109   where params = zip (repeat Type.double) (map fromString paramStrs)
110
111 buildAST (Extern (Prototype nameStr params)) =
112   extern (fromString nameStr) (replicate (length params) Type.double) Type.double
113
114 buildAST (TopLevelExpr x) = function "__anon_expr" [] Type.double $
115   const $ flip runReaderT mempty $ buildExpr x >>= ret
116
117 buildExpr :: Expr -> ReaderT Binds (IRBuilderT ModuleBuilder) Operand
118 buildExpr (Num x) = pure $ ConstantOperand (Float (Double x))
119 buildExpr (Var n) = do
120   binds <- ask
121   case binds Map.!? n of
122     Just x -> pure x
123     Nothing -> error $ "'" <> n <> "' doesn't exist in scope"
124
125 buildExpr (BinOp op a b) = do
126   opA <- buildExpr a
127   opB <- buildExpr b
128   tmp <- instr opA opB
129   if isCmp
130     then uitofp tmp Type.double
131     else return tmp
132   where isCmp
133           | Cmp _ <- op = True
134           | otherwise = False
135         instr = case op of
136                   K.Add -> fadd
137                   K.Sub -> fsub
138                   K.Mul -> fmul
139                   K.Cmp LT -> fcmp OLT
140                   K.Cmp GT -> fcmp OGT
141                   K.Cmp EQ -> fcmp OEQ
142
143 buildExpr (Call callee params) = do
144   paramOps <- mapM buildExpr params
145   let nam = fromString callee
146       -- get a pointer to the function
147       typ = FunctionType Type.double (replicate (length params) Type.double) False
148       ptrTyp = Type.PointerType typ (AddrSpace 0)
149       ref = GlobalReference ptrTyp nam
150   call (ConstantOperand ref) (zip paramOps (repeat []))
151
152 buildExpr (If cond thenE elseE) = mdo
153   _ifB <- block `named` "if"
154
155   -- since everything is a double, false == 0
156   let zero = ConstantOperand (Float (Double 0))
157   condV <- buildExpr cond
158   cmp <- fcmp ONE zero condV `named` "cmp"
159
160   condBr cmp thenB elseB
161
162   thenB <- block `named` "then"
163   thenOp <- buildExpr thenE
164   br mergeB
165
166   elseB <- block `named` "else"
167   elseOp <- buildExpr elseE
168   br mergeB
169
170   mergeB <- block `named` "ifcont"
171   phi [(thenOp, thenB), (elseOp, elseB)]
172
173 buildExpr (For name init cond mStep body) = mdo
174   preheaderB <- block `named` "preheader"
175
176   initV <- buildExpr init `named` "init"
177   
178   -- build the condition expression with 'i' in the bindings
179   initCondV <- withReaderT (Map.insert name initV) $
180                 (buildExpr cond >>= fcmp ONE zero) `named` "initcond"
181
182   -- skip the loop if we don't meet the condition with the init
183   condBr initCondV loopB afterB
184
185   loopB <- block `named` "loop"
186   i <- phi [(initV, preheaderB), (nextVar, loopB)] `named` "i"
187
188   -- build the body expression with 'i' in the bindings
189   withReaderT (Map.insert name i) $ buildExpr body `named` "body"
190
191   -- default to 1 if there's no step defined
192   stepV <- case mStep of
193     Just step -> buildExpr step
194     Nothing -> return $ ConstantOperand (Float (Double 1))
195
196   nextVar <- fadd i stepV `named` "nextvar"
197
198   let zero = ConstantOperand (Float (Double 0))
199   -- again we need 'i' in the bindings
200   condV <- withReaderT (Map.insert name i) $
201             (buildExpr cond >>= fcmp ONE zero) `named` "cond"
202   condBr condV loopB afterB
203
204   afterB <- block `named` "after"
205   -- since a for loop doesn't really have a value, return 0
206   return $ ConstantOperand (Float (Double 0))
207