Substitute only the variables in the scc
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     ('void #t)
14     (else #f)))
15
16 (define (pretty-type t)
17   (cond ((abs? t)
18          (string-append
19           (if (abs? (cadr t))
20               (string-append "(" (pretty-type (cadr t)) ")")
21               (pretty-type (cadr t)))
22           " -> "
23           (pretty-type (caddr t))))
24         (else (symbol->string t))))
25
26                                         ; ('a, ('b, 'a))
27 (define (env-lookup env n)
28   (if (null? env) (error #f "empty env")                        ; it's a type equality
29       (if (eq? (caar env) n)
30           (cdar env)
31           (env-lookup (cdr env) n))))
32
33 (define (env-insert env n t)
34   (cons (cons n t) env))
35
36 (define abs-arg cadr)
37
38 (define cur-tvar 0)
39 (define (fresh-tvar)
40   (begin
41     (set! cur-tvar (+ cur-tvar 1))
42     (string->symbol
43      (string-append "t" (number->string (- cur-tvar 1))))))
44
45 (define (last xs)
46   (if (null? (cdr xs))
47       (car xs)
48       (last (cdr xs))))
49                                 
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (case (ast-type prog)
52     ('lambda 
53                                         ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
54         (if (> (length (lambda-args prog)) 1)
55             (list 'lambda (list (car (lambda-args prog)))
56                   (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57             (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58     ('app
59      (if (null? (cddr prog))
60          `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
61          (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
62                       ,@(cddr prog))))) ; (f a b)
63     ('let
64         (append (list 'let
65                       (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
66                            (let-bindings prog)))
67                 (map normalize (let-body prog))))
68     (else (ast-traverse normalize prog))))
69
70 (define (builtin-type x)
71   (case x
72     ('+ '(abs int (abs int int)))
73     ('- '(abs int (abs int int)))
74     ('* '(abs int (abs int int)))
75     ('! '(abs bool bool))
76     ('= '(abs int (abs int bool)))
77     ('bool->int '(abs bool int))
78     ('print '(abs string void))
79     (else #f)))
80
81 ; we typecheck the lambda calculus only (only single arg lambdas)
82 (define (typecheck prog)
83   (define (check env x)
84     ;; (display "check: ")
85     ;; (display x)
86     ;; (display "\n\t")
87     ;; (display env)
88     ;; (newline)
89     (let
90         ((res
91           (case (ast-type x)
92             ('int-literal (list '() 'int))
93             ('bool-literal (list '() 'bool))
94             ('string-literal (list '() 'string))
95             ('builtin (list '() (builtin-type x)))
96
97             ('if
98              (let* ((cond-type-res (check env (cadr x)))
99                     (then-type-res (check env (caddr x)))
100                     (else-type-res (check env (cadddr x)))
101                     (then-eq-else-cs (unify (cadr then-type-res)
102                                             (cadr else-type-res)))
103                     (cs (consolidate
104                          (car then-type-res)
105                          (consolidate (car else-type-res)
106                                       then-eq-else-cs)))
107                     (return-type (substitute cs (cadr then-type-res))))
108                (when (not (eqv? (cadr cond-type-res) 'bool))
109                  (error #f "if condition isn't bool"))
110                (list cs return-type)))
111             
112             ('var (list '() (env-lookup env x)))
113             ('let
114                                         ; takes in the current environment and a scc
115                                         ; returns new environment with scc's types added in
116               (let* ([components (reverse (sccs (graph (let-bindings x))))]
117                      [process-component
118                       (lambda (acc comps)
119                         (let*
120                                         ; create a new env with tvars for each component
121                                         ; e.g. scc of (x y)
122                                         ; scc-env = ((x . t0) (y . t1))
123                             ([scc-env
124                               (fold-left
125                                (lambda (acc c)
126                                  (env-insert acc c (fresh-tvar)))
127                                acc comps)]
128                                         ; typecheck each component
129                              [type-results
130                               (map
131                                (lambda (c)
132                                  (let ([body (cadr (assoc c (let-bindings x)))])
133                                    (check scc-env body)))
134                                comps)]
135                                         ; collect all the constraints in the scc
136                              [cs
137                               (fold-left
138                                (lambda (acc res c)
139                                  (consolidate
140                                   acc
141                                   (consolidate (car res)
142                                         ; unify with tvars from scc-env
143                                         ; result ~ tvar
144                                                (unify (cadr res) (env-lookup scc-env c)))))
145                                '() type-results comps)]
146                                         ; substitute *only* the bindings in this scc
147                              [new-env
148                               (map (lambda (x)
149                                      (if (memv (car x) comps)
150                                          (cons (car x) (substitute cs (cdr x)))
151                                          x))
152                                    scc-env)])
153                           new-env))]
154                      [new-env (fold-left process-component env components)])
155                 (check new-env (last (let-body x)))))
156             
157             ('lambda
158                 (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
159
160                        (body-type-res (check new-env (lambda-body x)))
161                        (cs (car body-type-res))
162                        (subd-env (substitute-env (car body-type-res) new-env))
163                        (arg-type (env-lookup subd-env (lambda-arg x)))
164                        (resolved-arg-type (substitute cs arg-type))]
165                   ;; (display "lambda:\n\t")
166                   ;; (display prog)
167                   ;; (display "\n\t")
168                   ;; (display cs)
169                   ;; (display "\n\t")
170                   ;; (display resolved-arg-type)
171                   ;; (newline)
172                   (list (car body-type-res)
173                         (list 'abs
174                               resolved-arg-type
175                               (cadr body-type-res)))))
176             
177             ('app ; (f a)
178              (if (eqv? (car x) (cadr x))
179                                         ; recursive function (f f)
180                  (let* [(func-type (env-lookup env (car x)))
181                         (return-type (fresh-tvar))
182                         (other-func-type `(abs ,func-type ,return-type))
183                         (cs (unify func-type other-func-type))]
184                    (list cs return-type))
185
186                                         ; regular function
187                  (let* ((arg-type-res (check env (cadr x)))
188                         (arg-type (cadr arg-type-res))
189                         (func-type-res (check env (car x)))
190                         (func-type (cadr func-type-res))
191                         
192                                         ; f ~ a -> t0
193                         (func-c (unify func-type
194                                        (list 'abs
195                                              arg-type
196                                              (fresh-tvar))))
197                         (cs (consolidate
198                              (consolidate func-c (car arg-type-res))
199                              (car func-type-res)))
200                         
201                         (resolved-func-type (substitute cs func-type))
202                         (resolved-return-type (caddr resolved-func-type)))
203                    ;; (display "app:\n")
204                    ;; (display cs)
205                    ;; (display "\n")
206                    ;; (display func-type)
207                    ;; (display "\n")
208                    ;; (display resolved-func-type)
209                    ;; (display "\n")
210                    ;; (display arg-type-res)
211                    ;; (display "\n")
212                    (if (abs? resolved-func-type)
213                        (let ((return-type (substitute cs (caddr resolved-func-type))))
214                          (list cs return-type))
215                        (error #f "not a function"))))))))
216       ;; (display "result of ")
217       ;; (display x)
218       ;; (display ":\n\t")
219       ;; (display (pretty-type (cadr res)))
220       ;; (display "\n\t[")
221       ;; (display (car res))
222       ;; (display "]\n")
223       res))
224   (cadr (check '() (normalize prog))))
225
226                                         ; returns a list of pairs of constraints
227 (define (unify a b)
228   (let ([res (unify? a b)])
229     (if res
230         res
231         (error #f
232                (format "couldn't unify ~a ~~ ~a" a b)))))
233
234 (define (unify? a b)
235   (cond ((eq? a b) '())
236         ((or (tvar? a) (tvar? b)) (~ a b))
237         ((and (abs? a) (abs? b))
238          (let* [(arg-cs (unify (cadr a) (cadr b)))
239                 (body-cs (unify (substitute arg-cs (caddr a))
240                                 (substitute arg-cs (caddr b))))]
241            (consolidate arg-cs body-cs)))
242         (else #f)))
243
244                                         ; TODO: what's the most appropriate substitution?
245                                         ; should all constraints just be limited to a pair?
246 (define (substitute cs t)
247                                         ; gets the first concrete type
248                                         ; otherwise returns the last type variable
249
250   (define cs-without-t
251     (map (lambda (c)
252            (filter (lambda (x) (not (eqv? t x))) c))
253          cs))
254
255   (define (get-concrete c)
256     (let [(last (null? (cdr c)))]
257       (if (not (tvar? (car c)))
258           (if (abs? (car c))
259               (substitute cs-without-t (car c))
260               (car c))
261           (if last
262               (car c)
263               (get-concrete (cdr c))))))
264   
265   (cond
266    ((abs? t) (list 'abs
267                    (substitute cs (cadr t))
268                    (substitute cs (caddr t))))
269    (else
270     (fold-left
271      (lambda (t c)
272        (if (member t c)
273            (get-concrete c)
274            t))
275      t cs))))
276
277 (define (substitute-env cs env)
278   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
279
280 (define (~ a b)
281   (list (list a b)))
282
283 (define (consolidate x y)
284   (define (merge a b)
285     (cond ((null? a) b)
286           ((null? b) a)
287           (else (if (member (car b) a)
288                     (merge a (cdr b))
289                     (cons (car b) (merge a (cdr b)))))))
290   (define (overlap? a b)
291     (if (or (null? a) (null? b))
292         #f
293         (if (fold-left (lambda (acc v)
294                          (or acc (eq? v (car a))))
295                        #f b)
296             #t
297             (overlap? (cdr a) b))))
298
299   (cond ((null? y) x)
300         ((null? x) y)
301         (else
302          (let* ((a (car y))
303                 (merged (fold-left
304                          (lambda (acc b)
305                            (if acc
306                                acc
307                                (if (overlap? a b)
308                                    (cons (merge a b) b)
309                                    #f)))
310                          #f x))
311                 (removed (if merged
312                              (filter (lambda (b) (not (eq? b (cdr merged)))) x)
313                              x)))
314            (if merged
315                (consolidate removed (cons (car merged) (cdr y)))
316                (consolidate (cons a x) (cdr y)))))))
317
318                                         ; a1 -> a2 ~ a3 -> a4;
319                                         ; a1 -> a2 !~ bool -> bool
320                                         ; basically can the tvars be renamed
321 (define (types-equal? x y)
322   (let ([cs (unify? x y)])
323     (if (not cs) #f
324         (let*
325             ([test-kind
326               (lambda (acc c)
327                 (if (tvar? c) acc #f))]
328              [test (lambda (acc c)
329                      (and acc (fold-left test-kind #t c)))])
330           (fold-left test #t cs)))))
331
332                                         ; input: a list of binds ((x . y) (y . 3))
333                                         ; returns: pair of verts, edges ((x y) . (x . y))
334 (define (graph bs)
335   (define (find-refs prog)
336     (ast-collect
337      (lambda (x)
338        (case (ast-type x)
339                                         ; only count a reference if its a binding
340          ['var (if (assoc x bs) (list x) '())]
341          [else '()]))
342      prog))
343   (if (null? bs)
344       '(() . ())
345       (let* [(bind (car bs))
346
347              (vert (car bind))
348              (refs (find-refs (cdr bind)))
349              (edges (map (lambda (x) (cons vert x))
350                          refs))
351
352              (rest (if (null? (cdr bs))
353                        (cons '() '())
354                        (graph (cdr bs))))
355              (total-verts (cons vert (car rest)))
356              (total-edges (append edges (cdr rest)))]
357         (cons total-verts total-edges))))
358
359 (define (successors graph v)
360   (define (go v E)
361     (if (null? E)
362         '()
363         (if (eqv? v (caar E))
364             (cons (cdar E) (go v (cdr E)))
365             (go v (cdr E)))))
366   (go v (cdr graph)))
367
368                                         ; takes in a graph (pair of vertices, edges)
369                                         ; returns a list of strongly connected components
370
371                                         ; ((x y w) . ((x . y) (x . w) (w . x))
372
373                                         ; =>
374                                         ; .->x->y
375                                         ; |  |
376                                         ; |  v
377                                         ; .--w
378
379                                         ; ((x w) (y))
380
381                                         ; this uses tarjan's algorithm, to get reverse
382                                         ; topological sorting for free
383 (define (sccs graph)
384   
385   (let* ([indices (make-hash-table)]
386          [lowlinks (make-hash-table)]
387          [on-stack (make-hash-table)]
388          [current 0]
389          [stack '()]
390          [result '()])
391
392     (define (index v)
393       (get-hash-table indices v #f))
394     (define (lowlink v)
395       (get-hash-table lowlinks v #f))
396
397     (letrec
398         ([strong-connect
399           (lambda (v)
400             (begin
401               (put-hash-table! indices v current)
402               (put-hash-table! lowlinks v current)
403               (set! current (+ current 1))
404               (push! stack v)
405               (put-hash-table! on-stack v #t)
406
407               (for-each
408                (lambda (w)
409                  (if (not (hashtable-contains? indices w))
410                                         ; successor w has not been visited, recurse
411                      (begin
412                        (strong-connect w)
413                        (put-hash-table! lowlinks
414                                         v
415                                         (min (lowlink v) (lowlink w))))
416                                         ; successor w has been visited
417                      (when (get-hash-table on-stack w #f)
418                        (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
419                (successors graph v))
420
421               (when (= (index v) (lowlink v))
422                 (let ([scc
423                        (let new-scc ()
424                          (let ([w (pop! stack)])
425                            (put-hash-table! on-stack w #f)
426                            (if (eqv? w v)
427                                (list w)
428                                (cons w (new-scc)))))])
429                   (set! result (cons scc result))))))])
430       (for-each
431        (lambda (v)
432          (when (not (hashtable-contains? indices v)) ; v.index == -1
433            (strong-connect v)))
434        (car graph)))
435     result))
436