From: Luke Lau Date: Sun, 10 Mar 2019 14:34:09 +0000 (+0000) Subject: Update AST to match Kaleidoscope more closely X-Git-Url: https://git.lukelau.me/?p=kaleidoscope-hs-old.git;a=commitdiff_plain;h=808a6ae35fb1f2ff61a84991af1e7722298a4d62 Update AST to match Kaleidoscope more closely --- diff --git a/AST.hs b/AST.hs index 2bea002..1965ae6 100644 --- a/AST.hs +++ b/AST.hs @@ -4,6 +4,15 @@ import Data.Char import Text.Read import Text.ParserCombinators.ReadP hiding ((+++), choice) +newtype Program = Program [AST] + deriving Show + +instance Read Program where + readPrec = fmap Program $ lift $ sepBy1 (readPrec_to_P readPrec 0) $ do + skipSpaces + char ';' + skipSpaces + data AST = Function String [String] Expr | Eval Expr deriving Show @@ -25,9 +34,7 @@ instance Read AST where params <- between (char '(') (char ')') $ sepBy (munch1 isAlpha) skipSpaces skipSpaces - body <- between (char '{') (char '}') $ - readS_to_P reads - skipSpaces + body <- readS_to_P reads return (Function name params body) instance Read Expr where diff --git a/Main.hs b/Main.hs index 8f4610a..fcfe5df 100644 --- a/Main.hs +++ b/Main.hs @@ -3,27 +3,28 @@ module Main where import qualified AST +import qualified Data.Map as Map import qualified Data.Text.Lazy.IO as Text +import Data.String import Foreign.Ptr import System.IO import LLVM.Context -import LLVM.CodeModel import LLVM.ExecutionEngine import LLVM.Module import LLVM.IRBuilder +import LLVM.AST.AddrSpace import LLVM.AST.Constant import LLVM.AST.Float import LLVM.AST.Operand -import LLVM.AST.Type +import LLVM.AST.Type as Type import LLVM.Pretty foreign import ccall "dynamic" exprFun :: FunPtr (IO Float) -> IO Float main :: IO () main = do - ast <- read <$> getContents - let mdl = buildModule "main" $ - function "expr" [] float $ \_ -> build ast >>= ret + program <- read <$> getContents + let mdl = buildModule "main" $ mapM buildAST program Text.hPutStrLn stderr (ppllvm mdl) withContext $ \ctx -> withMCJIT ctx Nothing Nothing Nothing Nothing $ \mcjit -> @@ -33,11 +34,32 @@ main = do let f' = castFunPtr f :: FunPtr (IO Float) exprFun f' >>= print -build :: AST.Expr -> IRBuilderT ModuleBuilder Operand -build (AST.Num a) = pure $ ConstantOperand (Float (Single a)) -build (AST.BinOp op a b) = do - va <- build a - vb <- build b +buildAST :: AST.AST -> ModuleBuilder Operand +buildAST (AST.Function nameStr paramStrs body) = do + let name = fromString nameStr + function name params float $ \binds -> do + let bindMap = Map.fromList (zip paramStrs binds) + buildExpr bindMap body >>= ret + where params = zip (repeat float) (map fromString paramStrs) +buildAST (AST.Eval e) = + function "expr" [] float $ \_ -> buildExpr mempty e >>= ret + +buildExpr :: Map.Map String Operand -> AST.Expr -> IRBuilderT ModuleBuilder Operand +buildExpr _ (AST.Num a) = pure $ ConstantOperand (Float (Single a)) +buildExpr binds (AST.Var name) = pure $ binds Map.! name + +buildExpr binds (AST.Call nameStr params) = do + paramOps <- mapM (buildExpr binds) params + let name = fromString nameStr + -- get a pointer to the function + typ = FunctionType float (replicate (length params) float) False + ptrTyp = Type.PointerType typ (AddrSpace 0) + ref = GlobalReference ptrTyp name + call (ConstantOperand ref) (zip paramOps (repeat [])) + +buildExpr binds (AST.BinOp op a b) = do + va <- buildExpr binds a + vb <- buildExpr binds b let instr = case op of AST.Add -> fadd AST.Sub -> fsub