Add pretty printing for types
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     (else #f)))
14
15 (define (pretty-type t)
16   (cond ((abs? t)
17          (string-append
18           (if (abs? (cadr t))
19               (string-append "(" (pretty-type (cadr t)) ")")
20               (pretty-type (cadr t)))
21           " -> "
22           (pretty-type (caddr t))))
23         (else (symbol->string t))))
24
25                                         ; ('a, ('b, 'a))
26 (define (env-lookup env n)
27   (if (null? env) (error #f "empty env")                        ; it's a type equality
28       (if (eq? (caar env) n)
29           (cdar env)
30           (env-lookup (cdr env) n))))
31
32 (define (env-insert env n t)
33   (cons (cons n t) env))
34
35 (define abs-arg cadr)
36
37 (define cur-tvar 0)
38 (define (fresh-tvar)
39   (begin
40     (set! cur-tvar (+ cur-tvar 1))
41     (string->symbol
42      (string-append "t" (number->string (- cur-tvar 1))))))
43
44 (define (last xs)
45   (if (null? (cdr xs))
46       (car xs)
47       (last (cdr xs))))
48
49                                         
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (cond
52    ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
53    ((lambda? prog)
54     (if (> (length (lambda-args prog)) 1)
55         (list 'lambda (list (car (lambda-args prog)))
56               (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57         (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58    ((app? prog)
59     (if (null? (cddr prog))
60         (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
61         (list (list (normalize (car prog)) (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
62    ((let? prog)
63     (append (list 'let
64                   (map (lambda (x) (cons (car x) (normalize (cdr x))))
65                        (let-bindings prog)))
66             (map normalize (let-body prog))))
67    (else prog)))
68
69 (define (builtin-type x)
70   (case x
71     ('+ '(abs int (abs int int)))
72     ('- '(abs int (abs int int)))
73     ('* '(abs int (abs int int)))
74     ('! '(abs bool bool))
75     ('bool->int '(abs bool int))
76     (else #f)))
77
78 ; we typecheck the lambda calculus only (only single arg lambdas)
79 (define (typecheck prog)
80   (define (check env x)
81     ;; (display "check: ")
82     ;; (display x)
83     ;; (display "\n\t")
84     ;; (display env)
85     ;; (newline)
86     (let
87         ((res
88           (cond
89            ((integer? x) (list '() 'int))
90            ((boolean? x) (list '() 'bool))
91            ((builtin-type x) (list '() (builtin-type x)))
92            ((symbol? x)  (list '() (env-lookup env x)))
93            ((let? x)
94             (let ((new-env (fold-left
95                             (lambda (acc bind)
96                               (let ((t (check
97                                         (env-insert acc (car bind) (fresh-tvar))
98                                         (cadr bind))))
99                                 (env-insert acc (car bind) (cadr t))))
100                             env (let-bindings x))))
101               (check new-env (last (let-body x)))))
102                   
103
104            ((lambda? x)
105             (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
106                    (body-type-res (check new-env (lambda-body x)))
107                    (cs (car body-type-res))
108                    (subd-env (substitute-env (car body-type-res) new-env))
109                    (arg-type (env-lookup subd-env (lambda-arg x)))
110                    (resolved-arg-type (substitute cs arg-type)))
111               ;; (display "lambda:\n\t")
112               ;; (display prog)
113               ;; (display "\n\t")
114               ;; (display cs)
115               ;; (display "\n\t")
116               ;; (display resolved-arg-type)
117               ;; (newline)
118               (list (car body-type-res)
119                     (list 'abs
120                           resolved-arg-type
121                           (cadr body-type-res)))))
122            
123            ((app? x) ; (f a)
124             (let* ((arg-type-res (check env (cadr x)))
125                    (arg-type (cadr arg-type-res))
126                    (func-type-res (check env (car x)))
127                    (func-type (cadr func-type-res))
128                    
129                                         ; f ~ a -> t0
130                    (func-c (unify func-type
131                                   (list 'abs
132                                         arg-type
133                                         (fresh-tvar))))
134                    (cs (consolidate
135                         (consolidate func-c (car arg-type-res))
136                         (car func-type-res)))
137                    
138                    (resolved-func-type (substitute cs func-type))
139                    (resolved-return-type (caddr resolved-func-type)))
140               ;; (display "app:\n")
141               ;; (display cs)
142               ;; (display "\n")
143               ;; (display func-type)
144               ;; (display "\n")
145               ;; (display resolved-func-type)
146               ;; (display "\n")
147               ;; (display arg-type-res)
148               ;; (display "\n")
149               (if (abs? resolved-func-type)
150                   (let ((return-type (substitute cs (caddr resolved-func-type))))
151                     (list cs return-type))
152                   (error #f "not a function")))))))
153       ;; (display "result of ")
154       ;; (display x)
155       ;; (display ":\n\t")
156       ;; (display (cadr res))
157       ;; (display "[")
158       ;; (display (car res))
159       ;; (display "]\n")
160       res))
161   (cadr (check '() (normalize prog))))
162
163                                         ; returns a list of pairs of constraints
164 (define (unify a b)
165   (cond ((eq? a b) '())
166         ((or (tvar? a) (tvar? b)) (~ a b))
167         ((and (abs? a) (abs? b))
168          (consolidate (unify (cadr a) (cadr b))
169                       (unify (caddr a) (caddr b))))
170         (else (error #f "could not unify"))))
171
172                                         ; TODO: what's the most appropriate substitution?
173                                         ; should all constraints just be limited to a pair?
174 (define (substitute cs t)
175                                         ; gets the first concrete type
176                                         ; otherwise returns the last type variable
177
178   (define (get-concrete c)
179     (let ((last (null? (cdr c))))
180       (if (not (tvar? (car c)))
181           (if (abs? (car c))
182               (substitute cs (car c))
183               (car c))
184           (if last
185               (car c)
186               (get-concrete (cdr c))))))
187   (cond
188    ((abs? t) (list 'abs
189                    (substitute cs (cadr t))
190                    (substitute cs (caddr t))))
191    (else
192     (fold-left
193      (lambda (t c)
194        (if (member t c)
195            (get-concrete c)
196            t))
197      t cs))))
198
199 (define (substitute-env cs env)
200   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
201
202 (define (~ a b)
203   (list (list a b)))
204
205 (define (consolidate x y)
206   (define (merge a b)
207     (cond ((null? a) b)
208           ((null? b) a)
209           (else (if (member (car b) a)
210                     (merge a (cdr b))
211                     (cons (car b) (merge a (cdr b)))))))
212   (define (overlap? a b)
213     (if (or (null? a) (null? b))
214         #f
215         (if (fold-left (lambda (acc v)
216                          (or acc (eq? v (car a))))
217                        #f b)
218             #t
219             (overlap? (cdr a) b))))
220
221   (cond ((null? y) x)
222         ((null? x) y)
223         (else (let* ((a (car y))
224                      (merged (fold-left
225                               (lambda (acc b)
226                                 (if acc
227                                     acc
228                                     (if (overlap? a b)
229                                         (cons (merge a b) b)
230                                         #f)))
231                               #f x))
232                      (removed (if merged
233                                   (filter (lambda (b) (not (eq? b (cdr merged)))) x)
234                                   x)))
235                 (if merged
236                     (consolidate removed (cons (car merged) (cdr y)))
237                     (consolidate (cons a x) (cdr y)))))))