X-Git-Url: https://git.lukelau.me/?p=kaleidoscope-hs-old.git;a=blobdiff_plain;f=Main.hs;h=f096ea6fa5177b8deb5671c570cb97ba79f37192;hp=37c6cbc788a9849ef721995f58b9c1191ba0385d;hb=c67238064a9570e5c22566413b68906c9fcf39fe;hpb=4dc076eed8f88df110e2682bb24e208d19816790 diff --git a/Main.hs b/Main.hs index 37c6cbc..f096ea6 100644 --- a/Main.hs +++ b/Main.hs @@ -1,40 +1,155 @@ {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecursiveDo #-} module Main where import qualified AST +import Control.Monad +import Control.Monad.Trans.Class +import qualified Data.Map as Map import qualified Data.Text.Lazy.IO as Text +import Data.String import Foreign.Ptr +import System.Exit +import System.IO import LLVM.Context -import LLVM.CodeModel -import LLVM.ExecutionEngine +import LLVM.OrcJIT +import LLVM.OrcJIT.CompileLayer import LLVM.Module +import LLVM.PassManager import LLVM.IRBuilder +import LLVM.AST.AddrSpace import LLVM.AST.Constant import LLVM.AST.Float +import LLVM.AST.FloatingPointPredicate hiding (False, True) import LLVM.AST.Operand -import LLVM.AST.Type +import LLVM.AST.Type as Type +import LLVM.AST.Typed import LLVM.Pretty +import LLVM.Linking +import LLVM.Target -foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float +import Control.Concurrent.MVar + +type ModuleBuilderE = ModuleBuilderT (Either String) + +foreign import ccall "dynamic" mkFun :: FunPtr (IO Double) -> IO Double main :: IO () main = do - ast <- read <$> getContents - let mdl = buildModule "main" $ - function "expr" [] float $ \_ -> build ast >>= ret - Text.putStrLn (ppllvm mdl) - withContext $ \ctx -> - withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit -> - withModuleFromAST ctx mdl $ \mdl' -> - withModuleInEngine mcjit mdl' $ \emdl -> do - Just f <- getFunction emdl "expr" - let f' = castFunPtr f :: FunPtr (IO Float) - exprFun f' >>= print - -build :: AST.Expr -> IRBuilderT ModuleBuilder Operand -build (AST.Num a) = pure $ ConstantOperand (Float (Single a)) -build (AST.Add a b) = do - va <- build a - vb <- build b - fadd va vb + AST.Program asts <- read <$> getContents + let eitherMdl = buildModuleT "main" $ mapM buildAST asts + + case eitherMdl of + Left err -> die err + Right mdl' -> withContext $ \ctx -> + withHostTargetMachine $ \tm -> do + -- hPutStrLn stderr "Before optimisation:" + -- Text.hPutStrLn stderr (ppllvm mdl') + + withModuleFromAST ctx mdl' $ \mdl -> do + let spec = defaultCuratedPassSetSpec { optLevel = Just 3 } + withPassManager spec $ flip runPassManager mdl + -- hPutStrLn stderr "After optimisation:" + -- Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl + jit tm mdl >>= print + +jit :: TargetMachine -> Module -> IO Double +jit tm mdl = do + loadLibraryPermanently (Just "stdlib.dylib") >>= guard . not + compLayerVar <- newEmptyMVar + + -- jit time + withExecutionSession $ \exSession -> + withSymbolResolver exSession (SymbolResolver (symResolver compLayerVar)) $ \symResolverPtr -> + withObjectLinkingLayer exSession (const $ pure symResolverPtr) $ \linkingLayer -> + withModuleKey exSession $ \mdlKey -> + withIRCompileLayer linkingLayer tm $ \compLayer -> do + putMVar compLayerVar compLayer + + withModule compLayer mdlKey mdl $ do + mangled <- mangleSymbol compLayer "expr" + Right (JITSymbol fPtr _) <- findSymbolIn compLayer mdlKey mangled False + mkFun (castPtrToFunPtr (wordPtrToPtr fPtr)) + + where symResolver clv sym = do + cl <- readMVar clv + ms <- findSymbol cl sym False + case ms of + Right s -> return (return s) + _ -> do + addr <- getSymbolAddressInProcess sym + return $ return (JITSymbol addr (JITSymbolFlags False False True True)) + +evalProg :: AST.Program -> IO (Maybe Double) +evalProg (AST.Program asts) = do + let eitherMdl = buildModuleT "main" $ mapM buildAST asts + case eitherMdl of + Left _ -> return Nothing + Right mdl' -> withContext $ \ctx -> + withHostTargetMachine $ \tm -> + withModuleFromAST ctx mdl' (fmap Just . jit tm) + +-- | Builds up programs at the top-level of an LLVM Module +-- >>> evalProg (read "31 - 5") +-- Just 26.0 +buildAST :: AST.AST -> ModuleBuilderE Operand +buildAST (AST.Function nameStr paramStrs body) = do + let n = fromString nameStr + function n params Type.double $ \binds -> do + let bindMap = Map.fromList (zip paramStrs binds) + buildExpr bindMap body >>= ret + where params = zip (repeat Type.double) (map fromString paramStrs) +buildAST (AST.Extern nameStr params) = + extern (fromString nameStr) (replicate (length params) Type.double) Type.double +buildAST (AST.Eval e) = + function "expr" [] Type.double $ \_ -> buildExpr mempty e >>= ret + +-- | Builds up expressions, which are operands in LLVM IR +-- >>> evalProg (read "def foo(x) x * 2; foo(6)") +-- Just 12.0 +-- >>> evalProg (read "if 3 > 2 then 42 else 12") +-- Just 42.0 +buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand +buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Double a)) +buildExpr binds (AST.Var n) = case binds Map.!? n of + Just x -> pure x + Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope" + +buildExpr binds (AST.Call nameStr params) = do + paramOps <- mapM (buildExpr binds) params + let name = fromString nameStr + -- get a pointer to the function + typ = FunctionType Type.double (replicate (length params) Type.double) False + ptrTyp = Type.PointerType typ (AddrSpace 0) + ref = GlobalReference ptrTyp name + call (ConstantOperand ref) (zip paramOps (repeat [])) + +buildExpr binds (AST.BinOp op a b) = do + va <- buildExpr binds a + vb <- buildExpr binds b + let instr = case op of + AST.Add -> fadd + AST.Sub -> fsub + AST.Mul -> fmul + AST.Cmp GT -> fcmp OGT + AST.Cmp LT -> fcmp OLT + AST.Cmp EQ -> fcmp OEQ + instr va vb + +buildExpr binds (AST.If cond thenE elseE) = mdo + _ifB <- block `named` "if" + condV <- buildExpr binds cond + when (typeOf condV /= i1) $ lift $ lift $ Left "Not a boolean" + condBr condV thenB elseB + + thenB <- block `named` "then" + thenOp <- buildExpr binds thenE + br mergeB + + elseB <- block `named` "else" + elseOp <- buildExpr binds elseE + br mergeB + + mergeB <- block `named` "ifcont" + phi [(thenOp, thenB), (elseOp, elseB)]