Add more binary ops
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2                                         ; ('a, ('b, 'a))
3 (define (env-lookup env n)
4   (if (null? env) (error #f "empty env")                        ; it's a type equality
5       (if (eq? (caar env) n)
6           (cdar env)
7           (env-lookup (cdr env) n))))
8
9 (define (env-insert env n t)
10   (cons (cons n t) env))
11
12 (define abs-arg cadr)
13
14 (define cur-tvar 0)
15 (define (fresh-tvar)
16   (begin
17     (set! cur-tvar (+ cur-tvar 1))
18     (string->symbol
19      (string-append "t" (number->string (- cur-tvar 1))))))
20
21 (define (last xs)
22   (if (null? (cdr xs))
23       (car xs)
24       (last (cdr xs))))
25
26                                         
27 (define (normalize prog) ; (+ a b) -> ((+ a) b)
28   (cond
29    ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
30    ((lambda? prog)
31     (if (> (length (lambda-args prog)) 1)
32         (list 'lambda (list (car (lambda-args prog)))
33               (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
34         (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
35    ((app? prog)
36     (if (null? (cddr prog))
37         (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
38         (list (list (normalize (car prog)) (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
39    ((let? prog)
40     (append (list 'let
41                   (map (lambda (x) (cons (car x) (normalize (cdr x))))
42                        (let-bindings prog)))
43             (map normalize (let-body prog))))
44    (else prog)))
45
46 (define (builtin-type x)
47   (case x
48     ('+ '(abs int (abs int int)))
49     ('- '(abs int (abs int int)))
50     ('* '(abs int (abs int int)))
51     ('! '(abs bool bool))
52     ('bool->int '(abs bool int))
53     (else #f)))
54
55 ; we typecheck the lambda calculus only (only single arg lambdas)
56 (define (typecheck prog)
57   (define (check env x)
58     ;; (display "check: ")
59     ;; (display x)
60     ;; (display "\n\t")
61     ;; (display env)
62     ;; (newline)
63     (let
64         ((res
65           (cond
66            ((integer? x) (list '() 'int))
67            ((boolean? x) (list '() 'bool))
68            ((builtin-type x) (list '() (builtin-type x)))
69            ((symbol? x)  (list '() (env-lookup env x)))
70            ((let? x)
71             (let ((new-env (fold-left
72                             (lambda (acc bind)
73                               (let ((t (check
74                                         (env-insert acc (car bind) (fresh-tvar))
75                                         (cadr bind))))
76                                 (env-insert acc (car bind) (cadr t))))
77                             env (let-bindings x))))
78               (check new-env (last (let-body x)))))
79                   
80
81            ((lambda? x)
82             (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
83                    (body-type-res (check new-env (lambda-body x)))
84                    (cs (car body-type-res))
85                    (subd-env (substitute-env (car body-type-res) new-env))
86                    (arg-type (env-lookup subd-env (lambda-arg x)))
87                    (resolved-arg-type (substitute cs arg-type)))
88               ;; (display "lambda:\n\t")
89               ;; (display prog)
90               ;; (display "\n\t")
91               ;; (display cs)
92               ;; (display "\n\t")
93               ;; (display resolved-arg-type)
94               ;; (newline)
95               (list (car body-type-res)
96                     (list 'abs
97                           resolved-arg-type
98                           (cadr body-type-res)))))
99            
100            ((app? x) ; (f a)
101             (let* ((arg-type-res (check env (cadr x)))
102                    (arg-type (cadr arg-type-res))
103                    (func-type-res (check env (car x)))
104                    (func-type (cadr func-type-res))
105                    
106                                         ; f ~ a -> t0
107                    (func-c (unify func-type
108                                   (list 'abs
109                                         arg-type
110                                         (fresh-tvar))))
111                    (cs (consolidate
112                         (consolidate func-c (car arg-type-res))
113                         (car func-type-res)))
114                    
115                    (resolved-func-type (substitute cs func-type))
116                    (resolved-return-type (caddr resolved-func-type)))
117               ;; (display "app:\n")
118               ;; (display cs)
119               ;; (display "\n")
120               ;; (display func-type)
121               ;; (display "\n")
122               ;; (display resolved-func-type)
123               ;; (display "\n")
124               ;; (display arg-type-res)
125               ;; (display "\n")
126               (if (abs? resolved-func-type)
127                   (let ((return-type (substitute cs (caddr resolved-func-type))))
128                     (list cs return-type))
129                   (error #f "not a function")))))))
130       ;; (display "result of ")
131       ;; (display x)
132       ;; (display ":\n\t")
133       ;; (display (cadr res))
134       ;; (display "[")
135       ;; (display (car res))
136       ;; (display "]\n")
137       res))
138   (cadr (check '() (normalize prog))))
139
140
141 (define (abs? t)
142   (and (list? t) (eq? (car t) 'abs)))
143
144 (define (tvar? t)
145   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
146
147 (define (concrete? t)
148   (case t
149     ('int #t)
150     ('bool #t)
151     (else #f)))
152
153                                         ; returns a list of pairs of constraints
154 (define (unify a b)
155   (cond ((eq? a b) '())
156         ((or (tvar? a) (tvar? b)) (~ a b))
157         ((and (abs? a) (abs? b))
158          (consolidate (unify (cadr a) (cadr b))
159                       (unify (caddr a) (caddr b))))
160         (else (error #f "could not unify"))))
161
162                                         ; TODO: what's the most appropriate substitution?
163                                         ; should all constraints just be limited to a pair?
164 (define (substitute cs t)
165                                         ; gets the first concrete type
166                                         ; otherwise returns the last type variable
167
168   (define (get-concrete c)
169     (let ((last (null? (cdr c))))
170       (if (not (tvar? (car c)))
171           (if (abs? (car c))
172               (substitute cs (car c))
173               (car c))
174           (if last
175               (car c)
176               (get-concrete (cdr c))))))
177   (cond
178    ((abs? t) (list 'abs
179                    (substitute cs (cadr t))
180                    (substitute cs (caddr t))))
181    (else
182     (fold-left
183      (lambda (t c)
184        (if (member t c)
185            (get-concrete c)
186            t))
187      t cs))))
188
189 (define (substitute-env cs env)
190   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
191
192 (define (~ a b)
193   (list (list a b)))
194
195 (define (consolidate x y)
196   (define (merge a b)
197     (cond ((null? a) b)
198           ((null? b) a)
199           (else (if (member (car b) a)
200                     (merge a (cdr b))
201                     (cons (car b) (merge a (cdr b)))))))
202   (define (overlap? a b)
203     (if (or (null? a) (null? b))
204         #f
205         (if (fold-left (lambda (acc v)
206                          (or acc (eq? v (car a))))
207                        #f b)
208             #t
209             (overlap? (cdr a) b))))
210
211   (cond ((null? y) x)
212         ((null? x) y)
213         (else (let* ((a (car y))
214                      (merged (fold-left
215                               (lambda (acc b)
216                                 (if acc
217                                     acc
218                                     (if (overlap? a b)
219                                         (cons (merge a b) b)
220                                         #f)))
221                               #f x))
222                      (removed (if merged
223                                   (filter (lambda (b) (not (eq? b (cdr merged)))) x)
224                                   x)))
225                 (if merged
226                     (consolidate removed (cons (car merged) (cdr y)))
227                     (consolidate (cons a x) (cdr y)))))))