Codegen if statements
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     (else #f)))
14
15 (define (pretty-type t)
16   (cond ((abs? t)
17          (string-append
18           (if (abs? (cadr t))
19               (string-append "(" (pretty-type (cadr t)) ")")
20               (pretty-type (cadr t)))
21           " -> "
22           (pretty-type (caddr t))))
23         (else (symbol->string t))))
24
25                                         ; ('a, ('b, 'a))
26 (define (env-lookup env n)
27   (if (null? env) (error #f "empty env")                        ; it's a type equality
28       (if (eq? (caar env) n)
29           (cdar env)
30           (env-lookup (cdr env) n))))
31
32 (define (env-insert env n t)
33   (cons (cons n t) env))
34
35 (define abs-arg cadr)
36
37 (define cur-tvar 0)
38 (define (fresh-tvar)
39   (begin
40     (set! cur-tvar (+ cur-tvar 1))
41     (string->symbol
42      (string-append "t" (number->string (- cur-tvar 1))))))
43
44 (define (last xs)
45   (if (null? (cdr xs))
46       (car xs)
47       (last (cdr xs))))
48                                 
49 (define (normalize prog) ; (+ a b) -> ((+ a) b)
50   (case (ast-type prog)
51     ('lambda 
52                                         ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
53         (if (> (length (lambda-args prog)) 1)
54             (list 'lambda (list (car (lambda-args prog)))
55                   (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
56             (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
57     ('app
58      (if (null? (cddr prog))
59          `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
60          `(,(list (normalize (car prog)) (normalize (cadr prog)))
61            ,(normalize (caddr prog))))) ; (f a b)
62     ('let
63         (append (list 'let
64                       (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
65                            (let-bindings prog)))
66                 (map normalize (let-body prog))))
67     (else (ast-traverse normalize prog))))
68
69 (define (builtin-type x)
70   (case x
71     ('+ '(abs int (abs int int)))
72     ('- '(abs int (abs int int)))
73     ('* '(abs int (abs int int)))
74     ('! '(abs bool bool))
75     ('= '(abs int (abs int bool)))
76     ('bool->int '(abs bool int))
77     (else #f)))
78
79 ; we typecheck the lambda calculus only (only single arg lambdas)
80 (define (typecheck prog)
81   (define (check env x)
82     ;; (display "check: ")
83     ;; (display x)
84     ;; (display "\n\t")
85     ;; (display env)
86     ;; (newline)
87     (let
88         ((res
89           (case (ast-type x)
90            ('int-literal (list '() 'int))
91            ('bool-literal (list '() 'bool))
92            ('builtin (list '() (builtin-type x)))
93
94            ('if
95             (let* ((cond-type-res (check env (cadr x)))
96                    (then-type-res (check env (caddr x)))
97                    (else-type-res (check env (cadddr x)))
98                    (then-eq-else-cs (unify (cadr then-type-res)
99                                            (cadr else-type-res)))
100                    (cs (consolidate
101                         (car then-type-res)
102                         (consolidate (car else-type-res)
103                                      then-eq-else-cs)))
104                    (return-type (substitute cs (cadr then-type-res))))
105               (when (not (eqv? (cadr cond-type-res) 'bool))
106                 (error #f "if condition isn't bool"))
107               (list cs return-type)))
108            
109            ('var  (list '() (env-lookup env x)))
110            ('let
111             (let ((new-env (fold-left
112                             (lambda (acc bind)
113                               (let ((t (check
114                                         (env-insert acc (car bind) (fresh-tvar))
115                                         (cadr bind))))
116                                 (env-insert acc (car bind) (cadr t))))
117                             env (let-bindings x))))
118               (check new-env (last (let-body x)))))
119                   
120
121            ('lambda
122             (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
123                    (body-type-res (check new-env (lambda-body x)))
124                    (cs (car body-type-res))
125                    (subd-env (substitute-env (car body-type-res) new-env))
126                    (arg-type (env-lookup subd-env (lambda-arg x)))
127                    (resolved-arg-type (substitute cs arg-type)))
128               ;; (display "lambda:\n\t")
129               ;; (display prog)
130               ;; (display "\n\t")
131               ;; (display cs)
132               ;; (display "\n\t")
133               ;; (display resolved-arg-type)
134               ;; (newline)
135               (list (car body-type-res)
136                     (list 'abs
137                           resolved-arg-type
138                           (cadr body-type-res)))))
139            
140            ('app ; (f a)
141             (let* ((arg-type-res (check env (cadr x)))
142                    (arg-type (cadr arg-type-res))
143                    (func-type-res (check env (car x)))
144                    (func-type (cadr func-type-res))
145                    
146                                         ; f ~ a -> t0
147                    (func-c (unify func-type
148                                   (list 'abs
149                                         arg-type
150                                         (fresh-tvar))))
151                    (cs (consolidate
152                         (consolidate func-c (car arg-type-res))
153                         (car func-type-res)))
154                    
155                    (resolved-func-type (substitute cs func-type))
156                    (resolved-return-type (caddr resolved-func-type)))
157               ;; (display "app:\n")
158               ;; (display cs)
159               ;; (display "\n")
160               ;; (display func-type)
161               ;; (display "\n")
162               ;; (display resolved-func-type)
163               ;; (display "\n")
164               ;; (display arg-type-res)
165               ;; (display "\n")
166               (if (abs? resolved-func-type)
167                   (let ((return-type (substitute cs (caddr resolved-func-type))))
168                     (list cs return-type))
169                   (error #f "not a function")))))))
170       ;; (display "result of ")
171       ;; (display x)
172       ;; (display ":\n\t")
173       ;; (display (cadr res))
174       ;; (display "[")
175       ;; (display (car res))
176       ;; (display "]\n")
177       res))
178   (cadr (check '() (normalize prog))))
179
180                                         ; returns a list of pairs of constraints
181 (define (unify a b)
182   (cond ((eq? a b) '())
183         ((or (tvar? a) (tvar? b)) (~ a b))
184         ((and (abs? a) (abs? b))
185          (consolidate (unify (cadr a) (cadr b))
186                       (unify (caddr a) (caddr b))))
187         (else (error #f "could not unify"))))
188
189                                         ; TODO: what's the most appropriate substitution?
190                                         ; should all constraints just be limited to a pair?
191 (define (substitute cs t)
192                                         ; gets the first concrete type
193                                         ; otherwise returns the last type variable
194
195   (define (get-concrete c)
196     (let ((last (null? (cdr c))))
197       (if (not (tvar? (car c)))
198           (if (abs? (car c))
199               (substitute cs (car c))
200               (car c))
201           (if last
202               (car c)
203               (get-concrete (cdr c))))))
204   (cond
205    ((abs? t) (list 'abs
206                    (substitute cs (cadr t))
207                    (substitute cs (caddr t))))
208    (else
209     (fold-left
210      (lambda (t c)
211        (if (member t c)
212            (get-concrete c)
213            t))
214      t cs))))
215
216 (define (substitute-env cs env)
217   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
218
219 (define (~ a b)
220   (list (list a b)))
221
222 (define (consolidate x y)
223   (define (merge a b)
224     (cond ((null? a) b)
225           ((null? b) a)
226           (else (if (member (car b) a)
227                     (merge a (cdr b))
228                     (cons (car b) (merge a (cdr b)))))))
229   (define (overlap? a b)
230     (if (or (null? a) (null? b))
231         #f
232         (if (fold-left (lambda (acc v)
233                          (or acc (eq? v (car a))))
234                        #f b)
235             #t
236             (overlap? (cdr a) b))))
237
238   (cond ((null? y) x)
239         ((null? x) y)
240         (else (let* ((a (car y))
241                      (merged (fold-left
242                               (lambda (acc b)
243                                 (if acc
244                                     acc
245                                     (if (overlap? a b)
246                                         (cons (merge a b) b)
247                                         #f)))
248                               #f x))
249                      (removed (if merged
250                                   (filter (lambda (b) (not (eq? b (cdr merged)))) x)
251                                   x)))
252                 (if merged
253                     (consolidate removed (cons (car merged) (cdr y)))
254                     (consolidate (cons a x) (cdr y)))))))