Break up lets into SCCs before typechecking
authorLuke Lau <luke_lau@icloud.com>
Sun, 28 Jul 2019 22:56:37 +0000 (23:56 +0100)
committerLuke Lau <luke_lau@icloud.com>
Sun, 28 Jul 2019 22:56:37 +0000 (23:56 +0100)
This means each binding gets the most general type possible.
Part of the work towards supporting typechecking recursive functions

ast.scm
codegen.scm
tests.scm
typecheck.scm

diff --git a/ast.scm b/ast.scm
index a0899192e21a1d44b0fe3296e380604133a3f1d2..a38a01257aea179e4638d376400fcc77e1c9318e 100644 (file)
--- a/ast.scm
+++ b/ast.scm
     ('if `(if ,@(map f (cdr x))))
     (else x)))
 
+(define (ast-collect f x)
+  (define (inner y) (ast-collect f y))
+  (case (ast-type x)
+    ['let (append (f x)
+                 (fold-map inner (let-bindings x))
+                 (fold-map inner (let-body x)))]
+    ['app (append (f x)
+                 (fold-map inner x))]
+    ['lambda (append (f x)
+                    (inner (lambda-body x)))]
+    ['if (append (f x)
+                (fold-map inner (cdr x)))]
+    [else (f x)]))
+
 (define let-bindings cadr)
 (define let-body cddr)
 
 ; for use elsewhere
 (define lambda-args cadr)
 (define lambda-body caddr)
+
+; utils
+(define (fold-map f x) (fold-left append '() (map f x)))
+(define (repeat x n) (if (<= n 0) '()
+                        (cons x (repeat x (- n 1)))))
+
+
+(define-syntax push!
+  (syntax-rules ()
+    ((_ s x) (set! s (cons x s)))))
+
+(define-syntax pop!
+  (syntax-rules ()
+    ((_ s) (let ([x (car s)])
+            (set! s (cdr s))
+            x))))
index 4e05bcf797b31d175a9629aac4b75868bbd9c72e..2d60c0aed64c22d3f1933544bf1a5df7101ce62c 100644 (file)
   (emit "not %rcx")      ; -%rcx = strlen + 1
   (emit "dec %rcx")
   
-  (case target
-    ('darwin
   (emit "movq %rbx, %rsi") ; string addr
   (emit "movq %rcx, %rdx") ; num bytes
   (emit "movq $1, %rdi")   ; file handle (stdout)
-     (emit "movq $0x2000004, %rax")) ; syscall 4 (write)
-    ('linux
-     (emit "mov %rbx, %rsi")  ; string addr
-     (emit "mov %rcx, %rdx")  ; num bytes
-     (emit "mov $1, %rax")    ; file handle (stdout)
-     (emit "mov $1, %rdi"))) ; syscall 1 (write)
+  (case target
+    ('darwin (emit "mov $0x2000004, %rax")) ; syscall 4 (write)
+    ('linux  (emit "mov $1, %rax"))) ; syscall 1 (write)
   (emit "syscall"))
 
 (define (range s n)
     (else (error #f "don't know how to codegen this"))))
 
 
-(define (fold-map f x) (fold-left append '() (map f x)))
 
 (define (free-vars prog)
   (define bound '())
index 4ff6d760315700a6b862b1ce0818dafb78956640..6b66bec1ae3b6de291a2636778fc2656a6c246a6 100644 (file)
--- a/tests.scm
+++ b/tests.scm
@@ -1,12 +1,15 @@
 (load "codegen.scm")
 (load "typecheck.scm")
 
-(define (test actual expected)
-  (when (not (equal? actual expected))
+(define (test-f pred actual expected)
+  (when (not (pred actual expected))
     (error #f
           (format "test failed:\nexpected: ~a\nactual:   ~a"
                   expected actual))))
 
+(define (test . xs) (apply test-f (cons equal? xs)))
+(define (test-types . xs) (apply test-f (cons types-unify? xs)))
+
 (define (read-file file)
   (call-with-input-file file
     (lambda (p)
   (let ((str (read-file "/tmp/test-output.txt")))
     (test str output)))
 
-(test (typecheck '(lambda (x) (+ ((lambda (y) (x y 3)) 5) 2)))
+(test-types (typecheck '(lambda (x) (+ ((lambda (y) (x y 3)) 5) 2)))
            '(abs (abs int (abs int int)) int))
 
+                                       ; recursive types
+
+(test-types (substitute '((t1 (abs t1 t10))) 't1) '(abs t1 t10))
+
+(test-types (typecheck '(let ([bar (lambda (y) y)]
+                             [foo (lambda (x) (foo (bar #t)))])
+                         foo))
+           '(abs bool t0))
+
+(test-types (typecheck '(let ([bar (lambda (y) y)]
+                       [foo (lambda (x) (foo (bar #t)))])
+                   bar))
+      '(abs t0 t0))
+
 (test-prog '(+ 1 2) 3)
 (test-prog '((lambda (x) ((lambda (y) (+ x y)) 42)) 100) 142)
 
@@ -59,3 +76,8 @@
 (test-prog-stdout '(let () ((lambda (f) (f "foo")) print) 0) "foo")
 (test-prog '((lambda (f) (f 3 3)) (lambda (x y) (bool->int (= x y)))) 1)
 (test-prog '(bool->int ((lambda (f) (! (f 2 3))) =)) 1)
+
+                                       ; recursion (hangs at typechecking)
+(test-prog '(let [(fac (lambda (f n x) (if (= n 0) x (f f (- n 1) (* x x)))))]
+             (fac fac 3 2))
+          8)
index 7eb4fa96606d786b5e58d17f2d2601f4f52c85a3..25a0e45685f5bbc87a85daf308bb541cb76a7960 100644 (file)
@@ -58,8 +58,8 @@
     ('app
      (if (null? (cddr prog))
         `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
-        `(,(list (normalize (car prog)) (normalize (cadr prog)))
-          ,(normalize (caddr prog))))) ; (f a b)
+        (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
+                     ,@(cddr prog))))) ; (f a b)
     ('let
        (append (list 'let
                      (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
   (define (check env x)
-    ;; (display "check: ")
-    ;; (display x)
-    ;; (display "\n\t")
-    ;; (display env)
-    ;; (newline)
+    (display "check: ")
+    (display x)
+    (display "\n\t")
+    (display env)
+    (newline)
     (let
        ((res
          (case (ast-type x)
            
            ('var (list '() (env-lookup env x)))
            ('let
-           (let ((new-env (fold-left
-                           (lambda (acc bind)
-                             (let ((t (check
-                                       (env-insert acc (car bind) (fresh-tvar))
-                                       (cadr bind))))
-                               (env-insert acc (car bind) (cadr t))))
-                           env (let-bindings x))))
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+             (let* ([components (reverse (sccs (graph (let-bindings x))))]
+                    [process-component
+                     (lambda (acc comps)
+                       (display comps)
+                       (newline)
+                       (let*
+                           ([scc-env
+                             (fold-left
+                              (lambda (acc c)
+                                (env-insert acc c (fresh-tvar)))
+                              acc comps)]
+                            [type-results
+                             (map
+                              (lambda (c)
+                                (begin (display scc-env) (newline)
+                                (let ([body (cadr (assoc c (let-bindings x)))])
+                                  (display body)(newline)(check scc-env body))))
+                              comps)]
+                            [cs
+                             (fold-left
+                              (lambda (acc res c)
+                                (consolidate
+                                 acc
+                                 (unify (cadr res) (env-lookup scc-env c))))
+                              '() type-results comps)])
+                         (display "process-component env:\n")
+                         (display (substitute-env cs scc-env))
+                         (newline)
+                         (substitute-env cs scc-env)))]
+                    [new-env (fold-left process-component env components)])
                (check new-env (last (let-body x)))))
            
+           ;; (let ((new-env (fold-left
+           ;;          (lambda (acc bind)
+           ;;            (let* [(bind-tvar (fresh-tvar))
+           ;;                   (env-with-tvar (env-insert acc (car bind) bind-tvar))
+           ;;                   (bind-res (check env-with-tvar (cadr bind)))
+           ;;                   (bind-type (cadr bind-res))
+           ;;                   (cs (consolidate (car bind-res)
+           ;;                                    (unify bind-type bind-tvar)))]
+           ;;              (substitute-env cs env-with-tvar)))
+           ;;          env (let-bindings x))))
+           ;;   (display "sccs of graph\n")
+           ;;   (display (sccs (graph (let-bindings x))))
+           ;;   (newline)
+           ;;   (display "env when checking body:\n\t")
+           ;;   (display new-env)
+           ;;   (newline)
+           ;;   (check new-env (last (let-body x)))))
+           
 
            ('lambda
-           (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+               (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+
                       (body-type-res (check new-env (lambda-body x)))
                       (cs (car body-type-res))
                       (subd-env (substitute-env (car body-type-res) new-env))
                       (arg-type (env-lookup subd-env (lambda-arg x)))
-                  (resolved-arg-type (substitute cs arg-type)))
+                      (resolved-arg-type (substitute cs arg-type))]
                  ;; (display "lambda:\n\t")
                  ;; (display prog)
                  ;; (display "\n\t")
                              (cadr body-type-res)))))
            
            ('app ; (f a)
+            (if (eqv? (car x) (cadr x))
+                                       ; recursive function (f f)
+                (let* [(func-type (env-lookup env (car x)))
+                       (return-type (fresh-tvar))
+                       (other-func-type `(abs ,func-type ,return-type))
+                       (cs (unify func-type other-func-type))]
+                  (list cs return-type))
+
+                                       ; regular function
                 (let* ((arg-type-res (check env (cadr x)))
                        (arg-type (cadr arg-type-res))
                        (func-type-res (check env (car x)))
                   (if (abs? resolved-func-type)
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
-                 (error #f "not a function")))))))
-      ;; (display "result of ")
-      ;; (display x)
-      ;; (display ":\n\t")
-      ;; (display (cadr res))
-      ;; (display "[")
-      ;; (display (car res))
-      ;; (display "]\n")
+                      (error #f "not a function"))))))))
+      (display "result of ")
+      (display x)
+      (display ":\n\t")
+      (display (pretty-type (cadr res)))
+      (display "\n\t[")
+      (display (car res))
+      (display "]\n")
       res))
   (cadr (check '() (normalize prog))))
 
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
-        (consolidate (unify (cadr a) (cadr b))
-                     (unify (caddr a) (caddr b))))
+        (let* [(arg-cs (unify (cadr a) (cadr b)))
+               (body-cs (unify (substitute arg-cs (caddr a))
+                               (substitute arg-cs (caddr b))))]
+          (consolidate arg-cs body-cs)))
        (else (error #f "could not unify"))))
 
                                        ; TODO: what's the most appropriate substitution?
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
 
+  (define cs-without-t
+    (map (lambda (c)
+          (filter (lambda (x) (not (eqv? t x))) c))
+        cs))
+
   (define (get-concrete c)
-    (let ((last (null? (cdr c))))
+    (let [(last (null? (cdr c)))]
       (if (not (tvar? (car c)))
          (if (abs? (car c))
-             (substitute cs (car c))
+             (substitute cs-without-t (car c))
              (car c))
          (if last
              (car c)
              (get-concrete (cdr c))))))
+  
   (cond
    ((abs? t) (list 'abs
                   (substitute cs (cadr t))
 
   (cond ((null? y) x)
        ((null? x) y)
-       (else (let* ((a (car y))
+       (else
+        (let* ((a (car y))
                (merged (fold-left
                         (lambda (acc b)
                           (if acc
           (if merged
               (consolidate removed (cons (car merged) (cdr y)))
               (consolidate (cons a x) (cdr y)))))))
+
+                                       ; a1 -> a2 ~ a3 -> a4;
+                                       ; a1 -> a2 !~ bool -> bool
+                                       ; basically can the tvars be renamed
+(define (types-equal? x y)
+  (error #f "todo"))
+
+                                       ; input: a list of binds ((x . y) (y . 3))
+                                       ; returns: pair of verts, edges ((x y) . (x . y))
+(define (graph bs)
+  (define (find-refs prog)
+    (ast-collect
+     (lambda (x)
+       (case (ast-type x)
+                                       ; only count a reference if its a binding
+        ['var (if (assoc x bs) (list x) '())]
+        [else '()]))
+     prog))
+  (let* [(bind (car bs))
+
+        (vert (car bind))
+        (refs (find-refs (cdr bind)))
+        (edges (map (lambda (x) (cons vert x))
+                    refs))
+
+        (rest (if (null? (cdr bs))
+                  (cons '() '())
+                  (graph (cdr bs))))
+        (total-verts (cons vert (car rest)))
+        (total-edges (append edges (cdr rest)))]
+    (cons total-verts total-edges)))
+
+(define (successors graph v)
+  (define (go v E)
+    (if (null? E)
+       '()
+       (if (eqv? v (caar E))
+           (cons (cdar E) (go v (cdr E)))
+           (go v (cdr E)))))
+  (go v (cdr graph)))
+
+                                       ; takes in a graph (pair of vertices, edges)
+                                       ; returns a list of strongly connected components
+
+                                       ; ((x y w) . ((x . y) (x . w) (w . x))
+
+                                       ; =>
+                                       ; .->x->y
+                                       ; |  |
+                                       ; |  v
+                                       ; .--w
+
+                                       ; ((x w) (y))
+
+                                       ; this uses tarjan's algorithm, to get reverse
+                                       ; topological sorting for free
+(define (sccs graph)
+  
+  (let* ([indices (make-hash-table)]
+        [lowlinks (make-hash-table)]
+        [on-stack (make-hash-table)]
+        [current 0]
+        [stack '()]
+        [result '()])
+
+    (define (index v)
+      (get-hash-table indices v #f))
+    (define (lowlink v)
+      (get-hash-table lowlinks v #f))
+
+    (letrec
+       ([strong-connect
+         (lambda (v)
+           (begin
+             (put-hash-table! indices v current)
+             (put-hash-table! lowlinks v current)
+             (set! current (+ current 1))
+             (push! stack v)
+             (put-hash-table! on-stack v #t)
+
+             (for-each
+              (lambda (w)
+                (if (not (hashtable-contains? indices w))
+                                       ; successor w has not been visited, recurse
+                    (begin
+                      (strong-connect w)
+                      (put-hash-table! lowlinks
+                                       v
+                                       (min (lowlink v) (lowlink w))))
+                                       ; successor w has been visited
+                    (when (get-hash-table on-stack w #f)
+                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
+              (successors graph v))
+
+             (when (= (index v) (lowlink v))
+               (let ([scc
+                      (let new-scc ()
+                        (let ([w (pop! stack)])
+                          (put-hash-table! on-stack w #f)
+                          (if (eqv? w v)
+                              (list w)
+                              (cons w (new-scc)))))])
+                 (set! result (cons scc result))))))])
+      
+      (for-each
+       (lambda (v)
+        (when (not (hashtable-contains? indices v)) ; v.index == -1
+          (strong-connect v)))
+       (car graph)))
+    result))
+