Add recursive let-bindings
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     (else #f)))
14
15 (define (pretty-type t)
16   (cond ((abs? t)
17          (string-append
18           (if (abs? (cadr t))
19               (string-append "(" (pretty-type (cadr t)) ")")
20               (pretty-type (cadr t)))
21           " -> "
22           (pretty-type (caddr t))))
23         (else (symbol->string t))))
24
25                                         ; ('a, ('b, 'a))
26 (define (env-lookup env n)
27   (if (null? env) (error #f "empty env")                        ; it's a type equality
28       (if (eq? (caar env) n)
29           (cdar env)
30           (env-lookup (cdr env) n))))
31
32 (define (env-insert env n t)
33   (cons (cons n t) env))
34
35 (define abs-arg cadr)
36
37 (define cur-tvar 0)
38 (define (fresh-tvar)
39   (begin
40     (set! cur-tvar (+ cur-tvar 1))
41     (string->symbol
42      (string-append "t" (number->string (- cur-tvar 1))))))
43
44 (define (last xs)
45   (if (null? (cdr xs))
46       (car xs)
47       (last (cdr xs))))
48
49                                         
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (cond
52    ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
53    ((lambda? prog)
54     (if (> (length (lambda-args prog)) 1)
55         (list 'lambda (list (car (lambda-args prog)))
56               (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57         (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58    ((app? prog)
59     (if (null? (cddr prog))
60         `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
61         `(,(list (normalize (car prog)) (normalize (cadr prog)))
62           ,(normalize (caddr prog))))) ; (f a b)
63         ;; (list (list (normalize (car prog))
64         ;;          (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
65    ((let? prog)
66     (append (list 'let
67                   (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
68                        (let-bindings prog)))
69             (map normalize (let-body prog))))
70    (else prog)))
71
72 (define (builtin-type x)
73   (case x
74     ('+ '(abs int (abs int int)))
75     ('- '(abs int (abs int int)))
76     ('* '(abs int (abs int int)))
77     ('! '(abs bool bool))
78     ('bool->int '(abs bool int))
79     (else #f)))
80
81 ; we typecheck the lambda calculus only (only single arg lambdas)
82 (define (typecheck prog)
83   (define (check env x)
84     ;; (display "check: ")
85     ;; (display x)
86     ;; (display "\n\t")
87     ;; (display env)
88     ;; (newline)
89     (let
90         ((res
91           (cond
92            ((integer? x) (list '() 'int))
93            ((boolean? x) (list '() 'bool))
94            ((builtin-type x) (list '() (builtin-type x)))
95            ((symbol? x)  (list '() (env-lookup env x)))
96            ((let? x)
97             (let ((new-env (fold-left
98                             (lambda (acc bind)
99                               (let ((t (check
100                                         (env-insert acc (car bind) (fresh-tvar))
101                                         (cadr bind))))
102                                 (env-insert acc (car bind) (cadr t))))
103                             env (let-bindings x))))
104               (check new-env (last (let-body x)))))
105                   
106
107            ((lambda? x)
108             (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
109                    (body-type-res (check new-env (lambda-body x)))
110                    (cs (car body-type-res))
111                    (subd-env (substitute-env (car body-type-res) new-env))
112                    (arg-type (env-lookup subd-env (lambda-arg x)))
113                    (resolved-arg-type (substitute cs arg-type)))
114               ;; (display "lambda:\n\t")
115               ;; (display prog)
116               ;; (display "\n\t")
117               ;; (display cs)
118               ;; (display "\n\t")
119               ;; (display resolved-arg-type)
120               ;; (newline)
121               (list (car body-type-res)
122                     (list 'abs
123                           resolved-arg-type
124                           (cadr body-type-res)))))
125            
126            ((app? x) ; (f a)
127             (let* ((arg-type-res (check env (cadr x)))
128                    (arg-type (cadr arg-type-res))
129                    (func-type-res (check env (car x)))
130                    (func-type (cadr func-type-res))
131                    
132                                         ; f ~ a -> t0
133                    (func-c (unify func-type
134                                   (list 'abs
135                                         arg-type
136                                         (fresh-tvar))))
137                    (cs (consolidate
138                         (consolidate func-c (car arg-type-res))
139                         (car func-type-res)))
140                    
141                    (resolved-func-type (substitute cs func-type))
142                    (resolved-return-type (caddr resolved-func-type)))
143               ;; (display "app:\n")
144               ;; (display cs)
145               ;; (display "\n")
146               ;; (display func-type)
147               ;; (display "\n")
148               ;; (display resolved-func-type)
149               ;; (display "\n")
150               ;; (display arg-type-res)
151               ;; (display "\n")
152               (if (abs? resolved-func-type)
153                   (let ((return-type (substitute cs (caddr resolved-func-type))))
154                     (list cs return-type))
155                   (error #f "not a function")))))))
156       ;; (display "result of ")
157       ;; (display x)
158       ;; (display ":\n\t")
159       ;; (display (cadr res))
160       ;; (display "[")
161       ;; (display (car res))
162       ;; (display "]\n")
163       res))
164   (cadr (check '() (normalize prog))))
165
166                                         ; returns a list of pairs of constraints
167 (define (unify a b)
168   (cond ((eq? a b) '())
169         ((or (tvar? a) (tvar? b)) (~ a b))
170         ((and (abs? a) (abs? b))
171          (consolidate (unify (cadr a) (cadr b))
172                       (unify (caddr a) (caddr b))))
173         (else (error #f "could not unify"))))
174
175                                         ; TODO: what's the most appropriate substitution?
176                                         ; should all constraints just be limited to a pair?
177 (define (substitute cs t)
178                                         ; gets the first concrete type
179                                         ; otherwise returns the last type variable
180
181   (define (get-concrete c)
182     (let ((last (null? (cdr c))))
183       (if (not (tvar? (car c)))
184           (if (abs? (car c))
185               (substitute cs (car c))
186               (car c))
187           (if last
188               (car c)
189               (get-concrete (cdr c))))))
190   (cond
191    ((abs? t) (list 'abs
192                    (substitute cs (cadr t))
193                    (substitute cs (caddr t))))
194    (else
195     (fold-left
196      (lambda (t c)
197        (if (member t c)
198            (get-concrete c)
199            t))
200      t cs))))
201
202 (define (substitute-env cs env)
203   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
204
205 (define (~ a b)
206   (list (list a b)))
207
208 (define (consolidate x y)
209   (define (merge a b)
210     (cond ((null? a) b)
211           ((null? b) a)
212           (else (if (member (car b) a)
213                     (merge a (cdr b))
214                     (cons (car b) (merge a (cdr b)))))))
215   (define (overlap? a b)
216     (if (or (null? a) (null? b))
217         #f
218         (if (fold-left (lambda (acc v)
219                          (or acc (eq? v (car a))))
220                        #f b)
221             #t
222             (overlap? (cdr a) b))))
223
224   (cond ((null? y) x)
225         ((null? x) y)
226         (else (let* ((a (car y))
227                      (merged (fold-left
228                               (lambda (acc b)
229                                 (if acc
230                                     acc
231                                     (if (overlap? a b)
232                                         (cons (merge a b) b)
233                                         #f)))
234                               #f x))
235                      (removed (if merged
236                                   (filter (lambda (b) (not (eq? b (cdr merged)))) x)
237                                   x)))
238                 (if merged
239                     (consolidate removed (cons (car merged) (cdr y)))
240                     (consolidate (cons a x) (cdr y)))))))