c185a3e0c699f4445a0435c494953c13c8a19343
[kaleidoscope.git] / Kaleidoscope / codegen.cpp
1 #include <iostream>
2 #include <llvm/Passes/PassBuilder.h>
3 #include <llvm/IR/LLVMContext.h>
4 #include <llvm/IR/IRBuilder.h>
5 #include <llvm/IR/Verifier.h>
6 #include <llvm/IR/PassManager.h>
7 #include "llvm/Transforms/Scalar.h"
8 #include "llvm/Transforms/Scalar/Reassociate.h"
9 #include "llvm/Transforms/Scalar/GVN.h"
10 #include <llvm/Transforms/Scalar/DCE.h>
11 #include <llvm/Transforms/Scalar/SimplifyCFG.h>
12 #include <llvm/Transforms/InstCombine/InstCombine.h>
13 #include <llvm/Transforms/IPO/PassManagerBuilder.h>
14 #include <llvm/Analysis/OptimizationDiagnosticInfo.h>
15 #include <llvm/Analysis/MemorySSA.h>
16 #include <llvm/Analysis/PostDominators.h>
17 #include "ast.hpp"
18 #include "codegen.hpp"
19 #include "shared.h"
20
21 using namespace llvm;
22
23 static LLVMContext TheContext;
24 static IRBuilder<> Builder(TheContext);
25 static AnalysisManager<int, int> TheAM;
26 static std::map<std::string, Value *> NamedValues;
27 static std::unique_ptr<FunctionPassManager> TheFPM;
28 static std::unique_ptr<FunctionAnalysisManager> TheFAM;
29
30 void InitializeModuleAndPassManager(void) {
31         TheModule = llvm::make_unique<Module>("Kaleidoscope jit", TheContext);
32         
33         TheFPM = make_unique<FunctionPassManager>();
34         TheFAM = make_unique<FunctionAnalysisManager>();
35         
36         PassBuilder PB;
37         PB.registerFunctionAnalyses(*TheFAM);
38         
39         TheFPM->addPass(InstCombinePass());
40         TheFPM->addPass(ReassociatePass());
41         TheFPM->addPass(SimplifyCFGPass());
42         TheFPM->addPass(GVNHoistPass());
43         TheFPM->addPass(GVNSinkPass());
44 }
45
46 void PrintModule() {
47         TheModule->print(errs(), nullptr);
48 }
49
50 std::map<std::string, std::unique_ptr<PrototypeAST>> functionProtos;
51
52 Function *getFunction(std::string name) {
53         if (auto *func = TheModule->getFunction(name))
54                 return func;
55         
56         auto iterator = functionProtos.find(name);
57         if (iterator != functionProtos.end())
58                 return iterator->second->codegen();
59         
60         return nullptr;
61 }
62
63 Value *LogErrorV(const char *Str) {
64         LogError(Str);
65         return nullptr;
66 }
67
68 Value *NumberExprAST::codegen() {
69         return ConstantFP::get(TheContext, APFloat(Val));
70 }
71
72 Value *VariableExprAST::codegen() {
73         Value *V = NamedValues[Name];
74         if (!V)
75                 LogErrorV("Unknown variable name");
76         return V;
77 }
78
79 Value *BinaryExprAST::codegen() {
80         Value *L = LHS->codegen();
81         Value *R = RHS->codegen();
82         
83         if (!L || !R) return nullptr;
84         
85         switch (Op) {
86                 case '+':
87                         return Builder.CreateFAdd(L, R, "addtmp");
88                 case '-':
89                         return Builder.CreateFSub(L, R, "subtmp");
90                 case '*':
91                         return Builder.CreateFMul(L, R, "multmp");
92                 case '<':
93                         L = Builder.CreateFCmpULT(L, R, "cmptmp");
94                         // convert bool 0/1 to double 0.0/1.0
95                         return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
96                 default:
97                         return LogErrorV("Invalid binary operator");
98         }
99 }
100
101 Value *CallExprAST::codegen() {
102         Function *CalleeF = getFunction(Callee);
103         if (!CalleeF)
104                 return LogErrorV("Unknown function referenced");
105         
106         if (CalleeF->arg_size() != Args.size())
107                 return LogErrorV("Incorrect number of arguments passed");
108         
109         std::vector<Value *> ArgsV;
110         for (unsigned long i = 0, e = Args.size(); i != e; ++i) {
111                 ArgsV.push_back(Args[i]->codegen());
112                 if (!ArgsV.back()) return nullptr;
113         }
114         
115         return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
116 }
117
118 Function *PrototypeAST::codegen() {
119         std::vector<Type*> Doubles(Args.size(), Type::getDoubleTy(TheContext));
120         
121         FunctionType *FT = FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
122         Function *func = Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
123         
124         unsigned i = 0;
125         for (auto &arg : func->args())
126                 arg.setName(Args[i++]);
127         
128         return func;
129 }
130
131 Function *FunctionAST::codegen() {
132         //Transfer ownership but keep a reference
133         auto &P = *Prototype;
134         functionProtos[Prototype->getName()] = std::move(Prototype);
135         
136         Function *func = getFunction(P.getName());
137         
138         if (!func) return nullptr;
139         
140         BasicBlock *bb = BasicBlock::Create(TheContext, "entry", func);
141         Builder.SetInsertPoint(bb);
142         
143         NamedValues.clear();
144         for (auto &arg: func->args())
145                 NamedValues[arg.getName()] = &arg;
146         
147         if (Value *retVal = Body->codegen()) {
148                 Builder.CreateRet(retVal);
149                 verifyFunction(*func);
150                 TheFPM->run(*func, *TheFAM);
151                 return func;
152         }
153         
154         func->eraseFromParent();
155         return nullptr;
156 }