Resolve types in lambda arguments, recursively substitute
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2                                         ; ('a, ('b, 'a))
3 (define (env-lookup env n)
4   (if (null? env) (error #f "empty env")                        ; it's a type equality
5       (if (eq? (caar env) n)
6           (cdar env)
7           (env-lookup (cdr env) n))))
8
9 (define (env-insert env n t)
10   (cons (cons n t) env))
11
12 (define abs-arg cadr)
13
14 (define cur-tvar 0)
15 (define (fresh-tvar)
16   (begin
17     (set! cur-tvar (+ cur-tvar 1))
18     (string->symbol
19      (string-append "t" (number->string (- cur-tvar 1))))))
20
21 (define (last xs)
22   (if (null? (cdr xs))
23       (car xs)
24       (last (cdr xs))))
25
26                                         
27 (define (normalize prog) ; (+ a b) -> ((+ a) b)
28   (cond
29    ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
30    ((lambda? prog)
31     (if (> (length (lambda-args prog)) 1)
32         (list 'lambda (list (car (lambda-args prog)))
33               (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
34         (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
35    ((app? prog)
36     (if (null? (cddr prog))
37         (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
38         (list (list (normalize (car prog)) (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
39    ((let? prog)
40     (append (list 'let
41                   (map (lambda (x) (cons (car x) (normalize (cdr x))))
42                        (let-bindings prog)))
43             (map normalize (let-body prog))))
44    (else prog)))
45
46
47 ; we typecheck the lambda calculus only (only single arg lambdas)
48 (define (typecheck prog)
49   (define (check env x)
50     ;; (display "check: ")
51     ;; (display x)
52     ;; (display "\n\t")
53     ;; (display env)
54     ;; (newline)
55     (let
56         ((res
57           (cond
58            ((integer? x) (list '() 'int))
59            ((boolean? x) (list '() 'bool))
60            ((eq? x 'inc) (list '() '(abs int int)))
61            ((eq? x '+)   (list '() '(abs int (abs int int))))
62            ((symbol? x)  (list '() (env-lookup env x)))
63
64            ((let? x)
65             (let ((new-env (fold-left
66                             (lambda (acc bind)
67                               (let ((t (check
68                                         (env-insert acc (car bind) (fresh-tvar))
69                                         (cadr bind))))
70                                 (env-insert acc (car bind) (cadr t))))
71                             env (let-bindings x))))
72               (check new-env (last (let-body x)))))
73                   
74
75            ((lambda? x)
76             (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
77                    (body-type-res (check new-env (lambda-body x)))
78                    (cs (car body-type-res))
79                    (subd-env (substitute-env (car body-type-res) new-env))
80                    (arg-type (env-lookup subd-env (lambda-arg x)))
81                    (resolved-arg-type (substitute cs arg-type)))
82               ;; (display "lambda:\n\t")
83               ;; (display prog)
84               ;; (display "\n\t")
85               ;; (display cs)
86               ;; (display "\n\t")
87               ;; (display resolved-arg-type)
88               ;; (newline)
89               (list (car body-type-res)
90                     (list 'abs
91                           resolved-arg-type
92                           (cadr body-type-res)))))
93            
94            ((app? x) ; (f a)
95             (let* ((arg-type-res (check env (cadr x)))
96                    (arg-type (cadr arg-type-res))
97                    (func-type-res (check env (car x)))
98                    (func-type (cadr func-type-res))
99                    
100                                         ; f ~ a -> t0
101                    (func-c (unify func-type
102                                   (list 'abs
103                                         arg-type
104                                         (fresh-tvar))))
105                    (cs (consolidate
106                         (consolidate func-c (car arg-type-res))
107                         (car func-type-res)))
108                    
109                    (resolved-func-type (substitute cs func-type))
110                    (resolved-return-type (caddr resolved-func-type)))
111               ;; (display "app:\n")
112               ;; (display cs)
113               ;; (display "\n")
114               ;; (display func-type)
115               ;; (display "\n")
116               ;; (display resolved-func-type)
117               ;; (display "\n")
118               ;; (display arg-type-res)
119               ;; (display "\n")
120               (if (abs? resolved-func-type)
121                   (let ((return-type (substitute cs (caddr resolved-func-type))))
122                     (list cs return-type))
123                   (error #f "not a function")))))))
124       ;; (display "result of ")
125       ;; (display x)
126       ;; (display ":\n\t")
127       ;; (display (cadr res))
128       ;; (display "[")
129       ;; (display (car res))
130       ;; (display "]\n")
131       res))
132   (cadr (check '() (normalize prog))))
133
134
135 (define (abs? t)
136   (and (list? t) (eq? (car t) 'abs)))
137
138 (define (tvar? t)
139   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
140
141 (define (concrete? t)
142   (case t
143     ('int #t)
144     ('bool #t)
145     (else #f)))
146
147                                         ; returns a list of pairs of constraints
148 (define (unify a b)
149   (cond ((eq? a b) '())
150         ((or (tvar? a) (tvar? b)) (~ a b))
151         ((and (abs? a) (abs? b))
152          (consolidate (unify (cadr a) (cadr b))
153                       (unify (caddr a) (caddr b))))
154         (else (error #f "could not unify"))))
155
156                                         ; TODO: what's the most appropriate substitution?
157                                         ; should all constraints just be limited to a pair?
158 (define (substitute cs t)
159                                         ; gets the first concrete type
160                                         ; otherwise returns the last type variable
161
162   (define (get-concrete c)
163     (let ((last (null? (cdr c))))
164       (if (not (tvar? (car c)))
165           (if (abs? (car c))
166               (substitute cs (car c))
167               (car c))
168           (if last
169               (car c)
170               (get-concrete (cdr c))))))
171   (cond
172    ((abs? t) (list 'abs
173                    (substitute cs (cadr t))
174                    (substitute cs (caddr t))))
175    (else
176     (fold-left
177      (lambda (t c)
178        (if (member t c)
179            (get-concrete c)
180            t))
181      t cs))))
182
183 (define (substitute-env cs env)
184   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
185
186 (define (~ a b)
187   (list (list a b)))
188
189 (define (consolidate x y)
190   (define (merge a b)
191     (cond ((null? a) b)
192           ((null? b) a)
193           (else (if (member (car b) a)
194                     (merge a (cdr b))
195                     (cons (car b) (merge a (cdr b)))))))
196   (define (overlap? a b)
197     (if (or (null? a) (null? b))
198         #f
199         (if (fold-left (lambda (acc v)
200                          (or acc (eq? v (car a))))
201                        #f b)
202             #t
203             (overlap? (cdr a) b))))
204
205   (cond ((null? y) x)
206         ((null? x) y)
207         (else (let* ((a (car y))
208                      (merged (fold-left
209                               (lambda (acc b)
210                                 (if acc
211                                     acc
212                                     (if (overlap? a b)
213                                         (cons (merge a b) b)
214                                         #f)))
215                               #f x))
216                      (removed (if merged
217                                   (filter (lambda (b) (not (eq? b (cdr merged)))) x)
218                                   x)))
219                 (if merged
220                     (consolidate removed (cons (car merged) (cdr y)))
221                     (consolidate (cons a x) (cdr y)))))))