Add strings and print primitive
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     ('void #t)
14     (else #f)))
15
16 (define (pretty-type t)
17   (cond ((abs? t)
18          (string-append
19           (if (abs? (cadr t))
20               (string-append "(" (pretty-type (cadr t)) ")")
21               (pretty-type (cadr t)))
22           " -> "
23           (pretty-type (caddr t))))
24         (else (symbol->string t))))
25
26                                         ; ('a, ('b, 'a))
27 (define (env-lookup env n)
28   (if (null? env) (error #f "empty env")                        ; it's a type equality
29       (if (eq? (caar env) n)
30           (cdar env)
31           (env-lookup (cdr env) n))))
32
33 (define (env-insert env n t)
34   (cons (cons n t) env))
35
36 (define abs-arg cadr)
37
38 (define cur-tvar 0)
39 (define (fresh-tvar)
40   (begin
41     (set! cur-tvar (+ cur-tvar 1))
42     (string->symbol
43      (string-append "t" (number->string (- cur-tvar 1))))))
44
45 (define (last xs)
46   (if (null? (cdr xs))
47       (car xs)
48       (last (cdr xs))))
49                                 
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (case (ast-type prog)
52     ('lambda 
53                                         ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
54         (if (> (length (lambda-args prog)) 1)
55             (list 'lambda (list (car (lambda-args prog)))
56                   (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57             (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58     ('app
59      (if (null? (cddr prog))
60          `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
61          `(,(list (normalize (car prog)) (normalize (cadr prog)))
62            ,(normalize (caddr prog))))) ; (f a b)
63     ('let
64         (append (list 'let
65                       (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
66                            (let-bindings prog)))
67                 (map normalize (let-body prog))))
68     (else (ast-traverse normalize prog))))
69
70 (define (builtin-type x)
71   (case x
72     ('+ '(abs int (abs int int)))
73     ('- '(abs int (abs int int)))
74     ('* '(abs int (abs int int)))
75     ('! '(abs bool bool))
76     ('= '(abs int (abs int bool)))
77     ('bool->int '(abs bool int))
78     ('print '(abs string void))
79     (else #f)))
80
81 ; we typecheck the lambda calculus only (only single arg lambdas)
82 (define (typecheck prog)
83   (define (check env x)
84     ;; (display "check: ")
85     ;; (display x)
86     ;; (display "\n\t")
87     ;; (display env)
88     ;; (newline)
89     (let
90         ((res
91           (case (ast-type x)
92            ('int-literal (list '() 'int))
93            ('bool-literal (list '() 'bool))
94            ('string-literal (list '() 'string))
95            ('builtin (list '() (builtin-type x)))
96
97            ('if
98             (let* ((cond-type-res (check env (cadr x)))
99                    (then-type-res (check env (caddr x)))
100                    (else-type-res (check env (cadddr x)))
101                    (then-eq-else-cs (unify (cadr then-type-res)
102                                            (cadr else-type-res)))
103                    (cs (consolidate
104                         (car then-type-res)
105                         (consolidate (car else-type-res)
106                                      then-eq-else-cs)))
107                    (return-type (substitute cs (cadr then-type-res))))
108               (when (not (eqv? (cadr cond-type-res) 'bool))
109                 (error #f "if condition isn't bool"))
110               (list cs return-type)))
111            
112            ('var  (list '() (env-lookup env x)))
113            ('let
114             (let ((new-env (fold-left
115                             (lambda (acc bind)
116                               (let ((t (check
117                                         (env-insert acc (car bind) (fresh-tvar))
118                                         (cadr bind))))
119                                 (env-insert acc (car bind) (cadr t))))
120                             env (let-bindings x))))
121               (check new-env (last (let-body x)))))
122                   
123
124            ('lambda
125             (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
126                    (body-type-res (check new-env (lambda-body x)))
127                    (cs (car body-type-res))
128                    (subd-env (substitute-env (car body-type-res) new-env))
129                    (arg-type (env-lookup subd-env (lambda-arg x)))
130                    (resolved-arg-type (substitute cs arg-type)))
131               ;; (display "lambda:\n\t")
132               ;; (display prog)
133               ;; (display "\n\t")
134               ;; (display cs)
135               ;; (display "\n\t")
136               ;; (display resolved-arg-type)
137               ;; (newline)
138               (list (car body-type-res)
139                     (list 'abs
140                           resolved-arg-type
141                           (cadr body-type-res)))))
142            
143            ('app ; (f a)
144             (let* ((arg-type-res (check env (cadr x)))
145                    (arg-type (cadr arg-type-res))
146                    (func-type-res (check env (car x)))
147                    (func-type (cadr func-type-res))
148                    
149                                         ; f ~ a -> t0
150                    (func-c (unify func-type
151                                   (list 'abs
152                                         arg-type
153                                         (fresh-tvar))))
154                    (cs (consolidate
155                         (consolidate func-c (car arg-type-res))
156                         (car func-type-res)))
157                    
158                    (resolved-func-type (substitute cs func-type))
159                    (resolved-return-type (caddr resolved-func-type)))
160               ;; (display "app:\n")
161               ;; (display cs)
162               ;; (display "\n")
163               ;; (display func-type)
164               ;; (display "\n")
165               ;; (display resolved-func-type)
166               ;; (display "\n")
167               ;; (display arg-type-res)
168               ;; (display "\n")
169               (if (abs? resolved-func-type)
170                   (let ((return-type (substitute cs (caddr resolved-func-type))))
171                     (list cs return-type))
172                   (error #f "not a function")))))))
173       ;; (display "result of ")
174       ;; (display x)
175       ;; (display ":\n\t")
176       ;; (display (cadr res))
177       ;; (display "[")
178       ;; (display (car res))
179       ;; (display "]\n")
180       res))
181   (cadr (check '() (normalize prog))))
182
183                                         ; returns a list of pairs of constraints
184 (define (unify a b)
185   (cond ((eq? a b) '())
186         ((or (tvar? a) (tvar? b)) (~ a b))
187         ((and (abs? a) (abs? b))
188          (consolidate (unify (cadr a) (cadr b))
189                       (unify (caddr a) (caddr b))))
190         (else (error #f "could not unify"))))
191
192                                         ; TODO: what's the most appropriate substitution?
193                                         ; should all constraints just be limited to a pair?
194 (define (substitute cs t)
195                                         ; gets the first concrete type
196                                         ; otherwise returns the last type variable
197
198   (define (get-concrete c)
199     (let ((last (null? (cdr c))))
200       (if (not (tvar? (car c)))
201           (if (abs? (car c))
202               (substitute cs (car c))
203               (car c))
204           (if last
205               (car c)
206               (get-concrete (cdr c))))))
207   (cond
208    ((abs? t) (list 'abs
209                    (substitute cs (cadr t))
210                    (substitute cs (caddr t))))
211    (else
212     (fold-left
213      (lambda (t c)
214        (if (member t c)
215            (get-concrete c)
216            t))
217      t cs))))
218
219 (define (substitute-env cs env)
220   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
221
222 (define (~ a b)
223   (list (list a b)))
224
225 (define (consolidate x y)
226   (define (merge a b)
227     (cond ((null? a) b)
228           ((null? b) a)
229           (else (if (member (car b) a)
230                     (merge a (cdr b))
231                     (cons (car b) (merge a (cdr b)))))))
232   (define (overlap? a b)
233     (if (or (null? a) (null? b))
234         #f
235         (if (fold-left (lambda (acc v)
236                          (or acc (eq? v (car a))))
237                        #f b)
238             #t
239             (overlap? (cdr a) b))))
240
241   (cond ((null? y) x)
242         ((null? x) y)
243         (else (let* ((a (car y))
244                      (merged (fold-left
245                               (lambda (acc b)
246                                 (if acc
247                                     acc
248                                     (if (overlap? a b)
249                                         (cons (merge a b) b)
250                                         #f)))
251                               #f x))
252                      (removed (if merged
253                                   (filter (lambda (b) (not (eq? b (cdr merged)))) x)
254                                   x)))
255                 (if merged
256                     (consolidate removed (cons (car merged) (cdr y)))
257                     (consolidate (cons a x) (cdr y)))))))