Substitute only the variables in the scc
authorLuke Lau <luke_lau@icloud.com>
Sun, 28 Jul 2019 23:40:30 +0000 (00:40 +0100)
committerLuke Lau <luke_lau@icloud.com>
Sun, 28 Jul 2019 23:40:30 +0000 (00:40 +0100)
This preserves the most generic possible type

Also add a test for type equality (unifies, but only type variable
substitutions)
Fix graph when there's no bindings

tests.scm
typecheck.scm

index 6b66bec1ae3b6de291a2636778fc2656a6c246a6..65b99d9c5a3574e331228c21933528d1a3490e76 100644 (file)
--- a/tests.scm
+++ b/tests.scm
@@ -8,7 +8,7 @@
                   expected actual))))
 
 (define (test . xs) (apply test-f (cons equal? xs)))
-(define (test-types . xs) (apply test-f (cons types-unify? xs)))
+(define (test-types . xs) (apply test-f (cons types-equal? xs)))
 
 (define (read-file file)
   (call-with-input-file file
 (test-types (typecheck '(let ([bar (lambda (y) y)]
                              [foo (lambda (x) (foo (bar #t)))])
                          foo))
-           '(abs bool t0))
+           '(abs bool a))
 
 (test-types (typecheck '(let ([bar (lambda (y) y)]
                        [foo (lambda (x) (foo (bar #t)))])
                    bar))
-      '(abs t0 t0))
+      '(abs a a))
 
 (test-prog '(+ 1 2) 3)
 (test-prog '((lambda (x) ((lambda (y) (+ x y)) 42)) 100) 142)
index 25a0e45685f5bbc87a85daf308bb541cb76a7960..313db0e2a9d48fbdbc7fcaff2f89eab749f905b7 100644 (file)
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
   (define (check env x)
-    (display "check: ")
-    (display x)
-    (display "\n\t")
-    (display env)
-    (newline)
+    ;; (display "check: ")
+    ;; (display x)
+    ;; (display "\n\t")
+    ;; (display env)
+    ;; (newline)
     (let
        ((res
          (case (ast-type x)
              (let* ([components (reverse (sccs (graph (let-bindings x))))]
                     [process-component
                      (lambda (acc comps)
-                       (display comps)
-                       (newline)
                        (let*
+                                       ; create a new env with tvars for each component
+                                       ; e.g. scc of (x y)
+                                       ; scc-env = ((x . t0) (y . t1))
                            ([scc-env
                              (fold-left
                               (lambda (acc c)
                                 (env-insert acc c (fresh-tvar)))
                               acc comps)]
+                                       ; typecheck each component
                             [type-results
                              (map
                               (lambda (c)
-                                (begin (display scc-env) (newline)
                                 (let ([body (cadr (assoc c (let-bindings x)))])
-                                  (display body)(newline)(check scc-env body))))
+                                  (check scc-env body)))
                               comps)]
+                                       ; collect all the constraints in the scc
                             [cs
                              (fold-left
                               (lambda (acc res c)
                                 (consolidate
                                  acc
-                                 (unify (cadr res) (env-lookup scc-env c))))
-                              '() type-results comps)])
-                         (display "process-component env:\n")
-                         (display (substitute-env cs scc-env))
-                         (newline)
-                         (substitute-env cs scc-env)))]
+                                 (consolidate (car res)
+                                       ; unify with tvars from scc-env
+                                       ; result ~ tvar
+                                              (unify (cadr res) (env-lookup scc-env c)))))
+                              '() type-results comps)]
+                                       ; substitute *only* the bindings in this scc
+                            [new-env
+                             (map (lambda (x)
+                                    (if (memv (car x) comps)
+                                        (cons (car x) (substitute cs (cdr x)))
+                                        x))
+                                  scc-env)])
+                         new-env))]
                     [new-env (fold-left process-component env components)])
                (check new-env (last (let-body x)))))
            
-           ;; (let ((new-env (fold-left
-           ;;          (lambda (acc bind)
-           ;;            (let* [(bind-tvar (fresh-tvar))
-           ;;                   (env-with-tvar (env-insert acc (car bind) bind-tvar))
-           ;;                   (bind-res (check env-with-tvar (cadr bind)))
-           ;;                   (bind-type (cadr bind-res))
-           ;;                   (cs (consolidate (car bind-res)
-           ;;                                    (unify bind-type bind-tvar)))]
-           ;;              (substitute-env cs env-with-tvar)))
-           ;;          env (let-bindings x))))
-           ;;   (display "sccs of graph\n")
-           ;;   (display (sccs (graph (let-bindings x))))
-           ;;   (newline)
-           ;;   (display "env when checking body:\n\t")
-           ;;   (display new-env)
-           ;;   (newline)
-           ;;   (check new-env (last (let-body x)))))
-           
-
            ('lambda
                (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
 
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
                       (error #f "not a function"))))))))
-      (display "result of ")
-      (display x)
-      (display ":\n\t")
-      (display (pretty-type (cadr res)))
-      (display "\n\t[")
-      (display (car res))
-      (display "]\n")
+      ;; (display "result of ")
+      ;; (display x)
+      ;; (display ":\n\t")
+      ;; (display (pretty-type (cadr res)))
+      ;; (display "\n\t[")
+      ;; (display (car res))
+      ;; (display "]\n")
       res))
   (cadr (check '() (normalize prog))))
 
                                        ; returns a list of pairs of constraints
 (define (unify a b)
+  (let ([res (unify? a b)])
+    (if res
+       res
+       (error #f
+              (format "couldn't unify ~a ~~ ~a" a b)))))
+
+(define (unify? a b)
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
                (body-cs (unify (substitute arg-cs (caddr a))
                                (substitute arg-cs (caddr b))))]
           (consolidate arg-cs body-cs)))
-       (else (error #f "could not unify"))))
+       (else #f)))
 
                                        ; TODO: what's the most appropriate substitution?
                                        ; should all constraints just be limited to a pair?
                                        ; a1 -> a2 !~ bool -> bool
                                        ; basically can the tvars be renamed
 (define (types-equal? x y)
-  (error #f "todo"))
+  (let ([cs (unify? x y)])
+    (if (not cs) #f
+       (let*
+           ([test-kind
+             (lambda (acc c)
+               (if (tvar? c) acc #f))]
+            [test (lambda (acc c)
+                    (and acc (fold-left test-kind #t c)))])
+         (fold-left test #t cs)))))
 
                                        ; input: a list of binds ((x . y) (y . 3))
                                        ; returns: pair of verts, edges ((x y) . (x . y))
         ['var (if (assoc x bs) (list x) '())]
         [else '()]))
      prog))
+  (if (null? bs)
+      '(() . ())
       (let* [(bind (car bs))
 
             (vert (car bind))
                       (graph (cdr bs))))
             (total-verts (cons vert (car rest)))
             (total-edges (append edges (cdr rest)))]
-    (cons total-verts total-edges)))
+       (cons total-verts total-edges))))
 
 (define (successors graph v)
   (define (go v E)
                               (list w)
                               (cons w (new-scc)))))])
                  (set! result (cons scc result))))))])
-      
       (for-each
        (lambda (v)
         (when (not (hashtable-contains? indices v)) ; v.index == -1