X-Git-Url: https://git.lukelau.me/?p=kaleidoscope-hs-old.git;a=blobdiff_plain;f=Main.hs;h=4d8275828510ceed0bbcc99f3b2b95798042b2ae;hp=328aaf3c72c56d9b1d13f4f1de4376cf07b5681d;hb=749e5a29af22fc74b8c597485de9be6485ccc62f;hpb=f250f7c0e621f1f94b6d2377f6a634314f306ace diff --git a/Main.hs b/Main.hs index 328aaf3..4d82758 100644 --- a/Main.hs +++ b/Main.hs @@ -26,11 +26,9 @@ import LLVM.AST.Type as Type import LLVM.AST.Typed import LLVM.Pretty -import Debug.Trace - type ModuleBuilderE = ModuleBuilderT (Either String) -foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float +foreign import ccall "dynamic" exprFun :: FunPtr (IO Double) -> IO Double main :: IO () main = do @@ -42,28 +40,53 @@ main = do hPutStrLn stderr "Before optimisation:" Text.hPutStrLn stderr (ppllvm mdl) withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit -> - withModuleFromAST ctx mdl $ \mdl' -> - withPassManager defaultCuratedPassSetSpec $ \pm -> do - runPassManager pm mdl' >>= guard + withModuleFromAST ctx mdl $ \mdl' -> do + -- withPassManager defaultCuratedPassSetSpec $ \pm -> do + -- runPassManager pm mdl' >>= guard hPutStrLn stderr "After optimisation:" Text.hPutStrLn stderr . ppllvm =<< moduleAST mdl' withModuleInEngine mcjit mdl' $ \emdl -> do Just f <- getFunction emdl "expr" - let f' = castFunPtr f :: FunPtr (IO Float) + let f' = castFunPtr f :: FunPtr (IO Double) exprFun f' >>= print +evalProg :: AST.Program -> IO (Maybe Double) +evalProg (AST.Program asts) = do + let eitherMdl = buildModuleT "main" $ mapM buildAST asts + case eitherMdl of + Left _ -> return Nothing + Right mdl -> withContext $ \ctx -> + withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit -> + withModuleFromAST ctx mdl $ \mdl' -> + withModuleInEngine mcjit mdl' $ \emdl -> do + Just f <- getFunction emdl "expr" + let f' = castFunPtr f :: FunPtr (IO Double) + Just <$> exprFun f' + +-- | Builds up programs at the top-level of an LLVM Module +-- >>> evalProg (read "31 - 5") +-- Just 26.0 +-- >>> evalProg (read "extern pow(x e); pow(3,2)") +-- Just 9.0 buildAST :: AST.AST -> ModuleBuilderE Operand buildAST (AST.Function nameStr paramStrs body) = do let n = fromString nameStr - function n params float $ \binds -> do + function n params Type.double $ \binds -> do let bindMap = Map.fromList (zip paramStrs binds) buildExpr bindMap body >>= ret - where params = zip (repeat float) (map fromString paramStrs) + where params = zip (repeat Type.double) (map fromString paramStrs) +buildAST (AST.Extern nameStr params) = + extern (fromString nameStr) (replicate (length params) Type.double) Type.double buildAST (AST.Eval e) = - function "expr" [] float $ \_ -> buildExpr mempty e >>= ret + function "expr" [] Type.double $ \_ -> buildExpr mempty e >>= ret +-- | Builds up expressions, which are operands in LLVM IR +-- >>> evalProg (read "def foo(x) x * 2; foo(6)") +-- Just 12.0 +-- >>> evalProg (read "if 3 > 2 then 42 else 12") +-- Just 42.0 buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilderE Operand -buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a)) +buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Double a)) buildExpr binds (AST.Var n) = case binds Map.!? n of Just x -> pure x Nothing -> lift $ lift $ Left $ "'" <> n <> "' doesn't exist in scope" @@ -72,7 +95,7 @@ buildExpr binds (AST.Call nameStr params) = do paramOps <- mapM (buildExpr binds) params let name = fromString nameStr -- get a pointer to the function - typ = FunctionType float (replicate (length params) float) False + typ = FunctionType Type.double (replicate (length params) Type.double) False ptrTyp = Type.PointerType typ (AddrSpace 0) ref = GlobalReference ptrTyp name call (ConstantOperand ref) (zip paramOps (repeat [])) @@ -91,7 +114,6 @@ buildExpr binds (AST.BinOp op a b) = do buildExpr binds (AST.If cond thenE elseE) = mdo _ifB <- block `named` "if" - condV <- buildExpr binds cond when (typeOf condV /= i1) $ lift $ lift $ Left "Not a boolean" condBr condV thenB elseB @@ -105,4 +127,4 @@ buildExpr binds (AST.If cond thenE elseE) = mdo br mergeB mergeB <- block `named` "ifcont" - traceShowId <$> phi [(thenOp, thenB), (elseOp, elseB)] + phi [(thenOp, thenB), (elseOp, elseB)]