Initial commit
[kaleidoscope.git] / Kaleidoscope / codegen.cpp
diff --git a/Kaleidoscope/codegen.cpp b/Kaleidoscope/codegen.cpp
new file mode 100644 (file)
index 0000000..c185a3e
--- /dev/null
@@ -0,0 +1,156 @@
+#include <iostream>
+#include <llvm/Passes/PassBuilder.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/IRBuilder.h>
+#include <llvm/IR/Verifier.h>
+#include <llvm/IR/PassManager.h>
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/Reassociate.h"
+#include "llvm/Transforms/Scalar/GVN.h"
+#include <llvm/Transforms/Scalar/DCE.h>
+#include <llvm/Transforms/Scalar/SimplifyCFG.h>
+#include <llvm/Transforms/InstCombine/InstCombine.h>
+#include <llvm/Transforms/IPO/PassManagerBuilder.h>
+#include <llvm/Analysis/OptimizationDiagnosticInfo.h>
+#include <llvm/Analysis/MemorySSA.h>
+#include <llvm/Analysis/PostDominators.h>
+#include "ast.hpp"
+#include "codegen.hpp"
+#include "shared.h"
+
+using namespace llvm;
+
+static LLVMContext TheContext;
+static IRBuilder<> Builder(TheContext);
+static AnalysisManager<int, int> TheAM;
+static std::map<std::string, Value *> NamedValues;
+static std::unique_ptr<FunctionPassManager> TheFPM;
+static std::unique_ptr<FunctionAnalysisManager> TheFAM;
+
+void InitializeModuleAndPassManager(void) {
+       TheModule = llvm::make_unique<Module>("Kaleidoscope jit", TheContext);
+       
+       TheFPM = make_unique<FunctionPassManager>();
+       TheFAM = make_unique<FunctionAnalysisManager>();
+       
+       PassBuilder PB;
+       PB.registerFunctionAnalyses(*TheFAM);
+       
+       TheFPM->addPass(InstCombinePass());
+       TheFPM->addPass(ReassociatePass());
+       TheFPM->addPass(SimplifyCFGPass());
+       TheFPM->addPass(GVNHoistPass());
+       TheFPM->addPass(GVNSinkPass());
+}
+
+void PrintModule() {
+       TheModule->print(errs(), nullptr);
+}
+
+std::map<std::string, std::unique_ptr<PrototypeAST>> functionProtos;
+
+Function *getFunction(std::string name) {
+       if (auto *func = TheModule->getFunction(name))
+               return func;
+       
+       auto iterator = functionProtos.find(name);
+       if (iterator != functionProtos.end())
+               return iterator->second->codegen();
+       
+       return nullptr;
+}
+
+Value *LogErrorV(const char *Str) {
+       LogError(Str);
+       return nullptr;
+}
+
+Value *NumberExprAST::codegen() {
+       return ConstantFP::get(TheContext, APFloat(Val));
+}
+
+Value *VariableExprAST::codegen() {
+       Value *V = NamedValues[Name];
+       if (!V)
+               LogErrorV("Unknown variable name");
+       return V;
+}
+
+Value *BinaryExprAST::codegen() {
+       Value *L = LHS->codegen();
+       Value *R = RHS->codegen();
+       
+       if (!L || !R) return nullptr;
+       
+       switch (Op) {
+               case '+':
+                       return Builder.CreateFAdd(L, R, "addtmp");
+               case '-':
+                       return Builder.CreateFSub(L, R, "subtmp");
+               case '*':
+                       return Builder.CreateFMul(L, R, "multmp");
+               case '<':
+                       L = Builder.CreateFCmpULT(L, R, "cmptmp");
+                       // convert bool 0/1 to double 0.0/1.0
+                       return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
+               default:
+                       return LogErrorV("Invalid binary operator");
+       }
+}
+
+Value *CallExprAST::codegen() {
+       Function *CalleeF = getFunction(Callee);
+       if (!CalleeF)
+               return LogErrorV("Unknown function referenced");
+       
+       if (CalleeF->arg_size() != Args.size())
+               return LogErrorV("Incorrect number of arguments passed");
+       
+       std::vector<Value *> ArgsV;
+       for (unsigned long i = 0, e = Args.size(); i != e; ++i) {
+               ArgsV.push_back(Args[i]->codegen());
+               if (!ArgsV.back()) return nullptr;
+       }
+       
+       return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
+}
+
+Function *PrototypeAST::codegen() {
+       std::vector<Type*> Doubles(Args.size(), Type::getDoubleTy(TheContext));
+       
+       FunctionType *FT = FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
+       Function *func = Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
+       
+       unsigned i = 0;
+       for (auto &arg : func->args())
+               arg.setName(Args[i++]);
+       
+       return func;
+}
+
+Function *FunctionAST::codegen() {
+       //Transfer ownership but keep a reference
+       auto &P = *Prototype;
+       functionProtos[Prototype->getName()] = std::move(Prototype);
+       
+       Function *func = getFunction(P.getName());
+       
+       if (!func) return nullptr;
+       
+       BasicBlock *bb = BasicBlock::Create(TheContext, "entry", func);
+       Builder.SetInsertPoint(bb);
+       
+       NamedValues.clear();
+       for (auto &arg: func->args())
+               NamedValues[arg.getName()] = &arg;
+       
+       if (Value *retVal = Body->codegen()) {
+               Builder.CreateRet(retVal);
+               verifyFunction(*func);
+               TheFPM->run(*func, *TheFAM);
+               return func;
+       }
+       
+       func->eraseFromParent();
+       return nullptr;
+}