WIP
[kaleidoscope.git] / Kaleidoscope / lexer.cpp
1 #include <cstdio>
2 #include <cctype>
3 #include <cstdlib>
4 #include <string>
5 #include <map>
6 #include <iostream>
7 #include <iomanip>
8 #include <llvm/ADT/STLExtras.h>
9 #include <llvm/Support/raw_ostream.h>
10 #include <llvm/ExecutionEngine/ExecutionEngine.h>
11 #include <llvm/ExecutionEngine/GenericValue.h>
12 #include <llvm/Support/TargetSelect.h>
13 //Needed to force linking interpreter
14 #include <llvm/ExecutionEngine/MCJIT.h>
15 #include "lexer.h"
16 #include "ast.hpp"
17 #include "codegen.hpp"
18 #include "shared.h"
19
20 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
21                                                                                           std::unique_ptr<ExprAST> LHS);
22 static std::unique_ptr<ExprAST> ParseExpression();
23
24 static ExecutionEngine* TheEngine;
25
26 std::unique_ptr<Module> TheModule;
27
28 constexpr auto ANON_EXPR_FUNC_NAME = "__anon_expr";
29
30 enum Token {
31         tok_eof = -1,
32         
33         tok_def = -2,
34         tok_extern = -3,
35         
36         tok_identifier = -4,
37         tok_number = -5,
38         
39         //control flow
40         tok_if = -6,
41         tok_then = -7,
42         tok_else = -8,
43         tok_for = -9,
44         tok_in = -10,
45         
46         //operators
47         tok_binary = -11,
48         tok_unary = -12
49 };
50
51 static std::string IdentifierStr;
52 static double NumVal;
53
54 ///Returns the next token from stdin
55 static int gettok() {
56         static int LastChar = ' ';
57
58         while (isspace(LastChar))
59                 LastChar = getchar();
60
61         if (isalpha(LastChar)) {
62                 IdentifierStr = LastChar;
63                 while (isalnum((LastChar = getchar())))
64                         IdentifierStr += LastChar;
65
66                 if (IdentifierStr == "def") return tok_def;
67                 if (IdentifierStr == "extern") return tok_extern;
68                 if (IdentifierStr == "if") return tok_if;
69                 if (IdentifierStr == "then") return tok_then;
70                 if (IdentifierStr == "else") return tok_else;
71                 if (IdentifierStr == "binary") return tok_binary;
72                 if (IdentifierStr == "unary") return tok_unary;
73                 
74                 return tok_identifier;
75         }
76
77         if (isdigit(LastChar) || LastChar == '.') {
78                 std::string NumStr;
79                 do {
80                         NumStr += LastChar;
81                         LastChar = getchar();
82                 } while (isdigit(LastChar) || LastChar == '.');
83
84                 NumVal = strtod(NumStr.c_str(), 0);
85                 return tok_number;
86         }
87
88         if (LastChar == '#') {
89                 //Coment until the end of the line
90                 do
91                         LastChar = getchar();
92                 while(LastChar != EOF && LastChar != '\n' && LastChar != '\r');
93
94                 if(LastChar != EOF)
95                         return gettok();
96         }
97
98         //Check for end of file
99         if (LastChar == EOF)
100                 return tok_eof;
101
102         int ThisChar = LastChar;
103         LastChar = getchar();
104         return ThisChar;
105 }
106
107
108 static int CurTok;
109 static int getNextToken() {
110         return CurTok = gettok();
111 }
112
113 std::unique_ptr<PrototypeAST> LogErrorP(std::string str) {
114         LogError(str);
115         return nullptr;
116 }
117
118 static std::unique_ptr<ExprAST> ParseNumberExpr() {
119         auto Result = llvm::make_unique<NumberExprAST>(NumVal);
120         getNextToken(); // consume the number
121         return std::move(Result);
122 }
123
124 static std::unique_ptr<ExprAST> ParseParenExpr() {
125         getNextToken();
126         auto V = ParseExpression();
127         if (!V)
128                 return nullptr;
129         
130         if (CurTok != ')')
131                 return LogError("expected ')'");
132         getNextToken();
133         return V;
134 }
135
136 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
137         std::string IdName = IdentifierStr;
138         
139         getNextToken();
140         
141         if (CurTok != '(')
142                 return llvm::make_unique<VariableExprAST>(IdName);
143         
144         // call a function
145         getNextToken();
146         std::vector<std::unique_ptr<ExprAST>> Args;
147         if (CurTok != ')') {
148                 while (true) {
149                         if (auto Arg = ParseExpression())
150                                 Args.push_back(std::move(Arg));
151                         else
152                                 return nullptr;
153                         
154                         if (CurTok == ')')
155                                 break;
156                         
157                         if (CurTok != ',')
158                                 return LogError("Expected ')' or ',' in argument list");
159                         getNextToken();
160                 }
161         }
162         
163         getNextToken();
164         return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
165 }
166
167 static std::unique_ptr<ExprAST> ParseIfExpr() {
168         getNextToken(); // eat the if
169         
170         auto condExpr = ParseExpression();
171         if (!condExpr) return nullptr;
172         
173         if (CurTok != tok_then)
174                 return LogError("Expected then");
175         getNextToken(); // eat the then
176         
177         auto thenExpr = ParseExpression();
178         if (!thenExpr) return nullptr;
179         
180         if (CurTok != tok_else)
181                 return LogError("Expected else");
182         getNextToken(); // eat the else
183         
184         auto elseExpr = ParseExpression();
185         if (!elseExpr) return nullptr;
186         
187         return llvm::make_unique<IfExprAST>(std::move(condExpr),
188                                                                                 std::move(thenExpr),
189                                                                                 std::move(elseExpr));
190 }
191
192 static std::unique_ptr<ExprAST> ParseForExpr() {
193         getNextToken(); // eat the for
194         
195         if (CurTok != tok_identifier)
196                 return LogError("Expected identifier after for");
197         
198         std::string id = IdentifierStr;
199         getNextToken(); // eat the identifier
200         
201         if (CurTok != '=')
202                 return LogError("Expected = after for");
203         getNextToken(); // eat the =
204         
205         auto start = ParseExpression();
206         if (!start) return nullptr;
207         
208         if (CurTok != ',')
209                 return LogError("Expected , after for start value");
210         getNextToken(); // eat the ,
211         
212         auto end = ParseExpression();
213         if (!end)
214                 return nullptr;
215         
216         // optional step value
217         std::unique_ptr<ExprAST> step;
218         if (CurTok == ',') {
219                 getNextToken(); // eat the ,
220                 step = ParseExpression();
221                 if (!step) return nullptr;
222         }
223         
224         if (CurTok != tok_in)
225                 return LogError("Expected in after for");
226         getNextToken(); // eat the in
227         
228         auto body = ParseExpression();
229         if (!body) return nullptr;
230         
231         return llvm::make_unique<ForExprAST>(id,
232                                                                                  std::move(start),
233                                                                                  std::move(end),
234                                                                                  std::move(step),
235                                                                                  std::move(body));
236 }
237
238 static std::unique_ptr<ExprAST> ParsePrimary() {
239         switch (CurTok) {
240                 default:
241                         return LogError(std::string("Unknown token when expected an expression: ")
242                                                         + std::to_string(CurTok));
243                 case tok_identifier:
244                         return ParseIdentifierExpr();
245                 case tok_number:
246                         return ParseNumberExpr();
247                 case '(':
248                         return ParseParenExpr();
249                 case tok_if:
250                         return ParseIfExpr();
251                 case tok_for:
252                         return ParseForExpr();
253         }
254 }
255
256 static int GetTokPrecedence() {
257         if (!isascii(CurTok))
258                 return -1;
259         
260         int TokPrec = BinopPrecedence[CurTok];
261         if (TokPrec <= 0) return -1;
262         return TokPrec;
263 }
264
265 static std::unique_ptr<ExprAST> ParseExpression() {
266         auto LHS = ParsePrimary();
267         if (!LHS)
268                 return nullptr;
269         return ParseBinOpRHS(0, std::move(LHS));
270 }
271
272 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
273                                                                                           std::unique_ptr<ExprAST> LHS) {
274         while (true) {
275                 int TokPrec = GetTokPrecedence();
276                 if (TokPrec < ExprPrec)
277                         return LHS;
278                 
279                 int BinOp = CurTok;
280                 getNextToken(); // eat binop
281                 
282                 auto RHS = ParsePrimary();
283                 if (!RHS)
284                         return nullptr;
285                 
286                 int NextPrec = GetTokPrecedence();
287                 if (TokPrec < NextPrec) {
288                         RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
289                         if (!RHS)
290                                 return nullptr;
291                 }
292                 
293                 LHS = llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
294         }
295 }
296
297 static std::unique_ptr<PrototypeAST> ParsePrototype() {
298         std::string FnName;
299         
300         unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary
301         unsigned BinaryPrecedence = 30;
302
303         switch (CurTok) {
304         default:
305                 return LogErrorP("Expected function name in prototype");
306         case tok_identifier:
307                 FnName = IdentifierStr;
308                 Kind = 0;
309                 getNextToken();
310                 break;
311         case tok_binary:
312                 getNextToken();
313                 if (!isascii(CurTok)) return LogErrorP("Expected binary operator");
314                 FnName = "binary";
315                 FnName += (char)CurTok;
316                 Kind = 2;
317                 getNextToken();
318                 if (CurTok == tok_number) {
319                         if (NumVal < 1 || NumVal > 100)
320                                 return LogErrorP("Invalid precedence: must be 1..100");
321                         BinaryPrecedence = (unsigned)NumVal;
322                         getNextToken();
323                 }
324                 break;
325         }
326         
327         if (CurTok != '(')
328                 return LogErrorP("Expected '(' in prototype");
329         
330         std::vector<std::string> ArgNames;
331         while (getNextToken() == tok_identifier)
332                 ArgNames.push_back(IdentifierStr);
333         if (CurTok != ')')
334                 return LogErrorP("Expected ')' in prototype");
335         
336         getNextToken(); // eat ')'
337
338         if (Kind && ArgNames.size() != Kind)
339                 return LogErrorP("Invalid number of of arguments for operator");
340         
341         return llvm::make_unique<PrototypeAST>(FnName,
342                                                                                    std::move(ArgNames),
343                                                                                    Kind != 0,
344                                                                                    BinaryPrecedence);
345 }
346
347 static std::unique_ptr<FunctionAST> ParseDefinition() {
348         getNextToken();
349         auto Proto = ParsePrototype();
350         if (!Proto) return nullptr;
351         
352         if (auto E = ParseExpression())
353                 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
354         return nullptr;
355 }
356
357 static std::unique_ptr<PrototypeAST> ParseExtern() {
358         getNextToken();
359         return ParsePrototype();
360 }
361
362 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
363         if (auto e = ParseExpression()) {
364                 // make an anonymous prototype
365                 auto proto = llvm::make_unique<PrototypeAST>(ANON_EXPR_FUNC_NAME, std::vector<std::string>());
366                 return llvm::make_unique<FunctionAST>(std::move(proto), std::move(e));
367         }
368         return nullptr;
369 }
370
371 static void HandleDefinition() {
372         if (auto def = ParseDefinition()) {
373                 if (auto gen = def->codegen()) {
374                         std::cerr << "Read function defintion:";
375                         gen->print(errs());
376                         std::cerr << "\n";
377                         
378                         TheEngine->addModule(std::move(TheModule));
379                         InitializeModuleAndPassManager();
380                 }
381         }
382         else
383                 getNextToken(); // skip token for error recovery
384 }
385
386 static void HandleExtern() {
387         if (auto externProto = ParseExtern()) {
388                 if (auto gen = externProto->codegen()) {
389                         std::cerr << "Read an extern:\n";
390                         gen->print(errs());
391                         std::cerr << "\n";
392                         functionProtos[externProto->getName()] = std::move(externProto);
393                 }
394         }
395         else
396                 getNextToken(); // skip token for error recovery
397 }
398
399 static void HandleTopLevelExpr() {
400         if (auto expr = ParseTopLevelExpr()) {
401                 if (auto gen = expr->codegen()) {
402                         auto module = TheModule.get();
403                         TheEngine->addModule(std::move(TheModule));
404                         InitializeModuleAndPassManager();
405                         
406                         auto func = TheEngine->FindFunctionNamed(ANON_EXPR_FUNC_NAME);
407                         GenericValue gv = TheEngine->runFunction(func, std::vector<GenericValue>());
408                         std::cerr << "Evaluated to " << std::fixed << std::setw(5) << gv.DoubleVal << std::endl;
409                         
410                         TheEngine->removeModule(module);
411                 }
412         }
413         else
414                 getNextToken(); // skip token for error recovery
415 }
416
417 void mainLoop() {
418         while (true) {
419                 fprintf(stderr, "ready> ");
420                 switch(CurTok) {
421                         case tok_eof:
422                                 return;
423                         case ';':
424                                 getNextToken();
425                                 break;
426                         case tok_def:
427                                 HandleDefinition();
428                                 break;
429                         case tok_extern:
430                                 HandleExtern();
431                                 break;
432                         default:
433                                 HandleTopLevelExpr();
434                                 break;
435                 }
436         }
437 }
438
439 int main() {
440         InitializeNativeTarget();
441         InitializeNativeTargetAsmPrinter();
442         InitializeNativeTargetAsmParser();
443         
444         BinopPrecedence['<'] = 10;
445         BinopPrecedence['>'] = 10;
446         BinopPrecedence['|'] = 5;
447         BinopPrecedence['+'] = 20;
448         BinopPrecedence['-'] = 20;
449         BinopPrecedence['*'] = 40;
450         
451         InitializeModuleAndPassManager();
452         
453         std::string engineError;
454         TheEngine = EngineBuilder(std::move(TheModule)).setEngineKind(EngineKind::JIT).setErrorStr(&engineError).create();
455         if (!engineError.empty())
456                 std::cout << engineError << "\n";
457         
458         InitializeModuleAndPassManager();
459         
460         // prime the first token
461         std::cerr << "ready> ";
462         getNextToken();
463         
464         mainLoop();
465
466         
467         return 0;
468 }