X-Git-Url: https://git.lukelau.me/?p=kaleidoscope-hs-old.git;a=blobdiff_plain;f=Main.hs;h=7fa4499024c94db07b3264d0a08964af7b194c67;hp=eba1ce1c7ae6e8072b4e78862c2daf750280bdd4;hb=d9a6be382ca58e6d1c4ed988856ccbdf76a3bcdf;hpb=e9d3bf0386f5654857fe6eaabe74ca651f9df76d diff --git a/Main.hs b/Main.hs index eba1ce1..7fa4499 100644 --- a/Main.hs +++ b/Main.hs @@ -13,7 +13,8 @@ import Foreign.Ptr import System.Exit import System.IO import LLVM.Context -import LLVM.ExecutionEngine +import LLVM.OrcJIT +import LLVM.OrcJIT.CompileLayer import LLVM.Module import LLVM.PassManager import LLVM.IRBuilder @@ -25,43 +26,69 @@ import LLVM.AST.Operand import LLVM.AST.Type as Type import LLVM.AST.Typed import LLVM.Pretty +import LLVM.Linking +import LLVM.Target + +import Control.Concurrent.MVar type ModuleBuilderE = ModuleBuilderT (Either String) -foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float +foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double main :: IO () main = do AST.Program asts <- read <$> getContents let eitherMdl = buildModuleT "main" $ mapM buildAST asts + case eitherMdl of Left err -> die err - Right mdl -> withContext $ \ctx -> do + Right mdl' -> withContext $ \ctx -> + withHostTargetMachine $ \tm -> do hPutStrLn stderr "Before optimisation:" - Text.hPutStrLn stderr (ppllvm mdl) - withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit -> - withModuleFromAST ctx mdl $ \mdl' -> - withPassManager defaultCuratedPassSetSpec $ \pm -> do - runPassManager pm mdl' >>= guard + Text.hPutStrLn stderr (ppllvm mdl') + + withModuleFromAST ctx mdl' $ \mdl -> do + let spec = defaultCuratedPassSetSpec { optLevel = Just 3 } + withPassManager spec $ flip runPassManager mdl hPutStrLn stderr "After optimisation:" - Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl' - withModuleInEngine mcjit mdl' $ \emdl -> do - Just f <- getFunction emdl "expr" - let f' = castFunPtr f :: FunPtr (IO Float) - exprFun f' >>= print + Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl + jit tm mdl >>= print + +jit :: TargetMachine -> Module -> IO Double +jit tm mdl = do + loadLibraryPermanently (Just "stdlib.dylib") >>= guard . not + compLayerVar <- newEmptyMVar + + -- jit time + withExecutionSession $ \exSession -> + withSymbolResolver exSession (SymbolResolver (symResolver compLayerVar)) $ \symResolverPtr -> + withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer -> + withModuleKey exSession $ \mdlKey -> + withIRCompileLayer linkingLayer tm $ \compLayer -> do + putMVar compLayerVar compLayer + + withModule compLayer mdlKey mdl $ do + mangled <- mangleSymbol compLayer "expr" + Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False + mkFun (castPtrToFunPtr (wordPtrToPtr fPtr)) -evalProg :: AST.Program -> IO (Maybe Float) + where symResolver clv sym = do + cl <- readMVar clv + ms <- findSymbol cl sym False + case ms of + Right s -> return (return s) + _ -> do + addr <- getSymbolAddressInProcess sym + return $ return (JITSymbol addr (JITSymbolFlags False False True True)) + +evalProg :: AST.Program -> IO (Maybe Double) evalProg (AST.Program asts) = do let eitherMdl = buildModuleT "main" $ mapM buildAST asts case eitherMdl of Left _ -> return Nothing - Right mdl -> withContext $ \ctx -> - withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit -> - withModuleFromAST ctx mdl $ \mdl' -> - withModuleInEngine mcjit mdl' $ \emdl -> do - Just f <- getFunction emdl "expr" - let f' = castFunPtr f :: FunPtr (IO Float) - Just <$> exprFun f' + Right mdl' -> withContext $ \ctx -> + withHostTargetMachine $ \tm -> + withModuleFromAST ctx mdl' (fmap Just . jit tm) -- | Builds up programs at the top-level of an LLVM Module -- >>> evalProg (read "31 - 5") @@ -69,12 +96,14 @@ evalProg (AST.Program asts) = do buildAST :: AST.AST -> ModuleBuilderE Operand buildAST (AST.Function nameStr paramStrs body) = do let n = fromString nameStr - function n params float $ \binds -> do + function n params Type.double $ \binds -> do let bindMap = Map.fromList (zip paramStrs binds) buildExpr bindMap body >>= ret - where params = zip (repeat float) (map fromString paramStrs) + where params = zip (repeat Type.double) (map fromString paramStrs) +buildAST (AST.Extern nameStr params) = + extern (fromString nameStr) (replicate (length params) Type.double) Type.double buildAST (AST.Eval e) = - function "expr" [] float $ \_ -> buildExpr mempty e >>= ret + function "expr" [] Type.double $ \_ -> buildExpr mempty e >>= ret -- | Builds up expressions, which are operands in LLVM IR -- >>> evalProg (read "def foo(x) x * 2; foo(6)") @@ -82,7 +111,7 @@ buildAST (AST.Eval e) = -- >>> evalProg (read "if 3 > 2 then 42 else 12") -- Just 42.0 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand -buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a)) +buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Double a)) buildExpr binds (AST.Var n) = case binds Map.!? n of Just x -> pure x Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope" @@ -91,7 +120,7 @@ buildExpr binds (AST.Call nameStr params) = do paramOps <- mapM (buildExpr binds) params let name = fromString nameStr -- get a pointer to the function - typ = FunctionType float (replicate (length params) float) False + typ = FunctionType Type.double (replicate (length params) Type.double) False ptrTyp = Type.PointerType typ (AddrSpace 0) ref = GlobalReference ptrTyp name call (ConstantOperand ref) (zip paramOps (repeat [])) @@ -110,7 +139,6 @@ buildExpr binds (AST.BinOp op a b) = do buildExpr binds (AST.If cond thenE elseE) = mdo _ifB <- block `named` "if" - condV <- buildExpr binds cond when (typeOf condV /= i1) $ lift $ lift $ Left "Not a boolean" condBr condV thenB elseB @@ -125,3 +153,33 @@ buildExpr binds (AST.If cond thenE elseE) = mdo mergeB <- block `named` "ifcont" phi [(thenOp, thenB), (elseOp, elseB)] + +buildExpr binds (AST.For ident start cond mStep body) = mdo + startV <- buildExpr binds start + + preheaderB <- block `named` "preheader" + + br loopB + + loopB <- block `named` "loop" + + i <- phi [(startV, preheaderB), (nextVar, loopB)] `named` "i" + + let newBinds = Map.insert ident i binds + + buildExpr newBinds body `named` "body" + + stepV <- case mStep of + Just step -> buildExpr newBinds step + Nothing -> pure $ ConstantOperand (Float (Double 1)) + + nextVar <- fadd i stepV `named` "nextvar" + + condV <- buildExpr newBinds cond `named` "cond" + + condBr condV loopB afterB + + afterB <- block `named` "after" + + return (ConstantOperand (Float (Double 0))) +