WIP on typechecker refactor
[scheme.git] / typecheck.scm
index 25a0e45685f5bbc87a85daf308bb541cb76a7960..6a99869c0697302f361a9e28ac216013f5d018d6 100644 (file)
          (pretty-type (caddr t))))
        (else (symbol->string t))))
 
+(define (pretty-constraints cs)
+  (string-append "{"
+                (fold-left string-append
+                           ""
+                           (map (lambda (c)
+                                  (string-append
+                                   (pretty-type (car c))
+                                   ": "
+                                   (pretty-type (cdr c))
+                                   ", "))
+                                cs))
+                "}"))
+
                                        ; ('a, ('b, 'a))
 (define (env-lookup env n)
   (if (null? env) (error #f "empty env")                       ; it's a type equality
@@ -78,8 +91,6 @@
     ('print '(abs string void))
     (else #f)))
 
-; we typecheck the lambda calculus only (only single arg lambdas)
-(define (typecheck prog)
 (define (check env x)
   (display "check: ")
   (display x)
           (let* ((cond-type-res (check env (cadr x)))
                  (then-type-res (check env (caddr x)))
                  (else-type-res (check env (cadddr x)))
-                   (then-eq-else-cs (unify (cadr then-type-res)
+                 (then-eq-else-cs (~ (cadr then-type-res)
                                      (cadr else-type-res)))
-                   (cs (consolidate
+                 (cs (constraint-merge
                       (car then-type-res)
-                        (consolidate (car else-type-res)
+                      (constraint-merge (car else-type-res)
                                         then-eq-else-cs)))
                  (return-type (substitute cs (cadr then-type-res))))
             (when (not (eqv? (cadr cond-type-res) 'bool))
              (let* ([components (reverse (sccs (graph (let-bindings x))))]
                     [process-component
                      (lambda (acc comps)
-                       (display comps)
-                       (newline)
                        (let*
+                                       ; create a new env with tvars for each component
+                                       ; e.g. scc of (x y)
+                                       ; scc-env = ((x . t0) (y . t1))
                            ([scc-env
                              (fold-left
                               (lambda (acc c)
                                 (env-insert acc c (fresh-tvar)))
                               acc comps)]
+                                       ; typecheck each component
                             [type-results
                              (map
                               (lambda (c)
-                                (begin (display scc-env) (newline)
                                 (let ([body (cadr (assoc c (let-bindings x)))])
-                                  (display body)(newline)(check scc-env body))))
+                                  (check scc-env body)))
                               comps)]
+                                       ; collect all the constraints in the scc
                             [cs
                              (fold-left
                               (lambda (acc res c)
-                                (consolidate
-                                 acc
-                                 (unify (cadr res) (env-lookup scc-env c))))
-                              '() type-results comps)])
-                         (display "process-component env:\n")
-                         (display (substitute-env cs scc-env))
+                                (constraint-merge
+                                 (constraint-merge
+                                       ; unify with tvars from scc-env
+                                       ; result ~ tvar
+                                  (~ (env-lookup scc-env c) (cadr res))
+                                  (car res))                             
+                                 acc))
+                              '() type-results comps)]
+                                       ; substitute *only* the bindings in this scc
+                            [new-env
+                             (map (lambda (x)
+                                    (if (memv (car x) comps)
+                                        (cons (car x) (substitute cs (cdr x)))
+                                        x))
+                                  scc-env)])
+                                       (display "cs:")
+               (display cs)
                (newline)
-                         (substitute-env cs scc-env)))]
+                         new-env))]
                     [new-env (fold-left process-component env components)])
                (check new-env (last (let-body x)))))
          
-           ;; (let ((new-env (fold-left
-           ;;          (lambda (acc bind)
-           ;;            (let* [(bind-tvar (fresh-tvar))
-           ;;                   (env-with-tvar (env-insert acc (car bind) bind-tvar))
-           ;;                   (bind-res (check env-with-tvar (cadr bind)))
-           ;;                   (bind-type (cadr bind-res))
-           ;;                   (cs (consolidate (car bind-res)
-           ;;                                    (unify bind-type bind-tvar)))]
-           ;;              (substitute-env cs env-with-tvar)))
-           ;;          env (let-bindings x))))
-           ;;   (display "sccs of graph\n")
-           ;;   (display (sccs (graph (let-bindings x))))
-           ;;   (newline)
-           ;;   (display "env when checking body:\n\t")
-           ;;   (display new-env)
-           ;;   (newline)
-           ;;   (check new-env (last (let-body x)))))
-           
-
          ('lambda
              (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
 
                ;; (display "\n\t")
                ;; (display cs)
                ;; (display "\n\t")
+               ;; (display (format "subd-env: ~a\n" subd-env))
                ;; (display resolved-arg-type)
                ;; (newline)
                (list (car body-type-res)
               (let* [(func-type (env-lookup env (car x)))
                      (return-type (fresh-tvar))
                      (other-func-type `(abs ,func-type ,return-type))
-                       (cs (unify func-type other-func-type))]
-                  (list cs return-type))
+                     (cs (~ func-type other-func-type))
+                     (resolved-return-type (substitute cs return-type))]
+                (list cs resolved-return-type))
 
                                        ; regular function
               (let* ((arg-type-res (check env (cadr x)))
                      (func-type (cadr func-type-res))
                      
                                        ; f ~ a -> t0
-                       (func-c (unify func-type
-                                      (list 'abs
-                                            arg-type
-                                            (fresh-tvar))))
-                       (cs (consolidate
-                            (consolidate func-c (car arg-type-res))
+                     (func-c (~
+                              (substitute (car arg-type-res) func-type)
+                              `(abs ,arg-type ,(fresh-tvar))))
+                     (cs (constraint-merge
+                          (constraint-merge func-c (car arg-type-res))
                           (car func-type-res)))
                      
                      (resolved-func-type (substitute cs func-type))
     (display ":\n\t")
     (display (pretty-type (cadr res)))
     (display "\n\t[")
-      (display (car res))
+    (display (pretty-constraints (car res)))
     (display "]\n")
     res))
+
+                                       ; we typecheck the lambda calculus only (only single arg lambdas)
+(define (typecheck prog)
   (cadr (check '() (normalize prog))))
 
-                                       ; returns a list of pairs of constraints
-(define (unify a b)
-  (cond ((eq? a b) '())
-       ((or (tvar? a) (tvar? b)) (~ a b))
-       ((and (abs? a) (abs? b))
-        (let* [(arg-cs (unify (cadr a) (cadr b)))
-               (body-cs (unify (substitute arg-cs (caddr a))
+                                       ; returns a list of constraints
+(define (~ a b)
+  (let ([res (unify? a b)])
+    (if res
+       res
+       (error #f
+              (format "couldn't unify ~a ~~ ~a" a b)))))
+
+(define (unify? a b)
+  (cond [(eq? a b) '()]
+       [(tvar? a) (list (cons a b))]
+       [(tvar? b) (list (cons b a))]
+       [(and (abs? a) (abs? b))
+        (let* [(arg-cs (unify? (cadr a) (cadr b)))
+               (body-cs (unify? (substitute arg-cs (caddr a))
                                 (substitute arg-cs (caddr b))))]
-          (consolidate arg-cs body-cs)))
-       (else (error #f "could not unify"))))
+          (constraint-merge body-cs arg-cs))]
+       [else #f]))
 
-                                       ; TODO: what's the most appropriate substitution?
-                                       ; should all constraints just be limited to a pair?
 (define (substitute cs t)
-                                       ; gets the first concrete type
-                                       ; otherwise returns the last type variable
-
-  (define cs-without-t
-    (map (lambda (c)
-          (filter (lambda (x) (not (eqv? t x))) c))
-        cs))
-
-  (define (get-concrete c)
-    (let [(last (null? (cdr c)))]
-      (if (not (tvar? (car c)))
-         (if (abs? (car c))
-             (substitute cs-without-t (car c))
-             (car c))
-         (if last
-             (car c)
-             (get-concrete (cdr c))))))
-  
   (cond
-   ((abs? t) (list 'abs
-                  (substitute cs (cadr t))
-                  (substitute cs (caddr t))))
-   (else
-    (fold-left
-     (lambda (t c)
-       (if (member t c)
-          (get-concrete c)
-          t))
-     t cs))))
-
+   [(tvar? t)
+    (if (assoc t cs)
+       (cdr (assoc t cs))
+       t)]
+   [(abs? t) `(abs ,(substitute cs (cadr t))
+                  ,(substitute cs (caddr t)))]
+   [else t]))
+
+                                       ; applies substitutions to all variables in environment
 (define (substitute-env cs env)
   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
 
-(define (~ a b)
-  (list (list a b)))
-
-(define (consolidate x y)
-  (define (merge a b)
-    (cond ((null? a) b)
-         ((null? b) a)
-         (else (if (member (car b) a)
-                   (merge a (cdr b))
-                   (cons (car b) (merge a (cdr b)))))))
-  (define (overlap? a b)
-    (if (or (null? a) (null? b))
-       #f
-       (if (fold-left (lambda (acc v)
-                        (or acc (eq? v (car a))))
-                      #f b)
-           #t
-           (overlap? (cdr a) b))))
-
-  (cond ((null? y) x)
-       ((null? x) y)
-       (else
-        (let* ((a (car y))
-               (merged (fold-left
-                        (lambda (acc b)
-                          (if acc
-                              acc
-                              (if (overlap? a b)
-                                  (cons (merge a b) b)
-                                  #f)))
-                        #f x))
-               (removed (if merged
-                            (filter (lambda (b) (not (eq? b (cdr merged)))) x)
-                            x)))
-          (if merged
-              (consolidate removed (cons (car merged) (cdr y)))
-              (consolidate (cons a x) (cdr y)))))))
-
-                                       ; a1 -> a2 ~ a3 -> a4;
-                                       ; a1 -> a2 !~ bool -> bool
-                                       ; basically can the tvars be renamed
+                                       ; composes constraints a onto b and merges, i.e. applies a to b
+                                       ; a should be the "more important" constraints
+(define (constraint-merge a b)
+  (define (f constraint)
+    (cons (car constraint)
+         (substitute a (cdr constraint))))
+  
+  (define (most-concrete a b)
+    (cond
+     [(tvar? a) b]
+     [(tvar? b) a]
+     [(and (abs? a) (abs? b))
+      `(abs ,(most-concrete (cadr a) (cadr b))
+           ,(most-concrete (caddr a) (caddr b)))]
+     [(abs? a) b]
+     [(abs? b) a]
+     [else (error #f "impossible! most-concrete")]))
+
+  (define (union p q)
+    (cond
+     [(null? p) q]
+     [(null? q) p]
+     [else
+      (let ([x (car q)])
+       (if (assoc (car x) p)
+           (if (eqv? (most-concrete (cddr (assoc (car x) p))
+                                    (cdr x))
+                     (cdr x))
+               (cons x (union (filter (p) (not (eqv? 
+       
+  
+  (define (union p q)
+    (append (filter (lambda (x) (not (assoc (car x) p)))
+                   q)
+           p))
+  (union a (map f b)))
+
+
+;;                                     ; a1 -> a2 ~ a3 -> a4;
+;;                                     ; a1 -> a2 !~ bool -> bool
+;;                                     ; basically can the tvars be renamed
 (define (types-equal? x y)
-  (error #f "todo"))
+  (let ([cs (unify? x y)])
+    (if (not cs) #f    
+       (let*
+           ([test (lambda (acc c)
+                    (and acc
+                         (tvar? (car c)) ; the only substitutions allowed are tvar -> tvar
+                         (tvar? (cdr c))))])
+         (fold-left test #t cs)))))
 
                                        ; input: a list of binds ((x . y) (y . 3))
                                        ; returns: pair of verts, edges ((x y) . (x . y))
 (define (graph bs)
+  (define (go bs orig-bs)
     (define (find-refs prog)
       (ast-collect
        (lambda (x)
         (case (ast-type x)
                                        ; only count a reference if its a binding
-        ['var (if (assoc x bs) (list x) '())]
+          ['var (if (assoc x orig-bs) (list x) '())]
           [else '()]))
        prog))
+    (if (null? bs)
+       '(() . ())
        (let* [(bind (car bs))
 
               (vert (car bind))
 
               (rest (if (null? (cdr bs))
                         (cons '() '())
-                  (graph (cdr bs))))
+                        (go (cdr bs) orig-bs)))
               (total-verts (cons vert (car rest)))
               (total-edges (append edges (cdr rest)))]
-    (cons total-verts total-edges)))
+         (cons total-verts total-edges))))
+  (go bs bs))
 
 (define (successors graph v)
   (define (go v E)
                               (list w)
                               (cons w (new-scc)))))])
                  (set! result (cons scc result))))))])
-      
       (for-each
        (lambda (v)
         (when (not (hashtable-contains? indices v)) ; v.index == -1