Add ast-traverse helper
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     (else #f)))
14
15 (define (pretty-type t)
16   (cond ((abs? t)
17          (string-append
18           (if (abs? (cadr t))
19               (string-append "(" (pretty-type (cadr t)) ")")
20               (pretty-type (cadr t)))
21           " -> "
22           (pretty-type (caddr t))))
23         (else (symbol->string t))))
24
25                                         ; ('a, ('b, 'a))
26 (define (env-lookup env n)
27   (if (null? env) (error #f "empty env")                        ; it's a type equality
28       (if (eq? (caar env) n)
29           (cdar env)
30           (env-lookup (cdr env) n))))
31
32 (define (env-insert env n t)
33   (cons (cons n t) env))
34
35 (define abs-arg cadr)
36
37 (define cur-tvar 0)
38 (define (fresh-tvar)
39   (begin
40     (set! cur-tvar (+ cur-tvar 1))
41     (string->symbol
42      (string-append "t" (number->string (- cur-tvar 1))))))
43
44 (define (last xs)
45   (if (null? (cdr xs))
46       (car xs)
47       (last (cdr xs))))
48
49                                         
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (case (ast-type prog)
52     ('lambda 
53                                         ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
54         (if (> (length (lambda-args prog)) 1)
55             (list 'lambda (list (car (lambda-args prog)))
56                   (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57             (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58     ('app
59      (if (null? (cddr prog))
60          `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
61          `(,(list (normalize (car prog)) (normalize (cadr prog)))
62            ,(normalize (caddr prog))))) ; (f a b)
63     ;; (list (list (normalize (car prog))
64     ;;      (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
65     ('let
66         (append (list 'let
67                       (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
68                            (let-bindings prog)))
69                 (map normalize (let-body prog))))
70     ('if `(if ,(normalize (cadr prog))
71               ,(normalize (caddr prog))
72               ,(normalize (cadddr prog))))
73     (else prog)))
74
75 (define (builtin-type x)
76   (case x
77     ('+ '(abs int (abs int int)))
78     ('- '(abs int (abs int int)))
79     ('* '(abs int (abs int int)))
80     ('! '(abs bool bool))
81     ('= '(abs int (abs int bool)))
82     ('bool->int '(abs bool int))
83     (else #f)))
84
85 ; we typecheck the lambda calculus only (only single arg lambdas)
86 (define (typecheck prog)
87   (define (check env x)
88     ;; (display "check: ")
89     ;; (display x)
90     ;; (display "\n\t")
91     ;; (display env)
92     ;; (newline)
93     (let
94         ((res
95           (case (ast-type x)
96            ('int-literal (list '() 'int))
97            ('bool-literal (list '() 'bool))
98            ('builtin (list '() (builtin-type x)))
99
100            ('if
101             (let* ((cond-type-res (check env (cadr x)))
102                    (then-type-res (check env (caddr x)))
103                    (else-type-res (check env (cadddr x)))
104                    (then-eq-else-cs (unify (cadr then-type-res)
105                                            (cadr else-type-res)))
106                    (cs (consolidate
107                         (car then-type-res)
108                         (consolidate (car else-type-res)
109                                      then-eq-else-cs)))
110                    (return-type (substitute cs (cadr then-type-res))))
111               (when (not (eqv? (cadr cond-type-res) 'bool))
112                 (error #f "if condition isn't bool"))
113               (list cs return-type)))
114            
115            ('var  (list '() (env-lookup env x)))
116            ('let
117             (let ((new-env (fold-left
118                             (lambda (acc bind)
119                               (let ((t (check
120                                         (env-insert acc (car bind) (fresh-tvar))
121                                         (cadr bind))))
122                                 (env-insert acc (car bind) (cadr t))))
123                             env (let-bindings x))))
124               (check new-env (last (let-body x)))))
125                   
126
127            ('lambda
128             (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
129                    (body-type-res (check new-env (lambda-body x)))
130                    (cs (car body-type-res))
131                    (subd-env (substitute-env (car body-type-res) new-env))
132                    (arg-type (env-lookup subd-env (lambda-arg x)))
133                    (resolved-arg-type (substitute cs arg-type)))
134               ;; (display "lambda:\n\t")
135               ;; (display prog)
136               ;; (display "\n\t")
137               ;; (display cs)
138               ;; (display "\n\t")
139               ;; (display resolved-arg-type)
140               ;; (newline)
141               (list (car body-type-res)
142                     (list 'abs
143                           resolved-arg-type
144                           (cadr body-type-res)))))
145            
146            ('app ; (f a)
147             (let* ((arg-type-res (check env (cadr x)))
148                    (arg-type (cadr arg-type-res))
149                    (func-type-res (check env (car x)))
150                    (func-type (cadr func-type-res))
151                    
152                                         ; f ~ a -> t0
153                    (func-c (unify func-type
154                                   (list 'abs
155                                         arg-type
156                                         (fresh-tvar))))
157                    (cs (consolidate
158                         (consolidate func-c (car arg-type-res))
159                         (car func-type-res)))
160                    
161                    (resolved-func-type (substitute cs func-type))
162                    (resolved-return-type (caddr resolved-func-type)))
163               ;; (display "app:\n")
164               ;; (display cs)
165               ;; (display "\n")
166               ;; (display func-type)
167               ;; (display "\n")
168               ;; (display resolved-func-type)
169               ;; (display "\n")
170               ;; (display arg-type-res)
171               ;; (display "\n")
172               (if (abs? resolved-func-type)
173                   (let ((return-type (substitute cs (caddr resolved-func-type))))
174                     (list cs return-type))
175                   (error #f "not a function")))))))
176       ;; (display "result of ")
177       ;; (display x)
178       ;; (display ":\n\t")
179       ;; (display (cadr res))
180       ;; (display "[")
181       ;; (display (car res))
182       ;; (display "]\n")
183       res))
184   (cadr (check '() (normalize prog))))
185
186                                         ; returns a list of pairs of constraints
187 (define (unify a b)
188   (cond ((eq? a b) '())
189         ((or (tvar? a) (tvar? b)) (~ a b))
190         ((and (abs? a) (abs? b))
191          (consolidate (unify (cadr a) (cadr b))
192                       (unify (caddr a) (caddr b))))
193         (else (error #f "could not unify"))))
194
195                                         ; TODO: what's the most appropriate substitution?
196                                         ; should all constraints just be limited to a pair?
197 (define (substitute cs t)
198                                         ; gets the first concrete type
199                                         ; otherwise returns the last type variable
200
201   (define (get-concrete c)
202     (let ((last (null? (cdr c))))
203       (if (not (tvar? (car c)))
204           (if (abs? (car c))
205               (substitute cs (car c))
206               (car c))
207           (if last
208               (car c)
209               (get-concrete (cdr c))))))
210   (cond
211    ((abs? t) (list 'abs
212                    (substitute cs (cadr t))
213                    (substitute cs (caddr t))))
214    (else
215     (fold-left
216      (lambda (t c)
217        (if (member t c)
218            (get-concrete c)
219            t))
220      t cs))))
221
222 (define (substitute-env cs env)
223   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
224
225 (define (~ a b)
226   (list (list a b)))
227
228 (define (consolidate x y)
229   (define (merge a b)
230     (cond ((null? a) b)
231           ((null? b) a)
232           (else (if (member (car b) a)
233                     (merge a (cdr b))
234                     (cons (car b) (merge a (cdr b)))))))
235   (define (overlap? a b)
236     (if (or (null? a) (null? b))
237         #f
238         (if (fold-left (lambda (acc v)
239                          (or acc (eq? v (car a))))
240                        #f b)
241             #t
242             (overlap? (cdr a) b))))
243
244   (cond ((null? y) x)
245         ((null? x) y)
246         (else (let* ((a (car y))
247                      (merged (fold-left
248                               (lambda (acc b)
249                                 (if acc
250                                     acc
251                                     (if (overlap? a b)
252                                         (cons (merge a b) b)
253                                         #f)))
254                               #f x))
255                      (removed (if merged
256                                   (filter (lambda (b) (not (eq? b (cdr merged)))) x)
257                                   x)))
258                 (if merged
259                     (consolidate removed (cons (car merged) (cdr y)))
260                     (consolidate (cons a x) (cdr y)))))))