{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecursiveDo #-} module Main where import qualified AST import Control.Monad import Control.Monad.Trans.Class import qualified Data.Map as Map import qualified Data.Text.Lazy.IO as Text import Data.String import Foreign.Ptr import System.Exit import System.IO import LLVM.Context import LLVM.ExecutionEngine import LLVM.Module import LLVM.PassManager import LLVM.IRBuilder import LLVM.AST.AddrSpace import LLVM.AST.Constant import LLVM.AST.Float import LLVM.AST.FloatingPointPredicate hiding (False, True) import LLVM.AST.Operand import LLVM.AST.Type as Type import LLVM.Pretty import Debug.Trace type ModuleBuilderE = ModuleBuilderT (Either String) foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float main :: IO () main = do AST.Program asts <- read <$> getContents let eitherMdl = buildModuleT "main" $ mapM buildAST asts case eitherMdl of Left err -> die err Right mdl -> withContext $ \ctx -> do hPutStrLn stderr "Before optimisation:" Text.hPutStrLn stderr (ppllvm mdl) withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit -> withModuleFromAST ctx mdl $ \mdl' -> withPassManager defaultCuratedPassSetSpec $ \pm -> do runPassManager pm mdl' >>= guard hPutStrLn stderr "After optimisation:" Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl' withModuleInEngine mcjit mdl' $ \emdl -> do Just f <- getFunction emdl "expr" let f' = castFunPtr f :: FunPtr (IO Float) exprFun f' >>= print buildAST :: AST.AST -> ModuleBuilderE Operand buildAST (AST.Function nameStr paramStrs body) = do let n = fromString nameStr function n params float $ \binds -> do let bindMap = Map.fromList (zip paramStrs binds) buildExpr bindMap body >>= ret where params = zip (repeat float) (map fromString paramStrs) buildAST (AST.Eval e) = function "expr" [] float $ \_ -> buildExpr mempty e >>= ret buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a)) buildExpr binds (AST.Var n) = case binds Map.!? n of Just x -> pure x Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope" buildExpr binds (AST.Call nameStr params) = do paramOps <- mapM (buildExpr binds) params let name = fromString nameStr -- get a pointer to the function typ = FunctionType float (replicate (length params) float) False ptrTyp = Type.PointerType typ (AddrSpace 0) ref = GlobalReference ptrTyp name call (ConstantOperand ref) (zip paramOps (repeat [])) buildExpr binds (AST.BinOp op a b) = do va <- buildExpr binds a vb <- buildExpr binds b let instr = case op of AST.Add -> fadd AST.Sub -> fsub AST.Mul -> fmul AST.Cmp GT -> fcmp OGT AST.Cmp LT -> fcmp OLT AST.Cmp EQ -> fcmp OEQ instr va vb buildExpr binds (AST.If cond thenE elseE) = mdo _ifB <- block `named` "if" condV <- buildExpr binds cond condBr condV thenB elseB thenB <- block `named` "then" thenOp <- buildExpr binds thenE br mergeB elseB <- block `named` "else" elseOp <- buildExpr binds elseE br mergeB mergeB <- block `named` "ifcont" traceShowId <$> phi [(thenOp, thenB), (elseOp, elseB)]