WIP
[kaleidoscope.git] / Kaleidoscope / codegen.cpp
1 #include <iostream>
2 #include <llvm/Passes/PassBuilder.h>
3 #include <llvm/IR/LLVMContext.h>
4 #include <llvm/IR/IRBuilder.h>
5 #include <llvm/IR/Verifier.h>
6 #include <llvm/IR/PassManager.h>
7 #include "llvm/Transforms/Scalar.h"
8 #include "llvm/Transforms/Scalar/Reassociate.h"
9 #include "llvm/Transforms/Scalar/GVN.h"
10 #include <llvm/Transforms/Scalar/DCE.h>
11 #include <llvm/Transforms/Scalar/SimplifyCFG.h>
12 #include <llvm/Transforms/InstCombine/InstCombine.h>
13 #include <llvm/Transforms/IPO/PassManagerBuilder.h>
14 #include <llvm/Analysis/MemorySSA.h>
15 #include <llvm/Analysis/PostDominators.h>
16 #include "ast.hpp"
17 #include "codegen.hpp"
18 #include "shared.h"
19
20 using namespace llvm;
21
22 static LLVMContext TheContext;
23 static IRBuilder<> Builder(TheContext);
24 static AnalysisManager<int, int> TheAM;
25 static std::map<std::string, Value *> NamedValues;
26 static std::unique_ptr<FunctionPassManager> TheFPM;
27 static std::unique_ptr<FunctionAnalysisManager> TheFAM;
28
29 void InitializeModuleAndPassManager(void) {
30         TheModule = llvm::make_unique<Module>("Kaleidoscope jit", TheContext);
31         
32         TheFPM = make_unique<FunctionPassManager>();
33         TheFAM = make_unique<FunctionAnalysisManager>();
34         
35         PassBuilder PB;
36         PB.registerFunctionAnalyses(*TheFAM);
37         
38         TheFPM->addPass(InstCombinePass());
39         TheFPM->addPass(ReassociatePass());
40         TheFPM->addPass(SimplifyCFGPass());
41         TheFPM->addPass(GVNHoistPass());
42         TheFPM->addPass(GVNSinkPass());
43 }
44
45 std::map<std::string, std::unique_ptr<PrototypeAST>> functionProtos;
46
47 Function *getFunction(std::string name) {
48         if (auto *func = TheModule->getFunction(name))
49                 return func;
50         
51         auto iterator = functionProtos.find(name);
52         if (iterator != functionProtos.end())
53                 return iterator->second->codegen();
54         
55         return nullptr;
56 }
57
58 Value *LogErrorV(std::string str) {
59         LogError(str);
60         return nullptr;
61 }
62
63 Value *NumberExprAST::codegen() {
64         return ConstantFP::get(TheContext, APFloat(Val));
65 }
66
67 Value *VariableExprAST::codegen() {
68         Value *V = NamedValues[Name];
69         if (!V)
70                 LogErrorV("Unknown variable name");
71         return V;
72 }
73
74 Value *BinaryExprAST::codegen() {
75         Value *L = LHS->codegen();
76         Value *R = RHS->codegen();
77         
78         if (!L || !R) return nullptr;
79         
80         switch (Op) {
81                 case '+':
82                         return Builder.CreateFAdd(L, R, "addtmp");
83                 case '-':
84                         return Builder.CreateFSub(L, R, "subtmp");
85                 case '*':
86                         return Builder.CreateFMul(L, R, "multmp");
87                 case '<':
88                         L = Builder.CreateFCmpULT(L, R, "cmptmp");
89                         // convert bool 0/1 to double 0.0/1.0
90                         return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
91                 case '>':
92                         L = Builder.CreateFCmpUGT(L, R, "cmptmp");
93                         return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
94                 case '|':
95                         L = Builder.CreateFCmpUGE(L, ConstantFP::get(TheContext, APFloat(1.0)), "leftbooltmp");
96                         R = Builder.CreateFCmpUGE(R, ConstantFP::get(TheContext, APFloat(1.0)), "leftbooltmp");
97                         L = Builder.CreateOr(L, R, "orbooltmp");
98                         return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "ortmp");
99                 default:
100                         break;
101         }
102
103         Function *F = getFunction(std::string("binary") + Op);
104         assert(F && "binary oprator not found!");
105
106         Value *Ops[2] = {L, R};
107         return Builder.CreateCall(F, Ops, "binop");
108 }
109
110 Value *CallExprAST::codegen() {
111         Function *CalleeF = getFunction(Callee);
112         if (!CalleeF)
113                 return LogErrorV("Unknown function referenced");
114         
115         if (CalleeF->arg_size() != Args.size())
116                 return LogErrorV("Incorrect number of arguments passed");
117         
118         std::vector<Value *> ArgsV;
119         for (unsigned long i = 0, e = Args.size(); i != e; ++i) {
120                 ArgsV.push_back(Args[i]->codegen());
121                 if (!ArgsV.back()) return nullptr;
122         }
123         
124         return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
125 }
126
127 Function *PrototypeAST::codegen() {
128         std::vector<Type*> Doubles(Args.size(), Type::getDoubleTy(TheContext));
129         
130         FunctionType *FT = FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
131         Function *func = Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
132         
133         unsigned i = 0;
134         for (auto &arg : func->args())
135                 arg.setName(Args[i++]);
136         
137         return func;
138 }
139
140 Function *FunctionAST::codegen() {
141         //Transfer ownership but keep a reference
142         auto &P = *Prototype;
143         functionProtos[Prototype->getName()] = std::move(Prototype);
144         
145         Function *func = getFunction(P.getName());
146         
147         if (!func) return nullptr;
148
149         if (P.isBinaryOp())
150                 BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
151         
152         BasicBlock *bb = BasicBlock::Create(TheContext, "entry", func);
153         Builder.SetInsertPoint(bb);
154         
155         NamedValues.clear();
156         for (auto &arg: func->args())
157                 NamedValues[arg.getName()] = &arg;
158         
159         if (Value *retVal = Body->codegen()) {
160                 Builder.CreateRet(retVal);
161                 
162                 if (verifyFunction(*func, &errs())) {
163                         func->print(errs());
164                         func->eraseFromParent();
165                         return nullptr;
166                 }
167                 
168                 TheFPM->run(*func, *TheFAM);
169                 return func;
170         }
171         
172         func->eraseFromParent();
173         return nullptr;
174 }
175
176 Value *IfExprAST::codegen() {
177         Value *condV = Cond->codegen();
178         if (!condV) return nullptr;
179         
180         //convert to bool
181         condV = Builder.CreateFCmpONE(condV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
182         
183         Function *func = Builder.GetInsertBlock()->getParent();
184         
185         BasicBlock *thenBB = BasicBlock::Create(TheContext, "then", func);
186         BasicBlock *elseBB = BasicBlock::Create(TheContext, "else", func);
187         BasicBlock *mergeBB = BasicBlock::Create(TheContext, "merge", func);
188         
189         Builder.CreateCondBr(condV, thenBB, elseBB);
190         
191         Builder.SetInsertPoint(thenBB);
192         
193         Value *thenV = Then->codegen();
194         if (!thenV) return nullptr;
195         
196         Builder.CreateBr(mergeBB);
197         
198         //codegen of then can change the current block, so get then again
199         thenBB = Builder.GetInsertBlock();
200         
201         Builder.SetInsertPoint(elseBB);
202         
203         Value *elseV = Else->codegen();
204         if (!elseV) return nullptr;
205         
206         Builder.CreateBr(mergeBB);
207         
208         //codegen of else can change the current block, so get else again
209         elseBB = Builder.GetInsertBlock();
210         
211         Builder.SetInsertPoint(mergeBB);
212         PHINode *phiNode = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
213         phiNode->addIncoming(thenV, thenBB);
214         phiNode->addIncoming(elseV, elseBB);
215         
216         return phiNode;
217 }
218
219 Value *ForExprAST::codegen() {
220         auto *startValue = Start->codegen();
221         if (!startValue) return nullptr;
222         
223         auto *func = Builder.GetInsertBlock()->getParent();
224         auto *preheaderBB = Builder.GetInsertBlock();
225         auto *loopBB = BasicBlock::Create(TheContext, "loop", func);
226         
227         Builder.CreateBr(loopBB);
228         
229         Builder.SetInsertPoint(loopBB);
230         
231         auto *index = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, VarName.c_str());
232         index->addIncoming(startValue, preheaderBB);
233         
234         // if the index variable shadows an existing value, save it to restore later
235         auto *oldVal = NamedValues[VarName];
236         NamedValues[VarName] = index;
237         
238         // emit the loop body
239         if (!Body->codegen()) return nullptr;
240         
241         Value *stepVal = nullptr;
242         if (Step) {
243                 stepVal = Step->codegen();
244                 if (!stepVal) return nullptr;
245         } else {
246                 stepVal = ConstantFP::get(TheContext, APFloat(1.0));
247         }
248         
249         // increment the index
250         auto *nextVar = Builder.CreateFAdd(index, stepVal, "nextvar");
251         
252         auto *endCond = End->codegen();
253         if (!endCond) return nullptr;
254         
255         endCond = Builder.CreateFCmpONE(endCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
256         
257         auto *loopEndBB = Builder.GetInsertBlock();
258         auto *afterBB = BasicBlock::Create(TheContext, "afterloop", func);
259         
260         Builder.CreateCondBr(endCond, loopBB, afterBB);
261         
262         Builder.SetInsertPoint(afterBB);
263         
264         index->addIncoming(nextVar, loopEndBB);
265         
266         // restore shadowed index variable
267         if (oldVal)
268                 NamedValues[VarName] = oldVal;
269         else
270                 NamedValues.erase(VarName);
271         
272         return Constant::getNullValue(Type::getDoubleTy(TheContext));
273 }