X-Git-Url: http://git.lukelau.me/?p=scheme.git;a=blobdiff_plain;f=typecheck.scm;h=d180e44e92e755f1e4fbb4172f3535e70f8ebb24;hp=4cc93495fb9dd6bc90b03b9b06271e316c65a8bb;hb=8e106ca13666680051f91ab3f49ce2bd7e19ead7;hpb=a64f7097fa246c19a4c69d0aad65e60378273887 diff --git a/typecheck.scm b/typecheck.scm index 4cc9349..d180e44 100644 --- a/typecheck.scm +++ b/typecheck.scm @@ -37,7 +37,7 @@ ; ('a, ('b, 'a)) (define (env-lookup env n) - (if (null? env) (error #f "empty env") ; it's a type equality + (if (null? env) (error #f "empty env" env n) ; it's a type equality (if (eq? (caar env) n) (cdar env) (env-lookup (cdr env) n)))) @@ -91,11 +91,9 @@ (else (error #f "Couldn't find type for builtin" x)))) (define (check-let env x) - ; takes in the current environment and a scc - ; returns new environment with scc's types added in - (let* ([components (reverse (sccs (graph (let-bindings x))))] - [process-component - (lambda (acc comps) + + ; acc is a pair of (env . annotated bindings) + (define (process-component acc comps) (let* ; create a new env with tvars for each component ; e.g. scc of (x y) @@ -104,7 +102,7 @@ (fold-left (lambda (acc c) (env-insert acc c (fresh-tvar))) - acc comps)] + (car acc) comps)] ; typecheck each component [type-results (map @@ -130,24 +128,82 @@ (if (memv (car x) comps) (cons (car x) (substitute cs (cdr x))) x)) - scc-env)]) - new-env))] - [new-env (fold-left process-component env components)]) - (check new-env (last (let-body x))))) + scc-env)] + + [annotated-bindings (append (cdr acc) ; the previous annotated bindings + (map list + comps + (map caddr type-results)))]) + (cons new-env annotated-bindings))) + ; takes in the current environment and a scc + ; returns new environment with scc's types added in + (let* ([components (reverse (sccs (graph (let-bindings x))))] + [results (fold-left process-component (cons env '()) components)] + [new-env (car results)] + [annotated-bindings (cdr results)] + + [body-results (map (lambda (body) (check new-env body)) (let-body x))] + [let-type (cadr (last body-results))] + [cs (fold-left (lambda (acc cs) (constraint-merge acc cs)) '() (map car body-results))] + + [annotated `((let ,annotated-bindings ,@(map caddr body-results)) : ,let-type)]) + (list cs let-type annotated))) + +(define (check-app env x) + (if (eqv? (car x) (cadr x)) + ; recursive function (f f) + ; TODO: what about ((f a) f)???? + (let* ([func-type (env-lookup env (car x))] + [return-type (fresh-tvar)] + [other-func-type `(abs ,func-type ,return-type)] + [cs (~ func-type other-func-type)] + [resolved-return-type (substitute cs return-type)] + + [annotated `(((,(car x) : ,func-type) + (,(cadr x) : ,func-type)) : ,resolved-return-type)]) + (list cs resolved-return-type annotated))) + + ; regular function + (let* ([arg-type-res (check env (cadr x))] + [arg-type (cadr arg-type-res)] + [func-type-res (check env (car x))] + [func-type (cadr func-type-res)] + + ; f ~ a -> t0 + [func-c (~ + (substitute (car arg-type-res) func-type) + `(abs ,arg-type ,(fresh-tvar)))] + [cs (constraint-merge + (constraint-merge func-c (car arg-type-res)) + (car func-type-res))] + + [resolved-func-type (substitute cs func-type)] + [resolved-return-type (caddr resolved-func-type)] + + [annotated `((,(caddr func-type-res) + ,(caddr arg-type-res)) : ,resolved-return-type)]) + + (if (abs? resolved-func-type) + (let ((return-type (substitute cs (caddr resolved-func-type)))) + (list cs return-type annotated)) + (error #f "not a function")))) +; returns a list (constraints type annotated) (define (check env x) - (display "check: ") - (display x) - (display "\n\t") - (display env) - (newline) + (define (make-result cs type) + (list cs type `(,x : ,type))) + ;; (display "check: ") + ;; (display x) + ;; (display "\n\t") + ;; (display env) + ;; (newline) (let ((res (case (ast-type x) - ('int-literal (list '() 'Int)) - ('bool-literal (list '() 'Bool)) - ('string-literal (list '() 'String)) - ('builtin (list '() (builtin-type x))) + ('int-literal (make-result '() 'Int)) + ('bool-literal (make-result '() 'Bool)) + ('string-literal (make-result '() 'String)) + ('builtin (make-result '() (builtin-type x))) ('if (let* ((cond-type-res (check env (cadr x))) @@ -160,97 +216,107 @@ (constraint-merge (~ (cadr cond-type-res) 'Bool) (constraint-merge (car else-type-res) then-eq-else-cs)))) - (return-type (substitute cs (cadr then-type-res)))) - (list cs return-type))) + (return-type (substitute cs (cadr then-type-res))) + [annotated `((if ,(caddr cond-type-res) + ,(caddr then-type-res) + ,(caddr else-type-res)) : ,return-type)]) + (list cs return-type annotated))) - ('var (list '() (env-lookup env x))) + ('var (make-result '() (env-lookup env x))) ('let (check-let env x)) ('lambda - (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar))) - - (body-type-res (check new-env (lambda-body x))) - (cs (car body-type-res)) - (subd-env (substitute-env (car body-type-res) new-env)) - (arg-type (env-lookup subd-env (lambda-arg x))) - (resolved-arg-type (substitute cs arg-type))] - ;; (display "lambda:\n\t") - ;; (display prog) - ;; (display "\n\t") - ;; (display cs) - ;; (display "\n\t") - ;; (display (format "subd-env: ~a\n" subd-env)) - ;; (display resolved-arg-type) - ;; (newline) - (list (car body-type-res) - (list 'abs - resolved-arg-type - (cadr body-type-res))))) + (let* ([new-env (env-insert env (lambda-arg x) (fresh-tvar))] - ('app ; (f a) - (if (eqv? (car x) (cadr x)) - ; recursive function (f f) - (let* [(func-type (env-lookup env (car x))) - (return-type (fresh-tvar)) - (other-func-type `(abs ,func-type ,return-type)) - (cs (~ func-type other-func-type)) - (resolved-return-type (substitute cs return-type))] - (list cs resolved-return-type))) + [body-type-res (check new-env (lambda-body x))] + [cs (car body-type-res)] + [subd-env (substitute-env (car body-type-res) new-env)] + [arg-type (env-lookup subd-env (lambda-arg x))] + [resolved-arg-type (substitute cs arg-type)] - ; regular function - (let* ((arg-type-res (check env (cadr x))) - (arg-type (cadr arg-type-res)) - (func-type-res (check env (car x))) - (func-type (cadr func-type-res)) + [lambda-type `(abs ,resolved-arg-type ,(cadr body-type-res))] - ; f ~ a -> t0 - (func-c (~ - (substitute (car arg-type-res) func-type) - `(abs ,arg-type ,(fresh-tvar)))) - (cs (constraint-merge - (constraint-merge func-c (car arg-type-res)) - (car func-type-res))) - - (resolved-func-type (substitute cs func-type)) - (resolved-return-type (caddr resolved-func-type))) - ;; (display "app:\n") - ;; (display cs) - ;; (display "\n") - ;; (display func-type) - ;; (display "\n") - ;; (display resolved-func-type) - ;; (display "\n") - ;; (display arg-type-res) - ;; (display "\n") - (if (abs? resolved-func-type) - (let ((return-type (substitute cs (caddr resolved-func-type)))) - (list cs return-type)) - (error #f "not a function"))))))) - (display "result of ") - (display x) - (display ":\n\t") - (display (pretty-type (cadr res))) - (display "\n\t[") - (display (pretty-constraints (car res))) - (display "]\n") + [annotated `((lambda (,(lambda-arg x)) ,(caddr body-type-res)) : ,lambda-type)]) + + (list (car body-type-res) ; constraints + lambda-type ; type + annotated))) + + + ('app (check-app env x))))) + ;; (display "result of ") + ;; (display x) + ;; (display ":\n\t") + ;; (display (pretty-type (cadr res))) + ;; (display "\n\t[") + ;; (display (pretty-constraints (car res))) + ;; (display "]\n") res)) +(define (init-adts-env prog) + (flat-map data-tors-type-env (program-data-layouts prog))) + ; we typecheck the lambda calculus only (only single arg lambdas) (define (typecheck prog) - (define (constructor-type t ctr) - (fold-left (lambda (acc x) `(abs ,x ,acc)) t (cdr ctr))) - (define (constructors data-def) - (let ([type-name (cadr data-def)] - [ctrs (cddr data-def)]) - (fold-left (lambda (acc ctr) - (cons (cons (car ctr) (constructor-type type-name ctr)) - acc)) - '() - ctrs))) - (let ([init-env (flat-map constructors (program-datas prog))]) - (display init-env) - (cadr (check init-env (normalize (program-body prog)))))) + (let ([expanded (expand-pattern-matches prog)]) + (cadr (check (init-adts-env expanded) (normalize (program-body expanded)))))) + + + ; before passing annotated types onto codegen + ; we need to restore the pre-normalization structure + ; (this is important for function arity etc) +(define (denormalize orig normed) + + (define (collapse-lambdas n x) + (case n + [0 x] + [else + (let* ([inner-lambda (lambda-body (ann-expr x))] + [arg (lambda-arg (ann-expr x))] + [inner-collapsed (ann-expr (collapse-lambdas (- n 1) inner-lambda))]) + `((lambda ,(cons arg (lambda-args inner-collapsed)) + ,(lambda-body inner-collapsed)) : ,(ann-type x)))])) + + (define (collapse-apps n x) + (case n + [-1 (error #f "nullary functions not handled yet")] + [0 x] + [else + (let* ([inner-app (car (ann-expr x))] + [inner-collapsed (collapse-apps (- n 1) inner-app)]) + `(,(append (ann-expr inner-collapsed) (cdr (ann-expr x))) : ,(ann-type x)))])) + + (case (ast-type orig) + ['lambda + (let ([collapsed (collapse-lambdas (- (length (lambda-args orig)) 1) normed)]) + `((lambda ,(lambda-args (ann-expr collapsed)) + ,(denormalize (lambda-body orig) + (lambda-body (ann-expr collapsed)))) : ,(ann-type collapsed)))] + ['app + (let ([collapsed (collapse-apps (- (length orig) 2) normed)]) + `(,(map (lambda (o n) (denormalize o n)) orig (ann-expr collapsed)) + : ,(ann-type collapsed)))] + ['let + `((let ,(map (lambda (o n) (list (car o) (denormalize (cadr o) (cadr n)))) + (let-bindings orig) + (let-bindings (ann-expr normed))) + ,@(map (lambda (o n) (denormalize o n)) + (let-body orig) + (let-body (ann-expr normed)))) : ,(ann-type normed))] + ['if `((if ,@(map denormalize (cdr orig) (cdr (ann-expr normed)))) + : (ann-type normed))] + [else normed])) + +(define ann-expr car) +(define ann-type caddr) + + ; prerequisites: expand-pattern-matches +(define (annotate-types prog) + (denormalize + (program-body prog) + (caddr (check (init-adts-env prog) (normalize (program-body prog)))))) + ; returns a list of constraints (define (~ a b) @@ -338,108 +404,4 @@ ; input: a list of binds ((x . y) (y . 3)) ; returns: pair of verts, edges ((x y) . (x . y)) -(define (graph bs) - (define (go bs orig-bs) - (define (find-refs prog) - (ast-collect - (lambda (x) - (case (ast-type x) - ; only count a reference if its a binding - ['var (if (assoc x orig-bs) (list x) '())] - [else '()])) - prog)) - (if (null? bs) - '(() . ()) - (let* [(bind (car bs)) - - (vert (car bind)) - (refs (find-refs (cdr bind))) - (edges (map (lambda (x) (cons vert x)) - refs)) - - (rest (if (null? (cdr bs)) - (cons '() '()) - (go (cdr bs) orig-bs))) - (total-verts (cons vert (car rest))) - (total-edges (append edges (cdr rest)))] - (cons total-verts total-edges)))) - (go bs bs)) - -(define (successors graph v) - (define (go v E) - (if (null? E) - '() - (if (eqv? v (caar E)) - (cons (cdar E) (go v (cdr E))) - (go v (cdr E))))) - (go v (cdr graph))) - - ; takes in a graph (pair of vertices, edges) - ; returns a list of strongly connected components - - ; ((x y w) . ((x . y) (x . w) (w . x)) - - ; => - ; .->x->y - ; | | - ; | v - ; .--w - - ; ((x w) (y)) - - ; this uses tarjan's algorithm, to get reverse - ; topological sorting for free -(define (sccs graph) - - (let* ([indices (make-hash-table)] - [lowlinks (make-hash-table)] - [on-stack (make-hash-table)] - [current 0] - [stack '()] - [result '()]) - - (define (index v) - (get-hash-table indices v #f)) - (define (lowlink v) - (get-hash-table lowlinks v #f)) - - (letrec - ([strong-connect - (lambda (v) - (begin - (put-hash-table! indices v current) - (put-hash-table! lowlinks v current) - (set! current (+ current 1)) - (push! stack v) - (put-hash-table! on-stack v #t) - - (for-each - (lambda (w) - (if (not (hashtable-contains? indices w)) - ; successor w has not been visited, recurse - (begin - (strong-connect w) - (put-hash-table! lowlinks - v - (min (lowlink v) (lowlink w)))) - ; successor w has been visited - (when (get-hash-table on-stack w #f) - (put-hash-table! lowlinks v (min (lowlink v) (index w)))))) - (successors graph v)) - - (when (= (index v) (lowlink v)) - (let ([scc - (let new-scc () - (let ([w (pop! stack)]) - (put-hash-table! on-stack w #f) - (if (eqv? w v) - (list w) - (cons w (new-scc)))))]) - (set! result (cons scc result))))))]) - (for-each - (lambda (v) - (when (not (hashtable-contains? indices v)) ; v.index == -1 - (strong-connect v))) - (car graph))) - result))