Break up lets into SCCs before typechecking
[scheme.git] / typecheck.scm
index 7eb4fa96606d786b5e58d17f2d2601f4f52c85a3..25a0e45685f5bbc87a85daf308bb541cb76a7960 100644 (file)
@@ -58,8 +58,8 @@
     ('app
      (if (null? (cddr prog))
         `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
     ('app
      (if (null? (cddr prog))
         `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
-        `(,(list (normalize (car prog)) (normalize (cadr prog)))
-          ,(normalize (caddr prog))))) ; (f a b)
+        (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
+                     ,@(cddr prog))))) ; (f a b)
     ('let
        (append (list 'let
                      (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
     ('let
        (append (list 'let
                      (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
   (define (check env x)
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
   (define (check env x)
-    ;; (display "check: ")
-    ;; (display x)
-    ;; (display "\n\t")
-    ;; (display env)
-    ;; (newline)
+    (display "check: ")
+    (display x)
+    (display "\n\t")
+    (display env)
+    (newline)
     (let
        ((res
          (case (ast-type x)
     (let
        ((res
          (case (ast-type x)
            
            ('var (list '() (env-lookup env x)))
            ('let
            
            ('var (list '() (env-lookup env x)))
            ('let
-           (let ((new-env (fold-left
-                           (lambda (acc bind)
-                             (let ((t (check
-                                       (env-insert acc (car bind) (fresh-tvar))
-                                       (cadr bind))))
-                               (env-insert acc (car bind) (cadr t))))
-                           env (let-bindings x))))
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+             (let* ([components (reverse (sccs (graph (let-bindings x))))]
+                    [process-component
+                     (lambda (acc comps)
+                       (display comps)
+                       (newline)
+                       (let*
+                           ([scc-env
+                             (fold-left
+                              (lambda (acc c)
+                                (env-insert acc c (fresh-tvar)))
+                              acc comps)]
+                            [type-results
+                             (map
+                              (lambda (c)
+                                (begin (display scc-env) (newline)
+                                (let ([body (cadr (assoc c (let-bindings x)))])
+                                  (display body)(newline)(check scc-env body))))
+                              comps)]
+                            [cs
+                             (fold-left
+                              (lambda (acc res c)
+                                (consolidate
+                                 acc
+                                 (unify (cadr res) (env-lookup scc-env c))))
+                              '() type-results comps)])
+                         (display "process-component env:\n")
+                         (display (substitute-env cs scc-env))
+                         (newline)
+                         (substitute-env cs scc-env)))]
+                    [new-env (fold-left process-component env components)])
                (check new-env (last (let-body x)))))
            
                (check new-env (last (let-body x)))))
            
+           ;; (let ((new-env (fold-left
+           ;;          (lambda (acc bind)
+           ;;            (let* [(bind-tvar (fresh-tvar))
+           ;;                   (env-with-tvar (env-insert acc (car bind) bind-tvar))
+           ;;                   (bind-res (check env-with-tvar (cadr bind)))
+           ;;                   (bind-type (cadr bind-res))
+           ;;                   (cs (consolidate (car bind-res)
+           ;;                                    (unify bind-type bind-tvar)))]
+           ;;              (substitute-env cs env-with-tvar)))
+           ;;          env (let-bindings x))))
+           ;;   (display "sccs of graph\n")
+           ;;   (display (sccs (graph (let-bindings x))))
+           ;;   (newline)
+           ;;   (display "env when checking body:\n\t")
+           ;;   (display new-env)
+           ;;   (newline)
+           ;;   (check new-env (last (let-body x)))))
+           
 
            ('lambda
 
            ('lambda
-           (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+               (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+
                       (body-type-res (check new-env (lambda-body x)))
                       (cs (car body-type-res))
                       (subd-env (substitute-env (car body-type-res) new-env))
                       (arg-type (env-lookup subd-env (lambda-arg x)))
                       (body-type-res (check new-env (lambda-body x)))
                       (cs (car body-type-res))
                       (subd-env (substitute-env (car body-type-res) new-env))
                       (arg-type (env-lookup subd-env (lambda-arg x)))
-                  (resolved-arg-type (substitute cs arg-type)))
+                      (resolved-arg-type (substitute cs arg-type))]
                  ;; (display "lambda:\n\t")
                  ;; (display prog)
                  ;; (display "\n\t")
                  ;; (display "lambda:\n\t")
                  ;; (display prog)
                  ;; (display "\n\t")
                              (cadr body-type-res)))))
            
            ('app ; (f a)
                              (cadr body-type-res)))))
            
            ('app ; (f a)
+            (if (eqv? (car x) (cadr x))
+                                       ; recursive function (f f)
+                (let* [(func-type (env-lookup env (car x)))
+                       (return-type (fresh-tvar))
+                       (other-func-type `(abs ,func-type ,return-type))
+                       (cs (unify func-type other-func-type))]
+                  (list cs return-type))
+
+                                       ; regular function
                 (let* ((arg-type-res (check env (cadr x)))
                        (arg-type (cadr arg-type-res))
                        (func-type-res (check env (car x)))
                 (let* ((arg-type-res (check env (cadr x)))
                        (arg-type (cadr arg-type-res))
                        (func-type-res (check env (car x)))
                   (if (abs? resolved-func-type)
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
                   (if (abs? resolved-func-type)
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
-                 (error #f "not a function")))))))
-      ;; (display "result of ")
-      ;; (display x)
-      ;; (display ":\n\t")
-      ;; (display (cadr res))
-      ;; (display "[")
-      ;; (display (car res))
-      ;; (display "]\n")
+                      (error #f "not a function"))))))))
+      (display "result of ")
+      (display x)
+      (display ":\n\t")
+      (display (pretty-type (cadr res)))
+      (display "\n\t[")
+      (display (car res))
+      (display "]\n")
       res))
   (cadr (check '() (normalize prog))))
 
       res))
   (cadr (check '() (normalize prog))))
 
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
-        (consolidate (unify (cadr a) (cadr b))
-                     (unify (caddr a) (caddr b))))
+        (let* [(arg-cs (unify (cadr a) (cadr b)))
+               (body-cs (unify (substitute arg-cs (caddr a))
+                               (substitute arg-cs (caddr b))))]
+          (consolidate arg-cs body-cs)))
        (else (error #f "could not unify"))))
 
                                        ; TODO: what's the most appropriate substitution?
        (else (error #f "could not unify"))))
 
                                        ; TODO: what's the most appropriate substitution?
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
 
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
 
+  (define cs-without-t
+    (map (lambda (c)
+          (filter (lambda (x) (not (eqv? t x))) c))
+        cs))
+
   (define (get-concrete c)
   (define (get-concrete c)
-    (let ((last (null? (cdr c))))
+    (let [(last (null? (cdr c)))]
       (if (not (tvar? (car c)))
          (if (abs? (car c))
       (if (not (tvar? (car c)))
          (if (abs? (car c))
-             (substitute cs (car c))
+             (substitute cs-without-t (car c))
              (car c))
          (if last
              (car c)
              (get-concrete (cdr c))))))
              (car c))
          (if last
              (car c)
              (get-concrete (cdr c))))))
+  
   (cond
    ((abs? t) (list 'abs
                   (substitute cs (cadr t))
   (cond
    ((abs? t) (list 'abs
                   (substitute cs (cadr t))
 
   (cond ((null? y) x)
        ((null? x) y)
 
   (cond ((null? y) x)
        ((null? x) y)
-       (else (let* ((a (car y))
+       (else
+        (let* ((a (car y))
                (merged (fold-left
                         (lambda (acc b)
                           (if acc
                (merged (fold-left
                         (lambda (acc b)
                           (if acc
           (if merged
               (consolidate removed (cons (car merged) (cdr y)))
               (consolidate (cons a x) (cdr y)))))))
           (if merged
               (consolidate removed (cons (car merged) (cdr y)))
               (consolidate (cons a x) (cdr y)))))))
+
+                                       ; a1 -> a2 ~ a3 -> a4;
+                                       ; a1 -> a2 !~ bool -> bool
+                                       ; basically can the tvars be renamed
+(define (types-equal? x y)
+  (error #f "todo"))
+
+                                       ; input: a list of binds ((x . y) (y . 3))
+                                       ; returns: pair of verts, edges ((x y) . (x . y))
+(define (graph bs)
+  (define (find-refs prog)
+    (ast-collect
+     (lambda (x)
+       (case (ast-type x)
+                                       ; only count a reference if its a binding
+        ['var (if (assoc x bs) (list x) '())]
+        [else '()]))
+     prog))
+  (let* [(bind (car bs))
+
+        (vert (car bind))
+        (refs (find-refs (cdr bind)))
+        (edges (map (lambda (x) (cons vert x))
+                    refs))
+
+        (rest (if (null? (cdr bs))
+                  (cons '() '())
+                  (graph (cdr bs))))
+        (total-verts (cons vert (car rest)))
+        (total-edges (append edges (cdr rest)))]
+    (cons total-verts total-edges)))
+
+(define (successors graph v)
+  (define (go v E)
+    (if (null? E)
+       '()
+       (if (eqv? v (caar E))
+           (cons (cdar E) (go v (cdr E)))
+           (go v (cdr E)))))
+  (go v (cdr graph)))
+
+                                       ; takes in a graph (pair of vertices, edges)
+                                       ; returns a list of strongly connected components
+
+                                       ; ((x y w) . ((x . y) (x . w) (w . x))
+
+                                       ; =>
+                                       ; .->x->y
+                                       ; |  |
+                                       ; |  v
+                                       ; .--w
+
+                                       ; ((x w) (y))
+
+                                       ; this uses tarjan's algorithm, to get reverse
+                                       ; topological sorting for free
+(define (sccs graph)
+  
+  (let* ([indices (make-hash-table)]
+        [lowlinks (make-hash-table)]
+        [on-stack (make-hash-table)]
+        [current 0]
+        [stack '()]
+        [result '()])
+
+    (define (index v)
+      (get-hash-table indices v #f))
+    (define (lowlink v)
+      (get-hash-table lowlinks v #f))
+
+    (letrec
+       ([strong-connect
+         (lambda (v)
+           (begin
+             (put-hash-table! indices v current)
+             (put-hash-table! lowlinks v current)
+             (set! current (+ current 1))
+             (push! stack v)
+             (put-hash-table! on-stack v #t)
+
+             (for-each
+              (lambda (w)
+                (if (not (hashtable-contains? indices w))
+                                       ; successor w has not been visited, recurse
+                    (begin
+                      (strong-connect w)
+                      (put-hash-table! lowlinks
+                                       v
+                                       (min (lowlink v) (lowlink w))))
+                                       ; successor w has been visited
+                    (when (get-hash-table on-stack w #f)
+                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
+              (successors graph v))
+
+             (when (= (index v) (lowlink v))
+               (let ([scc
+                      (let new-scc ()
+                        (let ([w (pop! stack)])
+                          (put-hash-table! on-stack w #f)
+                          (if (eqv? w v)
+                              (list w)
+                              (cons w (new-scc)))))])
+                 (set! result (cons scc result))))))])
+      
+      (for-each
+       (lambda (v)
+        (when (not (hashtable-contains? indices v)) ; v.index == -1
+          (strong-connect v)))
+       (car graph)))
+    result))
+