Refactor unify
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     ('void #t)
14     (else #f)))
15
16 (define (pretty-type t)
17   (cond ((abs? t)
18          (string-append
19           (if (abs? (cadr t))
20               (string-append "(" (pretty-type (cadr t)) ")")
21               (pretty-type (cadr t)))
22           " -> "
23           (pretty-type (caddr t))))
24         (else (symbol->string t))))
25
26                                         ; ('a, ('b, 'a))
27 (define (env-lookup env n)
28   (if (null? env) (error #f "empty env")                        ; it's a type equality
29       (if (eq? (caar env) n)
30           (cdar env)
31           (env-lookup (cdr env) n))))
32
33 (define (env-insert env n t)
34   (cons (cons n t) env))
35
36 (define abs-arg cadr)
37
38 (define cur-tvar 0)
39 (define (fresh-tvar)
40   (begin
41     (set! cur-tvar (+ cur-tvar 1))
42     (string->symbol
43      (string-append "t" (number->string (- cur-tvar 1))))))
44
45 (define (last xs)
46   (if (null? (cdr xs))
47       (car xs)
48       (last (cdr xs))))
49                                 
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (case (ast-type prog)
52     ('lambda 
53                                         ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
54         (if (> (length (lambda-args prog)) 1)
55             (list 'lambda (list (car (lambda-args prog)))
56                   (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57             (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58     ('app
59      (if (null? (cddr prog))
60          `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
61          (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
62                       ,@(cddr prog))))) ; (f a b)
63     ('let
64         (append (list 'let
65                       (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
66                            (let-bindings prog)))
67                 (map normalize (let-body prog))))
68     (else (ast-traverse normalize prog))))
69
70 (define (builtin-type x)
71   (case x
72     ('+ '(abs int (abs int int)))
73     ('- '(abs int (abs int int)))
74     ('* '(abs int (abs int int)))
75     ('! '(abs bool bool))
76     ('= '(abs int (abs int bool)))
77     ('bool->int '(abs bool int))
78     ('print '(abs string void))
79     (else #f)))
80
81 ; we typecheck the lambda calculus only (only single arg lambdas)
82 (define (typecheck prog)
83   (define (check env x)
84     ;; (display "check: ")
85     ;; (display x)
86     ;; (display "\n\t")
87     ;; (display env)
88     ;; (newline)
89     (let
90         ((res
91           (case (ast-type x)
92             ('int-literal (list '() 'int))
93             ('bool-literal (list '() 'bool))
94             ('string-literal (list '() 'string))
95             ('builtin (list '() (builtin-type x)))
96
97             ('if
98              (let* ((cond-type-res (check env (cadr x)))
99                     (then-type-res (check env (caddr x)))
100                     (else-type-res (check env (cadddr x)))
101                     (then-eq-else-cs (~ (cadr then-type-res)
102                                         (cadr else-type-res)))
103                     (cs (consolidate
104                          (car then-type-res)
105                          (consolidate (car else-type-res)
106                                       then-eq-else-cs)))
107                     (return-type (substitute cs (cadr then-type-res))))
108                (when (not (eqv? (cadr cond-type-res) 'bool))
109                  (error #f "if condition isn't bool"))
110                (list cs return-type)))
111             
112             ('var (list '() (env-lookup env x)))
113             ('let
114                                         ; takes in the current environment and a scc
115                                         ; returns new environment with scc's types added in
116               (let* ([components (reverse (sccs (graph (let-bindings x))))]
117                      [process-component
118                       (lambda (acc comps)
119                         (let*
120                                         ; create a new env with tvars for each component
121                                         ; e.g. scc of (x y)
122                                         ; scc-env = ((x . t0) (y . t1))
123                             ([scc-env
124                               (fold-left
125                                (lambda (acc c)
126                                  (env-insert acc c (fresh-tvar)))
127                                acc comps)]
128                                         ; typecheck each component
129                              [type-results
130                               (map
131                                (lambda (c)
132                                  (let ([body (cadr (assoc c (let-bindings x)))])
133                                    (check scc-env body)))
134                                comps)]
135                                         ; collect all the constraints in the scc
136                              [cs
137                               (fold-left
138                                (lambda (acc res c)
139                                  (consolidate
140                                   acc
141                                   (consolidate (car res)
142                                         ; unify with tvars from scc-env
143                                         ; result ~ tvar
144                                                (~ (cadr res) (env-lookup scc-env c)))))
145                                '() type-results comps)]
146                                         ; substitute *only* the bindings in this scc
147                              [new-env
148                               (map (lambda (x)
149                                      (if (memv (car x) comps)
150                                          (cons (car x) (substitute cs (cdr x)))
151                                          x))
152                                    scc-env)])
153                           new-env))]
154                      [new-env (fold-left process-component env components)])
155                 (check new-env (last (let-body x)))))
156             
157             ('lambda
158                 (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
159
160                        (body-type-res (check new-env (lambda-body x)))
161                        (cs (car body-type-res))
162                        (subd-env (substitute-env (car body-type-res) new-env))
163                        (arg-type (env-lookup subd-env (lambda-arg x)))
164                        (resolved-arg-type (substitute cs arg-type))]
165                   ;; (display "lambda:\n\t")
166                   ;; (display prog)
167                   ;; (display "\n\t")
168                   ;; (display cs)
169                   ;; (display "\n\t")
170                   ;; (display resolved-arg-type)
171                   ;; (newline)
172                   (list (car body-type-res)
173                         (list 'abs
174                               resolved-arg-type
175                               (cadr body-type-res)))))
176             
177             ('app ; (f a)
178              (if (eqv? (car x) (cadr x))
179                                         ; recursive function (f f)
180                  (let* [(func-type (env-lookup env (car x)))
181                         (return-type (fresh-tvar))
182                         (other-func-type `(abs ,func-type ,return-type))
183                         (cs (~ func-type other-func-type))]
184                    (list cs return-type))
185
186                                         ; regular function
187                  (let* ((arg-type-res (check env (cadr x)))
188                         (arg-type (cadr arg-type-res))
189                         (func-type-res (check env (car x)))
190                         (func-type (cadr func-type-res))
191                         
192                                         ; f ~ a -> t0
193                         (func-c (~
194                                  func-type
195                                  (list 'abs
196                                        arg-type
197                                        (fresh-tvar))))
198                         (cs (consolidate
199                              (consolidate func-c (car arg-type-res))
200                              (car func-type-res)))
201                         
202                         (resolved-func-type (substitute cs func-type))
203                         (resolved-return-type (caddr resolved-func-type)))
204                    ;; (display "app:\n")
205                    ;; (display cs)
206                    ;; (display "\n")
207                    ;; (display func-type)
208                    ;; (display "\n")
209                    ;; (display resolved-func-type)
210                    ;; (display "\n")
211                    ;; (display arg-type-res)
212                    ;; (display "\n")
213                    (if (abs? resolved-func-type)
214                        (let ((return-type (substitute cs (caddr resolved-func-type))))
215                          (list cs return-type))
216                        (error #f "not a function"))))))))
217       ;; (display "result of ")
218       ;; (display x)
219       ;; (display ":\n\t")
220       ;; (display (pretty-type (cadr res)))
221       ;; (display "\n\t[")
222       ;; (display (car res))
223       ;; (display "]\n")
224       res))
225   (cadr (check '() (normalize prog))))
226
227                                         ; returns a list of pairs of constraints
228 (define (~ a b)
229   (let ([res (unify? a b)])
230     (if res
231         res
232         (error #f
233                (format "couldn't unify ~a ~~ ~a" a b)))))
234
235 (define (unify? a b)
236   (cond [(eq? a b) '()]
237         [(or (tvar? a) (tvar? b)) (list (list a b))]
238         [(and (abs? a) (abs? b))
239          (let* [(arg-cs (unify? (cadr a) (cadr b)))
240                 (body-cs (unify? (substitute arg-cs (caddr a))
241                                 (substitute arg-cs (caddr b))))]
242            (consolidate arg-cs body-cs))]
243         [else #f]))
244
245                                         ; TODO: what's the most appropriate substitution?
246                                         ; should all constraints just be limited to a pair?
247 (define (substitute cs t)
248                                         ; gets the first concrete type
249                                         ; otherwise returns the last type variable
250
251   (define cs-without-t
252     (map (lambda (c)
253            (filter (lambda (x) (not (eqv? t x))) c))
254          cs))
255
256   (define (get-concrete c)
257     (let [(last (null? (cdr c)))]
258       (if (not (tvar? (car c)))
259           (if (abs? (car c))
260               (substitute cs-without-t (car c))
261               (car c))
262           (if last
263               (car c)
264               (get-concrete (cdr c))))))
265   
266   (cond
267    ((abs? t) (list 'abs
268                    (substitute cs (cadr t))
269                    (substitute cs (caddr t))))
270    (else
271     (fold-left
272      (lambda (t c)
273        (if (member t c)
274            (get-concrete c)
275            t))
276      t cs))))
277
278 (define (substitute-env cs env)
279   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
280
281 (define (consolidate x y)
282   (define (merge a b)
283     (cond ((null? a) b)
284           ((null? b) a)
285           (else (if (member (car b) a)
286                     (merge a (cdr b))
287                     (cons (car b) (merge a (cdr b)))))))
288   (define (overlap? a b)
289     (if (or (null? a) (null? b))
290         #f
291         (if (fold-left (lambda (acc v)
292                          (or acc (eq? v (car a))))
293                        #f b)
294             #t
295             (overlap? (cdr a) b))))
296
297   (cond ((null? y) x)
298         ((null? x) y)
299         (else
300          (let* ((a (car y))
301                 (merged (fold-left
302                          (lambda (acc b)
303                            (if acc
304                                acc
305                                (if (overlap? a b)
306                                    (cons (merge a b) b)
307                                    #f)))
308                          #f x))
309                 (removed (if merged
310                              (filter (lambda (b) (not (eq? b (cdr merged)))) x)
311                              x)))
312            (if merged
313                (consolidate removed (cons (car merged) (cdr y)))
314                (consolidate (cons a x) (cdr y)))))))
315
316                                         ; a1 -> a2 ~ a3 -> a4;
317                                         ; a1 -> a2 !~ bool -> bool
318                                         ; basically can the tvars be renamed
319 (define (types-equal? x y)
320   (let ([cs (unify? x y)])
321     (if (not cs) #f
322         (let*
323             ([test-kind
324               (lambda (acc c)
325                 (if (tvar? c) acc #f))]
326              [test (lambda (acc c)
327                      (and acc (fold-left test-kind #t c)))])
328           (fold-left test #t cs)))))
329
330                                         ; input: a list of binds ((x . y) (y . 3))
331                                         ; returns: pair of verts, edges ((x y) . (x . y))
332 (define (graph bs)
333   (define (find-refs prog)
334     (ast-collect
335      (lambda (x)
336        (case (ast-type x)
337                                         ; only count a reference if its a binding
338          ['var (if (assoc x bs) (list x) '())]
339          [else '()]))
340      prog))
341   (if (null? bs)
342       '(() . ())
343       (let* [(bind (car bs))
344
345              (vert (car bind))
346              (refs (find-refs (cdr bind)))
347              (edges (map (lambda (x) (cons vert x))
348                          refs))
349
350              (rest (if (null? (cdr bs))
351                        (cons '() '())
352                        (graph (cdr bs))))
353              (total-verts (cons vert (car rest)))
354              (total-edges (append edges (cdr rest)))]
355         (cons total-verts total-edges))))
356
357 (define (successors graph v)
358   (define (go v E)
359     (if (null? E)
360         '()
361         (if (eqv? v (caar E))
362             (cons (cdar E) (go v (cdr E)))
363             (go v (cdr E)))))
364   (go v (cdr graph)))
365
366                                         ; takes in a graph (pair of vertices, edges)
367                                         ; returns a list of strongly connected components
368
369                                         ; ((x y w) . ((x . y) (x . w) (w . x))
370
371                                         ; =>
372                                         ; .->x->y
373                                         ; |  |
374                                         ; |  v
375                                         ; .--w
376
377                                         ; ((x w) (y))
378
379                                         ; this uses tarjan's algorithm, to get reverse
380                                         ; topological sorting for free
381 (define (sccs graph)
382   
383   (let* ([indices (make-hash-table)]
384          [lowlinks (make-hash-table)]
385          [on-stack (make-hash-table)]
386          [current 0]
387          [stack '()]
388          [result '()])
389
390     (define (index v)
391       (get-hash-table indices v #f))
392     (define (lowlink v)
393       (get-hash-table lowlinks v #f))
394
395     (letrec
396         ([strong-connect
397           (lambda (v)
398             (begin
399               (put-hash-table! indices v current)
400               (put-hash-table! lowlinks v current)
401               (set! current (+ current 1))
402               (push! stack v)
403               (put-hash-table! on-stack v #t)
404
405               (for-each
406                (lambda (w)
407                  (if (not (hashtable-contains? indices w))
408                                         ; successor w has not been visited, recurse
409                      (begin
410                        (strong-connect w)
411                        (put-hash-table! lowlinks
412                                         v
413                                         (min (lowlink v) (lowlink w))))
414                                         ; successor w has been visited
415                      (when (get-hash-table on-stack w #f)
416                        (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
417                (successors graph v))
418
419               (when (= (index v) (lowlink v))
420                 (let ([scc
421                        (let new-scc ()
422                          (let ([w (pop! stack)])
423                            (put-hash-table! on-stack w #f)
424                            (if (eqv? w v)
425                                (list w)
426                                (cons w (new-scc)))))])
427                   (set! result (cons scc result))))))])
428       (for-each
429        (lambda (v)
430          (when (not (hashtable-contains? indices v)) ; v.index == -1
431            (strong-connect v)))
432        (car graph)))
433     result))
434