ac75490ff45d8c343200b0e3a4da55bbbadc7b08
[kaleidoscope.git] / Kaleidoscope / lexer.cpp
1 #include <cstdio>
2 #include <cctype>
3 #include <cstdlib>
4 #include <string>
5 #include <map>
6 #include <iostream>
7 #include <iomanip>
8 #include <llvm/ADT/STLExtras.h>
9 #include <llvm/Support/raw_ostream.h>
10 #include <llvm/ExecutionEngine/ExecutionEngine.h>
11 #include <llvm/ExecutionEngine/GenericValue.h>
12 #include <llvm/Support/TargetSelect.h>
13 //Needed to force linking interpreter
14 #include <llvm/ExecutionEngine/MCJIT.h>
15 #include "lexer.h"
16 #include "ast.hpp"
17 #include "codegen.hpp"
18 #include "shared.h"
19
20 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
21                                                                                           std::unique_ptr<ExprAST> LHS);
22 static std::unique_ptr<ExprAST> ParseExpression();
23
24 static ExecutionEngine* TheEngine;
25
26 std::unique_ptr<Module> TheModule;
27
28 constexpr auto ANON_EXPR_FUNC_NAME = "__anon_expr";
29
30 enum Token {
31         tok_eof = -1,
32         
33         tok_def = -2,
34         tok_extern = -3,
35         
36         tok_identifier = -4,
37         tok_number = -5,
38 };
39
40 static std::string IdentifierStr;
41 static double NumVal;
42
43 ///Returns the next token from stdin
44 static int gettok() {
45         static int LastChar = ' ';
46
47         while(isspace(LastChar))
48                 LastChar = getchar();
49
50         if(isalpha(LastChar)) {
51                 IdentifierStr = LastChar;
52                 while(isalnum((LastChar = getchar())))
53                         IdentifierStr += LastChar;
54
55                 if(IdentifierStr == "def")
56                         return tok_def;
57                 if(IdentifierStr == "extern")
58                         return tok_extern;
59                 return tok_identifier;
60         }
61
62         if(isdigit(LastChar) || LastChar == '.') {
63                 std::string NumStr;
64                 do {
65                         NumStr += LastChar;
66                         LastChar = getchar();
67                 } while (isdigit(LastChar) || LastChar == '.');
68
69                 NumVal = strtod(NumStr.c_str(), 0);
70                 return tok_number;
71         }
72
73         if(LastChar == '#') {
74                 //Coment until the end of the line
75                 do
76                         LastChar = getchar();
77                 while(LastChar != EOF && LastChar != '\n' && LastChar != 'r');
78
79                 if(LastChar != EOF)
80                         return gettok();
81         }
82
83         //Check for end of file
84         if(LastChar == EOF)
85                 return tok_eof;
86
87         int ThisChar = LastChar;
88         LastChar = getchar();
89         return ThisChar;
90 }
91
92
93 static int CurTok;
94 static int getNextToken() {
95         return CurTok = gettok();
96 }
97
98 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
99         LogError(Str);
100         return nullptr;
101 }
102
103 static std::unique_ptr<ExprAST> ParseNumberExpr() {
104         auto Result = llvm::make_unique<NumberExprAST>(NumVal);
105         getNextToken(); // consume the number
106         return std::move(Result);
107 }
108
109 static std::unique_ptr<ExprAST> ParseParenExpr() {
110         getNextToken();
111         auto V = ParseExpression();
112         if (!V)
113                 return nullptr;
114         
115         if (CurTok != ')')
116                 return LogError("expected ')'");
117         getNextToken();
118         return V;
119 }
120
121 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
122         std::string IdName = IdentifierStr;
123         
124         getNextToken();
125         
126         if (CurTok != '(')
127                 return llvm::make_unique<VariableExprAST>(IdName);
128         
129         // call a function
130         getNextToken();
131         std::vector<std::unique_ptr<ExprAST>> Args;
132         if (CurTok != ')') {
133                 while (true) {
134                         if (auto Arg = ParseExpression())
135                                 Args.push_back(std::move(Arg));
136                         else
137                                 return nullptr;
138                         
139                         if (CurTok == ')')
140                                 break;
141                         
142                         if (CurTok != ',')
143                                 return LogError("Expected ')' or ',' in argument list");
144                         getNextToken();
145                 }
146         }
147         
148         getNextToken();
149         return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
150 }
151
152 static std::unique_ptr<ExprAST> ParsePrimary() {
153         switch (CurTok) {
154                 default:
155                         return LogError("Unknown token when expected an expression");
156                 case tok_identifier:
157                         return ParseIdentifierExpr();
158                 case tok_number:
159                         return ParseNumberExpr();
160                 case '(':
161                         return ParseParenExpr();
162         }
163 }
164
165 static std::map<char, int> BinopPrecedence;
166
167 static int GetTokPrecedence() {
168         if (!isascii(CurTok))
169                 return -1;
170         
171         int TokPrec = BinopPrecedence[CurTok];
172         if (TokPrec <= 0) return -1;
173         return TokPrec;
174 }
175
176 static std::unique_ptr<ExprAST> ParseExpression() {
177         auto LHS = ParsePrimary();
178         if (!LHS)
179                 return nullptr;
180         return ParseBinOpRHS(0, std::move(LHS));
181 }
182
183 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
184                                                                                           std::unique_ptr<ExprAST> LHS) {
185         while (true) {
186                 int TokPrec = GetTokPrecedence();
187                 if (TokPrec < ExprPrec)
188                         return LHS;
189                 
190                 int BinOp = CurTok;
191                 getNextToken(); // eat binop
192                 
193                 auto RHS = ParsePrimary();
194                 if (!RHS)
195                         return nullptr;
196                 
197                 int NextPrec = GetTokPrecedence();
198                 if (TokPrec < NextPrec) {
199                         RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
200                         if (!RHS)
201                                 return nullptr;
202                 }
203                 
204                 LHS = llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
205         }
206 }
207
208 static std::unique_ptr<PrototypeAST> ParsePrototype() {
209         if (CurTok != tok_identifier)
210                 return LogErrorP("Expected function name in prototype");
211         
212         std::string FnName = IdentifierStr;
213         getNextToken();
214         
215         if (CurTok != '(')
216                 return LogErrorP("Expected '(' in prototype");
217         
218         std::vector<std::string> ArgNames;
219         while (getNextToken() == tok_identifier)
220                 ArgNames.push_back(IdentifierStr);
221         if (CurTok != ')')
222                 return LogErrorP("Expected ')' in prototype");
223         
224         getNextToken(); // eat ')'
225         
226         return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
227 }
228
229 static std::unique_ptr<FunctionAST> ParseDefinition() {
230         getNextToken();
231         auto Proto = ParsePrototype();
232         if (!Proto) return nullptr;
233         
234         if (auto E = ParseExpression())
235                 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
236         return nullptr;
237 }
238
239 static std::unique_ptr<PrototypeAST> ParseExtern() {
240         getNextToken();
241         return ParsePrototype();
242 }
243
244 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
245         if (auto e = ParseExpression()) {
246                 // make an anonymous prototype
247                 auto proto = llvm::make_unique<PrototypeAST>(ANON_EXPR_FUNC_NAME, std::vector<std::string>());
248                 return llvm::make_unique<FunctionAST>(std::move(proto), std::move(e));
249         }
250         return nullptr;
251 }
252
253 static void HandleDefinition() {
254         if (auto def = ParseDefinition()) {
255                 if (auto gen = def->codegen()) {
256                         std::cerr << "Read function defintion:";
257                         gen->print(errs());
258                         std::cerr << "\n";
259                         
260                         TheEngine->addModule(std::move(TheModule));
261                         InitializeModuleAndPassManager();
262                 }
263         }
264         else
265                 getNextToken(); // skip token for error recovery
266 }
267
268 static void HandleExtern() {
269         if (auto externProto = ParseExtern()) {
270                 if (auto gen = externProto->codegen()) {
271                         std::cerr << "Read an extern:\n";
272                         gen->print(errs());
273                         std::cerr << "\n";
274                         functionProtos[externProto->getName()] = std::move(externProto);
275                 }
276         }
277         else
278                 getNextToken(); // skip token for error recovery
279 }
280
281 static void HandleTopLevelExpr() {
282         if (auto expr = ParseTopLevelExpr()) {
283                 if (auto gen = expr->codegen()) {
284                         gen->print(errs());
285                         
286                         auto module = TheModule.get();
287                         TheEngine->addModule(std::move(TheModule));
288                         InitializeModuleAndPassManager();
289                         
290                         auto func = TheEngine->FindFunctionNamed(ANON_EXPR_FUNC_NAME);
291                         GenericValue gv = TheEngine->runFunction(func, std::vector<GenericValue>());
292                         std::cerr << "Evaluated to " << std::fixed << std::setw(5) << gv.DoubleVal << std::endl;
293                         
294                         TheEngine->removeModule(module);
295                 }
296         }
297         else
298                 getNextToken(); // skip token for error recovery
299 }
300
301 void mainLoop() {
302         while (true) {
303                 fprintf(stderr, "ready> ");
304                 switch(CurTok) {
305                         case tok_eof:
306                                 return;
307                         case ';':
308                                 getNextToken();
309                                 break;
310                         case tok_def:
311                                 HandleDefinition();
312                                 break;
313                         case tok_extern:
314                                 HandleExtern();
315                                 break;
316                         default:
317                                 HandleTopLevelExpr();
318                                 break;
319                 }
320         }
321 }
322
323 int main() {
324         InitializeNativeTarget();
325         InitializeNativeTargetAsmPrinter();
326         InitializeNativeTargetAsmParser();
327         
328         BinopPrecedence['<'] = 10;
329         BinopPrecedence['+'] = 20;
330         BinopPrecedence['-'] = 20;
331         BinopPrecedence['*'] = 40;
332         
333         InitializeModuleAndPassManager();
334         
335         std::string engineError;
336         TheEngine = EngineBuilder(std::move(TheModule)).setErrorStr(&engineError).create();
337         if (!engineError.empty())
338                 std::cout << engineError << "\n";
339         
340         InitializeModuleAndPassManager();
341         
342         // prime the first token
343         std::cerr << "ready> ";
344         getNextToken();
345         
346         mainLoop();
347         
348         PrintModule();
349
350         
351         return 0;
352 }