Initial commit
[kaleidoscope.git] / Kaleidoscope / lexer.cpp
diff --git a/Kaleidoscope/lexer.cpp b/Kaleidoscope/lexer.cpp
new file mode 100644 (file)
index 0000000..ac75490
--- /dev/null
@@ -0,0 +1,352 @@
+#include <cstdio>
+#include <cctype>
+#include <cstdlib>
+#include <string>
+#include <map>
+#include <iostream>
+#include <iomanip>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/Support/raw_ostream.h>
+#include <llvm/ExecutionEngine/ExecutionEngine.h>
+#include <llvm/ExecutionEngine/GenericValue.h>
+#include <llvm/Support/TargetSelect.h>
+//Needed to force linking interpreter
+#include <llvm/ExecutionEngine/MCJIT.h>
+#include "lexer.h"
+#include "ast.hpp"
+#include "codegen.hpp"
+#include "shared.h"
+
+static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
+                                                                                         std::unique_ptr<ExprAST> LHS);
+static std::unique_ptr<ExprAST> ParseExpression();
+
+static ExecutionEngine* TheEngine;
+
+std::unique_ptr<Module> TheModule;
+
+constexpr auto ANON_EXPR_FUNC_NAME = "__anon_expr";
+
+enum Token {
+       tok_eof = -1,
+       
+       tok_def = -2,
+       tok_extern = -3,
+       
+       tok_identifier = -4,
+       tok_number = -5,
+};
+
+static std::string IdentifierStr;
+static double NumVal;
+
+///Returns the next token from stdin
+static int gettok() {
+       static int LastChar = ' ';
+
+       while(isspace(LastChar))
+               LastChar = getchar();
+
+       if(isalpha(LastChar)) {
+               IdentifierStr = LastChar;
+               while(isalnum((LastChar = getchar())))
+                       IdentifierStr += LastChar;
+
+               if(IdentifierStr == "def")
+                       return tok_def;
+               if(IdentifierStr == "extern")
+                       return tok_extern;
+               return tok_identifier;
+       }
+
+       if(isdigit(LastChar) || LastChar == '.') {
+               std::string NumStr;
+               do {
+                       NumStr += LastChar;
+                       LastChar = getchar();
+               } while (isdigit(LastChar) || LastChar == '.');
+
+               NumVal = strtod(NumStr.c_str(), 0);
+               return tok_number;
+       }
+
+       if(LastChar == '#') {
+               //Coment until the end of the line
+               do
+                       LastChar = getchar();
+               while(LastChar != EOF && LastChar != '\n' && LastChar != 'r');
+
+               if(LastChar != EOF)
+                       return gettok();
+       }
+
+       //Check for end of file
+       if(LastChar == EOF)
+               return tok_eof;
+
+       int ThisChar = LastChar;
+       LastChar = getchar();
+       return ThisChar;
+}
+
+
+static int CurTok;
+static int getNextToken() {
+       return CurTok = gettok();
+}
+
+std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
+       LogError(Str);
+       return nullptr;
+}
+
+static std::unique_ptr<ExprAST> ParseNumberExpr() {
+       auto Result = llvm::make_unique<NumberExprAST>(NumVal);
+       getNextToken(); // consume the number
+       return std::move(Result);
+}
+
+static std::unique_ptr<ExprAST> ParseParenExpr() {
+       getNextToken();
+       auto V = ParseExpression();
+       if (!V)
+               return nullptr;
+       
+       if (CurTok != ')')
+               return LogError("expected ')'");
+       getNextToken();
+       return V;
+}
+
+static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+       std::string IdName = IdentifierStr;
+       
+       getNextToken();
+       
+       if (CurTok != '(')
+               return llvm::make_unique<VariableExprAST>(IdName);
+       
+       // call a function
+       getNextToken();
+       std::vector<std::unique_ptr<ExprAST>> Args;
+       if (CurTok != ')') {
+               while (true) {
+                       if (auto Arg = ParseExpression())
+                               Args.push_back(std::move(Arg));
+                       else
+                               return nullptr;
+                       
+                       if (CurTok == ')')
+                               break;
+                       
+                       if (CurTok != ',')
+                               return LogError("Expected ')' or ',' in argument list");
+                       getNextToken();
+               }
+       }
+       
+       getNextToken();
+       return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
+}
+
+static std::unique_ptr<ExprAST> ParsePrimary() {
+       switch (CurTok) {
+               default:
+                       return LogError("Unknown token when expected an expression");
+               case tok_identifier:
+                       return ParseIdentifierExpr();
+               case tok_number:
+                       return ParseNumberExpr();
+               case '(':
+                       return ParseParenExpr();
+       }
+}
+
+static std::map<char, int> BinopPrecedence;
+
+static int GetTokPrecedence() {
+       if (!isascii(CurTok))
+               return -1;
+       
+       int TokPrec = BinopPrecedence[CurTok];
+       if (TokPrec <= 0) return -1;
+       return TokPrec;
+}
+
+static std::unique_ptr<ExprAST> ParseExpression() {
+       auto LHS = ParsePrimary();
+       if (!LHS)
+               return nullptr;
+       return ParseBinOpRHS(0, std::move(LHS));
+}
+
+static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
+                                                                                         std::unique_ptr<ExprAST> LHS) {
+       while (true) {
+               int TokPrec = GetTokPrecedence();
+               if (TokPrec < ExprPrec)
+                       return LHS;
+               
+               int BinOp = CurTok;
+               getNextToken(); // eat binop
+               
+               auto RHS = ParsePrimary();
+               if (!RHS)
+                       return nullptr;
+               
+               int NextPrec = GetTokPrecedence();
+               if (TokPrec < NextPrec) {
+                       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
+                       if (!RHS)
+                               return nullptr;
+               }
+               
+               LHS = llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
+       }
+}
+
+static std::unique_ptr<PrototypeAST> ParsePrototype() {
+       if (CurTok != tok_identifier)
+               return LogErrorP("Expected function name in prototype");
+       
+       std::string FnName = IdentifierStr;
+       getNextToken();
+       
+       if (CurTok != '(')
+               return LogErrorP("Expected '(' in prototype");
+       
+       std::vector<std::string> ArgNames;
+       while (getNextToken() == tok_identifier)
+               ArgNames.push_back(IdentifierStr);
+       if (CurTok != ')')
+               return LogErrorP("Expected ')' in prototype");
+       
+       getNextToken(); // eat ')'
+       
+       return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
+}
+
+static std::unique_ptr<FunctionAST> ParseDefinition() {
+       getNextToken();
+       auto Proto = ParsePrototype();
+       if (!Proto) return nullptr;
+       
+       if (auto E = ParseExpression())
+               return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
+       return nullptr;
+}
+
+static std::unique_ptr<PrototypeAST> ParseExtern() {
+       getNextToken();
+       return ParsePrototype();
+}
+
+static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
+       if (auto e = ParseExpression()) {
+               // make an anonymous prototype
+               auto proto = llvm::make_unique<PrototypeAST>(ANON_EXPR_FUNC_NAME, std::vector<std::string>());
+               return llvm::make_unique<FunctionAST>(std::move(proto), std::move(e));
+       }
+       return nullptr;
+}
+
+static void HandleDefinition() {
+       if (auto def = ParseDefinition()) {
+               if (auto gen = def->codegen()) {
+                       std::cerr << "Read function defintion:";
+                       gen->print(errs());
+                       std::cerr << "\n";
+                       
+                       TheEngine->addModule(std::move(TheModule));
+                       InitializeModuleAndPassManager();
+               }
+       }
+       else
+               getNextToken(); // skip token for error recovery
+}
+
+static void HandleExtern() {
+       if (auto externProto = ParseExtern()) {
+               if (auto gen = externProto->codegen()) {
+                       std::cerr << "Read an extern:\n";
+                       gen->print(errs());
+                       std::cerr << "\n";
+                       functionProtos[externProto->getName()] = std::move(externProto);
+               }
+       }
+       else
+               getNextToken(); // skip token for error recovery
+}
+
+static void HandleTopLevelExpr() {
+       if (auto expr = ParseTopLevelExpr()) {
+               if (auto gen = expr->codegen()) {
+                       gen->print(errs());
+                       
+                       auto module = TheModule.get();
+                       TheEngine->addModule(std::move(TheModule));
+                       InitializeModuleAndPassManager();
+                       
+                       auto func = TheEngine->FindFunctionNamed(ANON_EXPR_FUNC_NAME);
+                       GenericValue gv = TheEngine->runFunction(func, std::vector<GenericValue>());
+                       std::cerr << "Evaluated to " << std::fixed << std::setw(5) << gv.DoubleVal << std::endl;
+                       
+                       TheEngine->removeModule(module);
+               }
+       }
+       else
+               getNextToken(); // skip token for error recovery
+}
+
+void mainLoop() {
+       while (true) {
+               fprintf(stderr, "ready> ");
+               switch(CurTok) {
+                       case tok_eof:
+                               return;
+                       case ';':
+                               getNextToken();
+                               break;
+                       case tok_def:
+                               HandleDefinition();
+                               break;
+                       case tok_extern:
+                               HandleExtern();
+                               break;
+                       default:
+                               HandleTopLevelExpr();
+                               break;
+               }
+       }
+}
+
+int main() {
+       InitializeNativeTarget();
+       InitializeNativeTargetAsmPrinter();
+       InitializeNativeTargetAsmParser();
+       
+       BinopPrecedence['<'] = 10;
+       BinopPrecedence['+'] = 20;
+       BinopPrecedence['-'] = 20;
+       BinopPrecedence['*'] = 40;
+       
+       InitializeModuleAndPassManager();
+       
+       std::string engineError;
+       TheEngine = EngineBuilder(std::move(TheModule)).setErrorStr(&engineError).create();
+       if (!engineError.empty())
+               std::cout << engineError << "\n";
+       
+       InitializeModuleAndPassManager();
+       
+       // prime the first token
+       std::cerr << "ready> ";
+       getNextToken();
+       
+       mainLoop();
+       
+       PrintModule();
+
+       
+       return 0;
+}