4 (and (list? t) (eq? (car t) 'abs)))
7 (and (not (list? t)) (not (concrete? t)) (symbol? t)))
16 (define (pretty-type t)
20 (string-append "(" (pretty-type (cadr t)) ")")
21 (pretty-type (cadr t)))
23 (pretty-type (caddr t))))
24 (else (symbol->string t))))
26 (define (pretty-constraints cs)
28 (fold-left string-append
40 (define (env-lookup env n)
41 (if (null? env) (error #f "empty env") ; it's a type equality
42 (if (eq? (caar env) n)
44 (env-lookup (cdr env) n))))
46 (define (env-insert env n t)
47 (cons (cons n t) env))
54 (set! cur-tvar (+ cur-tvar 1))
56 (string-append "t" (number->string (- cur-tvar 1))))))
63 (define (normalize prog) ; (+ a b) -> ((+ a) b)
66 ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
67 (if (> (length (lambda-args prog)) 1)
68 (list 'lambda (list (car (lambda-args prog)))
69 (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
70 (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
72 (if (null? (cddr prog))
73 `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
74 (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
75 ,@(cddr prog))))) ; (f a b)
78 (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
80 (map normalize (let-body prog))))
81 (else (ast-traverse normalize prog))))
83 (define (builtin-type x)
85 ('+ '(abs int (abs int int)))
86 ('- '(abs int (abs int int)))
87 ('* '(abs int (abs int int)))
89 ('= '(abs int (abs int bool)))
90 ('bool->int '(abs bool int))
91 ('print '(abs string void))
95 ;; (display "check: ")
103 ('int-literal (list '() 'int))
104 ('bool-literal (list '() 'bool))
105 ('string-literal (list '() 'string))
106 ('builtin (list '() (builtin-type x)))
109 (let* ((cond-type-res (check env (cadr x)))
110 (then-type-res (check env (caddr x)))
111 (else-type-res (check env (cadddr x)))
112 (then-eq-else-cs (~ (cadr then-type-res)
113 (cadr else-type-res)))
114 (cs (constraint-merge
116 (constraint-merge (car else-type-res)
118 (return-type (substitute cs (cadr then-type-res))))
119 (when (not (eqv? (cadr cond-type-res) 'bool))
120 (error #f "if condition isn't bool"))
121 (list cs return-type)))
123 ('var (list '() (env-lookup env x)))
125 ; takes in the current environment and a scc
126 ; returns new environment with scc's types added in
127 (let* ([components (reverse (sccs (graph (let-bindings x))))]
131 ; create a new env with tvars for each component
133 ; scc-env = ((x . t0) (y . t1))
137 (env-insert acc c (fresh-tvar)))
139 ; typecheck each component
143 (let ([body (cadr (assoc c (let-bindings x)))])
144 (check scc-env body)))
146 ; collect all the constraints in the scc
152 ; unify with tvars from scc-env
154 (~ (env-lookup scc-env c) (cadr res))
157 '() type-results comps)]
158 ; substitute *only* the bindings in this scc
161 (if (memv (car x) comps)
162 (cons (car x) (substitute cs (cdr x)))
166 [new-env (fold-left process-component env components)])
167 (check new-env (last (let-body x)))))
170 (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
172 (body-type-res (check new-env (lambda-body x)))
173 (cs (car body-type-res))
174 (subd-env (substitute-env (car body-type-res) new-env))
175 (arg-type (env-lookup subd-env (lambda-arg x)))
176 (resolved-arg-type (substitute cs arg-type))]
177 ;; (display "lambda:\n\t")
182 ;; (display (format "subd-env: ~a\n" subd-env))
183 ;; (display resolved-arg-type)
185 (list (car body-type-res)
188 (cadr body-type-res)))))
191 (if (eqv? (car x) (cadr x))
192 ; recursive function (f f)
193 (let* [(func-type (env-lookup env (car x)))
194 (return-type (fresh-tvar))
195 (other-func-type `(abs ,func-type ,return-type))
196 (cs (~ func-type other-func-type))
197 (resolved-return-type (substitute cs return-type))]
198 (list cs resolved-return-type)))
201 (let* ((arg-type-res (check env (cadr x)))
202 (arg-type (cadr arg-type-res))
203 (func-type-res (check env (car x)))
204 (func-type (cadr func-type-res))
208 (substitute (car arg-type-res) func-type)
209 `(abs ,arg-type ,(fresh-tvar))))
210 (cs (constraint-merge
211 (constraint-merge func-c (car arg-type-res))
212 (car func-type-res)))
214 (resolved-func-type (substitute cs func-type))
215 (resolved-return-type (caddr resolved-func-type)))
216 ;; (display "app:\n")
219 ;; (display func-type)
221 ;; (display resolved-func-type)
223 ;; (display arg-type-res)
225 (if (abs? resolved-func-type)
226 (let ((return-type (substitute cs (caddr resolved-func-type))))
227 (list cs return-type))
228 (error #f "not a function")))))))
229 ;; (display "result of ")
232 ;; (display (pretty-type (cadr res)))
234 ;; (display (pretty-constraints (car res)))
238 ; we typecheck the lambda calculus only (only single arg lambdas)
239 (define (typecheck prog)
240 (cadr (check '() (normalize prog))))
242 ; returns a list of constraints
244 (let ([res (unify? a b)])
248 (format "couldn't unify ~a ~~ ~a" a b)))))
251 (cond [(eq? a b) '()]
252 [(tvar? a) (list (cons a b))]
253 [(tvar? b) (list (cons b a))]
254 [(and (abs? a) (abs? b))
255 (let* [(arg-cs (unify? (cadr a) (cadr b)))
256 (body-cs (unify? (substitute arg-cs (caddr a))
257 (substitute arg-cs (caddr b))))]
258 (constraint-merge body-cs arg-cs))]
261 (define (substitute cs t)
267 [(abs? t) `(abs ,(substitute cs (cadr t))
268 ,(substitute cs (caddr t)))]
271 ; applies substitutions to all variables in environment
272 (define (substitute-env cs env)
273 (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
275 ; composes constraints a onto b and merges, i.e. applies a to b
276 ; a should be the "more important" constraints
277 (define (constraint-merge a b)
278 (define (f cs constraint)
279 (cons (car constraint)
280 (substitute cs (cdr constraint))))
282 (define (most-concrete a b)
286 [(and (abs? a) (abs? b))
287 `(abs ,(most-concrete (cadr a) (cadr b))
288 ,(most-concrete (caddr a) (caddr b)))]
293 ; for any two constraints that clash, e.g. t1 ~ abs t2 t3
294 ; and t1 ~ abs int t3
295 ; prepend the most concrete version of the type to the
296 ; list of constraints
299 (if (assoc (car x) a)
300 (cons (cons (car x) (most-concrete (cdr (assoc (car x) a))
304 (fold-left gen '() b))
307 (append (filter (lambda (x) (not (assoc (car x) p)))
310 (append (clashes) (union a (map (lambda (z) (f a z)) b))))
313 ;; ; a1 -> a2 ~ a3 -> a4;
314 ;; ; a1 -> a2 !~ bool -> bool
315 ;; ; basically can the tvars be renamed
316 (define (types-equal? x y)
317 (let ([cs (unify? x y)])
320 ([test (lambda (acc c)
322 (tvar? (car c)) ; the only substitutions allowed are tvar -> tvar
324 (fold-left test #t cs)))))
326 ; input: a list of binds ((x . y) (y . 3))
327 ; returns: pair of verts, edges ((x y) . (x . y))
329 (define (go bs orig-bs)
330 (define (find-refs prog)
334 ; only count a reference if its a binding
335 ['var (if (assoc x orig-bs) (list x) '())]
340 (let* [(bind (car bs))
343 (refs (find-refs (cdr bind)))
344 (edges (map (lambda (x) (cons vert x))
347 (rest (if (null? (cdr bs))
349 (go (cdr bs) orig-bs)))
350 (total-verts (cons vert (car rest)))
351 (total-edges (append edges (cdr rest)))]
352 (cons total-verts total-edges))))
355 (define (successors graph v)
359 (if (eqv? v (caar E))
360 (cons (cdar E) (go v (cdr E)))
364 ; takes in a graph (pair of vertices, edges)
365 ; returns a list of strongly connected components
367 ; ((x y w) . ((x . y) (x . w) (w . x))
377 ; this uses tarjan's algorithm, to get reverse
378 ; topological sorting for free
381 (let* ([indices (make-hash-table)]
382 [lowlinks (make-hash-table)]
383 [on-stack (make-hash-table)]
389 (get-hash-table indices v #f))
391 (get-hash-table lowlinks v #f))
397 (put-hash-table! indices v current)
398 (put-hash-table! lowlinks v current)
399 (set! current (+ current 1))
401 (put-hash-table! on-stack v #t)
405 (if (not (hashtable-contains? indices w))
406 ; successor w has not been visited, recurse
409 (put-hash-table! lowlinks
411 (min (lowlink v) (lowlink w))))
412 ; successor w has been visited
413 (when (get-hash-table on-stack w #f)
414 (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
415 (successors graph v))
417 (when (= (index v) (lowlink v))
420 (let ([w (pop! stack)])
421 (put-hash-table! on-stack w #f)
424 (cons w (new-scc)))))])
425 (set! result (cons scc result))))))])
428 (when (not (hashtable-contains? indices v)) ; v.index == -1