Break up lets into SCCs before typechecking
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     ('void #t)
14     (else #f)))
15
16 (define (pretty-type t)
17   (cond ((abs? t)
18          (string-append
19           (if (abs? (cadr t))
20               (string-append "(" (pretty-type (cadr t)) ")")
21               (pretty-type (cadr t)))
22           " -> "
23           (pretty-type (caddr t))))
24         (else (symbol->string t))))
25
26                                         ; ('a, ('b, 'a))
27 (define (env-lookup env n)
28   (if (null? env) (error #f "empty env")                        ; it's a type equality
29       (if (eq? (caar env) n)
30           (cdar env)
31           (env-lookup (cdr env) n))))
32
33 (define (env-insert env n t)
34   (cons (cons n t) env))
35
36 (define abs-arg cadr)
37
38 (define cur-tvar 0)
39 (define (fresh-tvar)
40   (begin
41     (set! cur-tvar (+ cur-tvar 1))
42     (string->symbol
43      (string-append "t" (number->string (- cur-tvar 1))))))
44
45 (define (last xs)
46   (if (null? (cdr xs))
47       (car xs)
48       (last (cdr xs))))
49                                 
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (case (ast-type prog)
52     ('lambda 
53                                         ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
54         (if (> (length (lambda-args prog)) 1)
55             (list 'lambda (list (car (lambda-args prog)))
56                   (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57             (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58     ('app
59      (if (null? (cddr prog))
60          `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
61          (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
62                       ,@(cddr prog))))) ; (f a b)
63     ('let
64         (append (list 'let
65                       (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
66                            (let-bindings prog)))
67                 (map normalize (let-body prog))))
68     (else (ast-traverse normalize prog))))
69
70 (define (builtin-type x)
71   (case x
72     ('+ '(abs int (abs int int)))
73     ('- '(abs int (abs int int)))
74     ('* '(abs int (abs int int)))
75     ('! '(abs bool bool))
76     ('= '(abs int (abs int bool)))
77     ('bool->int '(abs bool int))
78     ('print '(abs string void))
79     (else #f)))
80
81 ; we typecheck the lambda calculus only (only single arg lambdas)
82 (define (typecheck prog)
83   (define (check env x)
84     (display "check: ")
85     (display x)
86     (display "\n\t")
87     (display env)
88     (newline)
89     (let
90         ((res
91           (case (ast-type x)
92             ('int-literal (list '() 'int))
93             ('bool-literal (list '() 'bool))
94             ('string-literal (list '() 'string))
95             ('builtin (list '() (builtin-type x)))
96
97             ('if
98              (let* ((cond-type-res (check env (cadr x)))
99                     (then-type-res (check env (caddr x)))
100                     (else-type-res (check env (cadddr x)))
101                     (then-eq-else-cs (unify (cadr then-type-res)
102                                             (cadr else-type-res)))
103                     (cs (consolidate
104                          (car then-type-res)
105                          (consolidate (car else-type-res)
106                                       then-eq-else-cs)))
107                     (return-type (substitute cs (cadr then-type-res))))
108                (when (not (eqv? (cadr cond-type-res) 'bool))
109                  (error #f "if condition isn't bool"))
110                (list cs return-type)))
111             
112             ('var (list '() (env-lookup env x)))
113             ('let
114                                         ; takes in the current environment and a scc
115                                         ; returns new environment with scc's types added in
116               (let* ([components (reverse (sccs (graph (let-bindings x))))]
117                      [process-component
118                       (lambda (acc comps)
119                         (display comps)
120                         (newline)
121                         (let*
122                             ([scc-env
123                               (fold-left
124                                (lambda (acc c)
125                                  (env-insert acc c (fresh-tvar)))
126                                acc comps)]
127                              [type-results
128                               (map
129                                (lambda (c)
130                                  (begin (display scc-env) (newline)
131                                  (let ([body (cadr (assoc c (let-bindings x)))])
132                                    (display body)(newline)(check scc-env body))))
133                                comps)]
134                              [cs
135                               (fold-left
136                                (lambda (acc res c)
137                                  (consolidate
138                                   acc
139                                   (unify (cadr res) (env-lookup scc-env c))))
140                                '() type-results comps)])
141                           (display "process-component env:\n")
142                           (display (substitute-env cs scc-env))
143                           (newline)
144                           (substitute-env cs scc-env)))]
145                      [new-env (fold-left process-component env components)])
146                 (check new-env (last (let-body x)))))
147             
148             ;; (let ((new-env (fold-left
149             ;;          (lambda (acc bind)
150             ;;            (let* [(bind-tvar (fresh-tvar))
151             ;;                   (env-with-tvar (env-insert acc (car bind) bind-tvar))
152             ;;                   (bind-res (check env-with-tvar (cadr bind)))
153             ;;                   (bind-type (cadr bind-res))
154             ;;                   (cs (consolidate (car bind-res)
155             ;;                                    (unify bind-type bind-tvar)))]
156             ;;              (substitute-env cs env-with-tvar)))
157             ;;          env (let-bindings x))))
158             ;;   (display "sccs of graph\n")
159             ;;   (display (sccs (graph (let-bindings x))))
160             ;;   (newline)
161             ;;   (display "env when checking body:\n\t")
162             ;;   (display new-env)
163             ;;   (newline)
164             ;;   (check new-env (last (let-body x)))))
165             
166
167             ('lambda
168                 (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
169
170                        (body-type-res (check new-env (lambda-body x)))
171                        (cs (car body-type-res))
172                        (subd-env (substitute-env (car body-type-res) new-env))
173                        (arg-type (env-lookup subd-env (lambda-arg x)))
174                        (resolved-arg-type (substitute cs arg-type))]
175                   ;; (display "lambda:\n\t")
176                   ;; (display prog)
177                   ;; (display "\n\t")
178                   ;; (display cs)
179                   ;; (display "\n\t")
180                   ;; (display resolved-arg-type)
181                   ;; (newline)
182                   (list (car body-type-res)
183                         (list 'abs
184                               resolved-arg-type
185                               (cadr body-type-res)))))
186             
187             ('app ; (f a)
188              (if (eqv? (car x) (cadr x))
189                                         ; recursive function (f f)
190                  (let* [(func-type (env-lookup env (car x)))
191                         (return-type (fresh-tvar))
192                         (other-func-type `(abs ,func-type ,return-type))
193                         (cs (unify func-type other-func-type))]
194                    (list cs return-type))
195
196                                         ; regular function
197                  (let* ((arg-type-res (check env (cadr x)))
198                         (arg-type (cadr arg-type-res))
199                         (func-type-res (check env (car x)))
200                         (func-type (cadr func-type-res))
201                         
202                                         ; f ~ a -> t0
203                         (func-c (unify func-type
204                                        (list 'abs
205                                              arg-type
206                                              (fresh-tvar))))
207                         (cs (consolidate
208                              (consolidate func-c (car arg-type-res))
209                              (car func-type-res)))
210                         
211                         (resolved-func-type (substitute cs func-type))
212                         (resolved-return-type (caddr resolved-func-type)))
213                    ;; (display "app:\n")
214                    ;; (display cs)
215                    ;; (display "\n")
216                    ;; (display func-type)
217                    ;; (display "\n")
218                    ;; (display resolved-func-type)
219                    ;; (display "\n")
220                    ;; (display arg-type-res)
221                    ;; (display "\n")
222                    (if (abs? resolved-func-type)
223                        (let ((return-type (substitute cs (caddr resolved-func-type))))
224                          (list cs return-type))
225                        (error #f "not a function"))))))))
226       (display "result of ")
227       (display x)
228       (display ":\n\t")
229       (display (pretty-type (cadr res)))
230       (display "\n\t[")
231       (display (car res))
232       (display "]\n")
233       res))
234   (cadr (check '() (normalize prog))))
235
236                                         ; returns a list of pairs of constraints
237 (define (unify a b)
238   (cond ((eq? a b) '())
239         ((or (tvar? a) (tvar? b)) (~ a b))
240         ((and (abs? a) (abs? b))
241          (let* [(arg-cs (unify (cadr a) (cadr b)))
242                 (body-cs (unify (substitute arg-cs (caddr a))
243                                 (substitute arg-cs (caddr b))))]
244            (consolidate arg-cs body-cs)))
245         (else (error #f "could not unify"))))
246
247                                         ; TODO: what's the most appropriate substitution?
248                                         ; should all constraints just be limited to a pair?
249 (define (substitute cs t)
250                                         ; gets the first concrete type
251                                         ; otherwise returns the last type variable
252
253   (define cs-without-t
254     (map (lambda (c)
255            (filter (lambda (x) (not (eqv? t x))) c))
256          cs))
257
258   (define (get-concrete c)
259     (let [(last (null? (cdr c)))]
260       (if (not (tvar? (car c)))
261           (if (abs? (car c))
262               (substitute cs-without-t (car c))
263               (car c))
264           (if last
265               (car c)
266               (get-concrete (cdr c))))))
267   
268   (cond
269    ((abs? t) (list 'abs
270                    (substitute cs (cadr t))
271                    (substitute cs (caddr t))))
272    (else
273     (fold-left
274      (lambda (t c)
275        (if (member t c)
276            (get-concrete c)
277            t))
278      t cs))))
279
280 (define (substitute-env cs env)
281   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
282
283 (define (~ a b)
284   (list (list a b)))
285
286 (define (consolidate x y)
287   (define (merge a b)
288     (cond ((null? a) b)
289           ((null? b) a)
290           (else (if (member (car b) a)
291                     (merge a (cdr b))
292                     (cons (car b) (merge a (cdr b)))))))
293   (define (overlap? a b)
294     (if (or (null? a) (null? b))
295         #f
296         (if (fold-left (lambda (acc v)
297                          (or acc (eq? v (car a))))
298                        #f b)
299             #t
300             (overlap? (cdr a) b))))
301
302   (cond ((null? y) x)
303         ((null? x) y)
304         (else
305          (let* ((a (car y))
306                 (merged (fold-left
307                          (lambda (acc b)
308                            (if acc
309                                acc
310                                (if (overlap? a b)
311                                    (cons (merge a b) b)
312                                    #f)))
313                          #f x))
314                 (removed (if merged
315                              (filter (lambda (b) (not (eq? b (cdr merged)))) x)
316                              x)))
317            (if merged
318                (consolidate removed (cons (car merged) (cdr y)))
319                (consolidate (cons a x) (cdr y)))))))
320
321                                         ; a1 -> a2 ~ a3 -> a4;
322                                         ; a1 -> a2 !~ bool -> bool
323                                         ; basically can the tvars be renamed
324 (define (types-equal? x y)
325   (error #f "todo"))
326
327                                         ; input: a list of binds ((x . y) (y . 3))
328                                         ; returns: pair of verts, edges ((x y) . (x . y))
329 (define (graph bs)
330   (define (find-refs prog)
331     (ast-collect
332      (lambda (x)
333        (case (ast-type x)
334                                         ; only count a reference if its a binding
335          ['var (if (assoc x bs) (list x) '())]
336          [else '()]))
337      prog))
338   (let* [(bind (car bs))
339
340          (vert (car bind))
341          (refs (find-refs (cdr bind)))
342          (edges (map (lambda (x) (cons vert x))
343                      refs))
344
345          (rest (if (null? (cdr bs))
346                    (cons '() '())
347                    (graph (cdr bs))))
348          (total-verts (cons vert (car rest)))
349          (total-edges (append edges (cdr rest)))]
350     (cons total-verts total-edges)))
351
352 (define (successors graph v)
353   (define (go v E)
354     (if (null? E)
355         '()
356         (if (eqv? v (caar E))
357             (cons (cdar E) (go v (cdr E)))
358             (go v (cdr E)))))
359   (go v (cdr graph)))
360
361                                         ; takes in a graph (pair of vertices, edges)
362                                         ; returns a list of strongly connected components
363
364                                         ; ((x y w) . ((x . y) (x . w) (w . x))
365
366                                         ; =>
367                                         ; .->x->y
368                                         ; |  |
369                                         ; |  v
370                                         ; .--w
371
372                                         ; ((x w) (y))
373
374                                         ; this uses tarjan's algorithm, to get reverse
375                                         ; topological sorting for free
376 (define (sccs graph)
377   
378   (let* ([indices (make-hash-table)]
379          [lowlinks (make-hash-table)]
380          [on-stack (make-hash-table)]
381          [current 0]
382          [stack '()]
383          [result '()])
384
385     (define (index v)
386       (get-hash-table indices v #f))
387     (define (lowlink v)
388       (get-hash-table lowlinks v #f))
389
390     (letrec
391         ([strong-connect
392           (lambda (v)
393             (begin
394               (put-hash-table! indices v current)
395               (put-hash-table! lowlinks v current)
396               (set! current (+ current 1))
397               (push! stack v)
398               (put-hash-table! on-stack v #t)
399
400               (for-each
401                (lambda (w)
402                  (if (not (hashtable-contains? indices w))
403                                         ; successor w has not been visited, recurse
404                      (begin
405                        (strong-connect w)
406                        (put-hash-table! lowlinks
407                                         v
408                                         (min (lowlink v) (lowlink w))))
409                                         ; successor w has been visited
410                      (when (get-hash-table on-stack w #f)
411                        (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
412                (successors graph v))
413
414               (when (= (index v) (lowlink v))
415                 (let ([scc
416                        (let new-scc ()
417                          (let ([w (pop! stack)])
418                            (put-hash-table! on-stack w #f)
419                            (if (eqv? w v)
420                                (list w)
421                                (cons w (new-scc)))))])
422                   (set! result (cons scc result))))))])
423       
424       (for-each
425        (lambda (v)
426          (when (not (hashtable-contains? indices v)) ; v.index == -1
427            (strong-connect v)))
428        (car graph)))
429     result))
430