X-Git-Url: http://git.lukelau.me/?a=blobdiff_plain;f=typecheck.scm;h=e96f6943767e1c5ab463398d11146d8f8ae450ce;hb=f605bff88ce12e5f4384ab308c036350bfa86cb5;hp=25a0e45685f5bbc87a85daf308bb541cb76a7960;hpb=d0e9f5296b7510fe057be4a2f9e2a31ed856652c;p=scheme.git diff --git a/typecheck.scm b/typecheck.scm index 25a0e45..e96f694 100644 --- a/typecheck.scm +++ b/typecheck.scm @@ -81,11 +81,11 @@ ; we typecheck the lambda calculus only (only single arg lambdas) (define (typecheck prog) (define (check env x) - (display "check: ") - (display x) - (display "\n\t") - (display env) - (newline) + ;; (display "check: ") + ;; (display x) + ;; (display "\n\t") + ;; (display env) + ;; (newline) (let ((res (case (ast-type x) @@ -98,7 +98,7 @@ (let* ((cond-type-res (check env (cadr x))) (then-type-res (check env (caddr x))) (else-type-res (check env (cadddr x))) - (then-eq-else-cs (unify (cadr then-type-res) + (then-eq-else-cs (~ (cadr then-type-res) (cadr else-type-res))) (cs (consolidate (car then-type-res) @@ -116,54 +116,44 @@ (let* ([components (reverse (sccs (graph (let-bindings x))))] [process-component (lambda (acc comps) - (display comps) - (newline) (let* + ; create a new env with tvars for each component + ; e.g. scc of (x y) + ; scc-env = ((x . t0) (y . t1)) ([scc-env (fold-left (lambda (acc c) (env-insert acc c (fresh-tvar))) acc comps)] + ; typecheck each component [type-results (map (lambda (c) - (begin (display scc-env) (newline) (let ([body (cadr (assoc c (let-bindings x)))]) - (display body)(newline)(check scc-env body)))) + (check scc-env body))) comps)] + ; collect all the constraints in the scc [cs (fold-left (lambda (acc res c) (consolidate acc - (unify (cadr res) (env-lookup scc-env c)))) - '() type-results comps)]) - (display "process-component env:\n") - (display (substitute-env cs scc-env)) - (newline) - (substitute-env cs scc-env)))] + (consolidate (car res) + ; unify with tvars from scc-env + ; result ~ tvar + (~ (cadr res) (env-lookup scc-env c))))) + '() type-results comps)] + ; substitute *only* the bindings in this scc + [new-env + (map (lambda (x) + (if (memv (car x) comps) + (cons (car x) (substitute cs (cdr x))) + x)) + scc-env)]) + new-env))] [new-env (fold-left process-component env components)]) (check new-env (last (let-body x))))) - ;; (let ((new-env (fold-left - ;; (lambda (acc bind) - ;; (let* [(bind-tvar (fresh-tvar)) - ;; (env-with-tvar (env-insert acc (car bind) bind-tvar)) - ;; (bind-res (check env-with-tvar (cadr bind))) - ;; (bind-type (cadr bind-res)) - ;; (cs (consolidate (car bind-res) - ;; (unify bind-type bind-tvar)))] - ;; (substitute-env cs env-with-tvar))) - ;; env (let-bindings x)))) - ;; (display "sccs of graph\n") - ;; (display (sccs (graph (let-bindings x)))) - ;; (newline) - ;; (display "env when checking body:\n\t") - ;; (display new-env) - ;; (newline) - ;; (check new-env (last (let-body x))))) - - ('lambda (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar))) @@ -190,8 +180,9 @@ (let* [(func-type (env-lookup env (car x))) (return-type (fresh-tvar)) (other-func-type `(abs ,func-type ,return-type)) - (cs (unify func-type other-func-type))] - (list cs return-type)) + (cs (~ func-type other-func-type)) + (resolved-return-type (substitute cs return-type))] + (list cs resolved-return-type)) ; regular function (let* ((arg-type-res (check env (cadr x))) @@ -200,7 +191,8 @@ (func-type (cadr func-type-res)) ; f ~ a -> t0 - (func-c (unify func-type + (func-c (~ + func-type (list 'abs arg-type (fresh-tvar)))) @@ -223,33 +215,46 @@ (let ((return-type (substitute cs (caddr resolved-func-type)))) (list cs return-type)) (error #f "not a function")))))))) - (display "result of ") - (display x) - (display ":\n\t") - (display (pretty-type (cadr res))) - (display "\n\t[") - (display (car res)) - (display "]\n") + ;; (display "result of ") + ;; (display x) + ;; (display ":\n\t") + ;; (display (pretty-type (cadr res))) + ;; (display "\n\t[") + ;; (display (car res)) + ;; (display "]\n") res)) (cadr (check '() (normalize prog)))) ; returns a list of pairs of constraints -(define (unify a b) - (cond ((eq? a b) '()) - ((or (tvar? a) (tvar? b)) (~ a b)) - ((and (abs? a) (abs? b)) - (let* [(arg-cs (unify (cadr a) (cadr b))) - (body-cs (unify (substitute arg-cs (caddr a)) +(define (~ a b) + (let ([res (unify? a b)]) + (if res + res + (error #f + (format "couldn't unify ~a ~~ ~a" a b))))) + +(define (unify? a b) + (cond [(eq? a b) '()] + [(or (tvar? a) (tvar? b)) (list (list a b))] + [(and (abs? a) (abs? b)) + (let* [(arg-cs (unify? (cadr a) (cadr b))) + (body-cs (unify? (substitute arg-cs (caddr a)) (substitute arg-cs (caddr b))))] - (consolidate arg-cs body-cs))) - (else (error #f "could not unify")))) + (consolidate arg-cs body-cs))] + [else #f])) ; TODO: what's the most appropriate substitution? ; should all constraints just be limited to a pair? + ; this is currently horrific and i don't know what im doing. + ; should probably use ast-find here or during consolidation + ; to detect substitutions more than one layer deep + ; e.g. (abs t1 int) ~ (abs bool int) + ; substituting these constraints with t1 should resolve t1 with bool (define (substitute cs t) ; gets the first concrete type ; otherwise returns the last type variable + ; removes t itself from cs, to prevent infinite recursion (define cs-without-t (map (lambda (c) (filter (lambda (x) (not (eqv? t x))) c)) @@ -280,9 +285,6 @@ (define (substitute-env cs env) (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env)) -(define (~ a b) - (list (list a b))) - (define (consolidate x y) (define (merge a b) (cond ((null? a) b) @@ -322,19 +324,32 @@ ; a1 -> a2 !~ bool -> bool ; basically can the tvars be renamed (define (types-equal? x y) - (error #f "todo")) + (let ([cs (unify? x y)]) + (if (not cs) #f + (let* + ([test-kind + (lambda (acc c) + (if (tvar? c) acc #f))] + [test (lambda (acc c) + (and acc + (fold-left test-kind #t c) ; check only tvar substitutions + (<= (length c) 2)))]) ; check maximum 2 subs per equality group + (fold-left test #t cs))))) ; input: a list of binds ((x . y) (y . 3)) ; returns: pair of verts, edges ((x y) . (x . y)) (define (graph bs) + (define (go bs orig-bs) (define (find-refs prog) (ast-collect (lambda (x) (case (ast-type x) ; only count a reference if its a binding - ['var (if (assoc x bs) (list x) '())] + ['var (if (assoc x orig-bs) (list x) '())] [else '()])) prog)) + (if (null? bs) + '(() . ()) (let* [(bind (car bs)) (vert (car bind)) @@ -344,10 +359,11 @@ (rest (if (null? (cdr bs)) (cons '() '()) - (graph (cdr bs)))) + (go (cdr bs) orig-bs))) (total-verts (cons vert (car rest))) (total-edges (append edges (cdr rest)))] - (cons total-verts total-edges))) + (cons total-verts total-edges)))) + (go bs bs)) (define (successors graph v) (define (go v E) @@ -420,7 +436,6 @@ (list w) (cons w (new-scc)))))]) (set! result (cons scc result))))))]) - (for-each (lambda (v) (when (not (hashtable-contains? indices v)) ; v.index == -1