Break up lets into SCCs before typechecking
authorLuke Lau <luke_lau@icloud.com>
Sun, 28 Jul 2019 22:56:37 +0000 (23:56 +0100)
committerLuke Lau <luke_lau@icloud.com>
Sun, 28 Jul 2019 22:56:37 +0000 (23:56 +0100)
This means each binding gets the most general type possible.
Part of the work towards supporting typechecking recursive functions

ast.scm
codegen.scm
tests.scm
typecheck.scm

diff --git a/ast.scm b/ast.scm
index a0899192e21a1d44b0fe3296e380604133a3f1d2..a38a01257aea179e4638d376400fcc77e1c9318e 100644 (file)
--- a/ast.scm
+++ b/ast.scm
     ('if `(if ,@(map f (cdr x))))
     (else x)))
 
     ('if `(if ,@(map f (cdr x))))
     (else x)))
 
+(define (ast-collect f x)
+  (define (inner y) (ast-collect f y))
+  (case (ast-type x)
+    ['let (append (f x)
+                 (fold-map inner (let-bindings x))
+                 (fold-map inner (let-body x)))]
+    ['app (append (f x)
+                 (fold-map inner x))]
+    ['lambda (append (f x)
+                    (inner (lambda-body x)))]
+    ['if (append (f x)
+                (fold-map inner (cdr x)))]
+    [else (f x)]))
+
 (define let-bindings cadr)
 (define let-body cddr)
 
 (define let-bindings cadr)
 (define let-body cddr)
 
 ; for use elsewhere
 (define lambda-args cadr)
 (define lambda-body caddr)
 ; for use elsewhere
 (define lambda-args cadr)
 (define lambda-body caddr)
+
+; utils
+(define (fold-map f x) (fold-left append '() (map f x)))
+(define (repeat x n) (if (<= n 0) '()
+                        (cons x (repeat x (- n 1)))))
+
+
+(define-syntax push!
+  (syntax-rules ()
+    ((_ s x) (set! s (cons x s)))))
+
+(define-syntax pop!
+  (syntax-rules ()
+    ((_ s) (let ([x (car s)])
+            (set! s (cdr s))
+            x))))
index 4e05bcf797b31d175a9629aac4b75868bbd9c72e..2d60c0aed64c22d3f1933544bf1a5df7101ce62c 100644 (file)
   (emit "not %rcx")      ; -%rcx = strlen + 1
   (emit "dec %rcx")
   
   (emit "not %rcx")      ; -%rcx = strlen + 1
   (emit "dec %rcx")
   
-  (case target
-    ('darwin
   (emit "movq %rbx, %rsi") ; string addr
   (emit "movq %rcx, %rdx") ; num bytes
   (emit "movq $1, %rdi")   ; file handle (stdout)
   (emit "movq %rbx, %rsi") ; string addr
   (emit "movq %rcx, %rdx") ; num bytes
   (emit "movq $1, %rdi")   ; file handle (stdout)
-     (emit "movq $0x2000004, %rax")) ; syscall 4 (write)
-    ('linux
-     (emit "mov %rbx, %rsi")  ; string addr
-     (emit "mov %rcx, %rdx")  ; num bytes
-     (emit "mov $1, %rax")    ; file handle (stdout)
-     (emit "mov $1, %rdi"))) ; syscall 1 (write)
+  (case target
+    ('darwin (emit "mov $0x2000004, %rax")) ; syscall 4 (write)
+    ('linux  (emit "mov $1, %rax"))) ; syscall 1 (write)
   (emit "syscall"))
 
 (define (range s n)
   (emit "syscall"))
 
 (define (range s n)
     (else (error #f "don't know how to codegen this"))))
 
 
     (else (error #f "don't know how to codegen this"))))
 
 
-(define (fold-map f x) (fold-left append '() (map f x)))
 
 (define (free-vars prog)
   (define bound '())
 
 (define (free-vars prog)
   (define bound '())
index 4ff6d760315700a6b862b1ce0818dafb78956640..6b66bec1ae3b6de291a2636778fc2656a6c246a6 100644 (file)
--- a/tests.scm
+++ b/tests.scm
@@ -1,12 +1,15 @@
 (load "codegen.scm")
 (load "typecheck.scm")
 
 (load "codegen.scm")
 (load "typecheck.scm")
 
-(define (test actual expected)
-  (when (not (equal? actual expected))
+(define (test-f pred actual expected)
+  (when (not (pred actual expected))
     (error #f
           (format "test failed:\nexpected: ~a\nactual:   ~a"
                   expected actual))))
 
     (error #f
           (format "test failed:\nexpected: ~a\nactual:   ~a"
                   expected actual))))
 
+(define (test . xs) (apply test-f (cons equal? xs)))
+(define (test-types . xs) (apply test-f (cons types-unify? xs)))
+
 (define (read-file file)
   (call-with-input-file file
     (lambda (p)
 (define (read-file file)
   (call-with-input-file file
     (lambda (p)
   (let ((str (read-file "/tmp/test-output.txt")))
     (test str output)))
 
   (let ((str (read-file "/tmp/test-output.txt")))
     (test str output)))
 
-(test (typecheck '(lambda (x) (+ ((lambda (y) (x y 3)) 5) 2)))
+(test-types (typecheck '(lambda (x) (+ ((lambda (y) (x y 3)) 5) 2)))
            '(abs (abs int (abs int int)) int))
 
            '(abs (abs int (abs int int)) int))
 
+                                       ; recursive types
+
+(test-types (substitute '((t1 (abs t1 t10))) 't1) '(abs t1 t10))
+
+(test-types (typecheck '(let ([bar (lambda (y) y)]
+                             [foo (lambda (x) (foo (bar #t)))])
+                         foo))
+           '(abs bool t0))
+
+(test-types (typecheck '(let ([bar (lambda (y) y)]
+                       [foo (lambda (x) (foo (bar #t)))])
+                   bar))
+      '(abs t0 t0))
+
 (test-prog '(+ 1 2) 3)
 (test-prog '((lambda (x) ((lambda (y) (+ x y)) 42)) 100) 142)
 
 (test-prog '(+ 1 2) 3)
 (test-prog '((lambda (x) ((lambda (y) (+ x y)) 42)) 100) 142)
 
@@ -59,3 +76,8 @@
 (test-prog-stdout '(let () ((lambda (f) (f "foo")) print) 0) "foo")
 (test-prog '((lambda (f) (f 3 3)) (lambda (x y) (bool->int (= x y)))) 1)
 (test-prog '(bool->int ((lambda (f) (! (f 2 3))) =)) 1)
 (test-prog-stdout '(let () ((lambda (f) (f "foo")) print) 0) "foo")
 (test-prog '((lambda (f) (f 3 3)) (lambda (x y) (bool->int (= x y)))) 1)
 (test-prog '(bool->int ((lambda (f) (! (f 2 3))) =)) 1)
+
+                                       ; recursion (hangs at typechecking)
+(test-prog '(let [(fac (lambda (f n x) (if (= n 0) x (f f (- n 1) (* x x)))))]
+             (fac fac 3 2))
+          8)
index 7eb4fa96606d786b5e58d17f2d2601f4f52c85a3..25a0e45685f5bbc87a85daf308bb541cb76a7960 100644 (file)
@@ -58,8 +58,8 @@
     ('app
      (if (null? (cddr prog))
         `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
     ('app
      (if (null? (cddr prog))
         `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
-        `(,(list (normalize (car prog)) (normalize (cadr prog)))
-          ,(normalize (caddr prog))))) ; (f a b)
+        (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
+                     ,@(cddr prog))))) ; (f a b)
     ('let
        (append (list 'let
                      (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
     ('let
        (append (list 'let
                      (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
   (define (check env x)
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
   (define (check env x)
-    ;; (display "check: ")
-    ;; (display x)
-    ;; (display "\n\t")
-    ;; (display env)
-    ;; (newline)
+    (display "check: ")
+    (display x)
+    (display "\n\t")
+    (display env)
+    (newline)
     (let
        ((res
          (case (ast-type x)
     (let
        ((res
          (case (ast-type x)
            
            ('var (list '() (env-lookup env x)))
            ('let
            
            ('var (list '() (env-lookup env x)))
            ('let
-           (let ((new-env (fold-left
-                           (lambda (acc bind)
-                             (let ((t (check
-                                       (env-insert acc (car bind) (fresh-tvar))
-                                       (cadr bind))))
-                               (env-insert acc (car bind) (cadr t))))
-                           env (let-bindings x))))
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+             (let* ([components (reverse (sccs (graph (let-bindings x))))]
+                    [process-component
+                     (lambda (acc comps)
+                       (display comps)
+                       (newline)
+                       (let*
+                           ([scc-env
+                             (fold-left
+                              (lambda (acc c)
+                                (env-insert acc c (fresh-tvar)))
+                              acc comps)]
+                            [type-results
+                             (map
+                              (lambda (c)
+                                (begin (display scc-env) (newline)
+                                (let ([body (cadr (assoc c (let-bindings x)))])
+                                  (display body)(newline)(check scc-env body))))
+                              comps)]
+                            [cs
+                             (fold-left
+                              (lambda (acc res c)
+                                (consolidate
+                                 acc
+                                 (unify (cadr res) (env-lookup scc-env c))))
+                              '() type-results comps)])
+                         (display "process-component env:\n")
+                         (display (substitute-env cs scc-env))
+                         (newline)
+                         (substitute-env cs scc-env)))]
+                    [new-env (fold-left process-component env components)])
                (check new-env (last (let-body x)))))
            
                (check new-env (last (let-body x)))))
            
+           ;; (let ((new-env (fold-left
+           ;;          (lambda (acc bind)
+           ;;            (let* [(bind-tvar (fresh-tvar))
+           ;;                   (env-with-tvar (env-insert acc (car bind) bind-tvar))
+           ;;                   (bind-res (check env-with-tvar (cadr bind)))
+           ;;                   (bind-type (cadr bind-res))
+           ;;                   (cs (consolidate (car bind-res)
+           ;;                                    (unify bind-type bind-tvar)))]
+           ;;              (substitute-env cs env-with-tvar)))
+           ;;          env (let-bindings x))))
+           ;;   (display "sccs of graph\n")
+           ;;   (display (sccs (graph (let-bindings x))))
+           ;;   (newline)
+           ;;   (display "env when checking body:\n\t")
+           ;;   (display new-env)
+           ;;   (newline)
+           ;;   (check new-env (last (let-body x)))))
+           
 
            ('lambda
 
            ('lambda
-           (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+               (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+
                       (body-type-res (check new-env (lambda-body x)))
                       (cs (car body-type-res))
                       (subd-env (substitute-env (car body-type-res) new-env))
                       (arg-type (env-lookup subd-env (lambda-arg x)))
                       (body-type-res (check new-env (lambda-body x)))
                       (cs (car body-type-res))
                       (subd-env (substitute-env (car body-type-res) new-env))
                       (arg-type (env-lookup subd-env (lambda-arg x)))
-                  (resolved-arg-type (substitute cs arg-type)))
+                      (resolved-arg-type (substitute cs arg-type))]
                  ;; (display "lambda:\n\t")
                  ;; (display prog)
                  ;; (display "\n\t")
                  ;; (display "lambda:\n\t")
                  ;; (display prog)
                  ;; (display "\n\t")
                              (cadr body-type-res)))))
            
            ('app ; (f a)
                              (cadr body-type-res)))))
            
            ('app ; (f a)
+            (if (eqv? (car x) (cadr x))
+                                       ; recursive function (f f)
+                (let* [(func-type (env-lookup env (car x)))
+                       (return-type (fresh-tvar))
+                       (other-func-type `(abs ,func-type ,return-type))
+                       (cs (unify func-type other-func-type))]
+                  (list cs return-type))
+
+                                       ; regular function
                 (let* ((arg-type-res (check env (cadr x)))
                        (arg-type (cadr arg-type-res))
                        (func-type-res (check env (car x)))
                 (let* ((arg-type-res (check env (cadr x)))
                        (arg-type (cadr arg-type-res))
                        (func-type-res (check env (car x)))
                   (if (abs? resolved-func-type)
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
                   (if (abs? resolved-func-type)
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
-                 (error #f "not a function")))))))
-      ;; (display "result of ")
-      ;; (display x)
-      ;; (display ":\n\t")
-      ;; (display (cadr res))
-      ;; (display "[")
-      ;; (display (car res))
-      ;; (display "]\n")
+                      (error #f "not a function"))))))))
+      (display "result of ")
+      (display x)
+      (display ":\n\t")
+      (display (pretty-type (cadr res)))
+      (display "\n\t[")
+      (display (car res))
+      (display "]\n")
       res))
   (cadr (check '() (normalize prog))))
 
       res))
   (cadr (check '() (normalize prog))))
 
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
-        (consolidate (unify (cadr a) (cadr b))
-                     (unify (caddr a) (caddr b))))
+        (let* [(arg-cs (unify (cadr a) (cadr b)))
+               (body-cs (unify (substitute arg-cs (caddr a))
+                               (substitute arg-cs (caddr b))))]
+          (consolidate arg-cs body-cs)))
        (else (error #f "could not unify"))))
 
                                        ; TODO: what's the most appropriate substitution?
        (else (error #f "could not unify"))))
 
                                        ; TODO: what's the most appropriate substitution?
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
 
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
 
+  (define cs-without-t
+    (map (lambda (c)
+          (filter (lambda (x) (not (eqv? t x))) c))
+        cs))
+
   (define (get-concrete c)
   (define (get-concrete c)
-    (let ((last (null? (cdr c))))
+    (let [(last (null? (cdr c)))]
       (if (not (tvar? (car c)))
          (if (abs? (car c))
       (if (not (tvar? (car c)))
          (if (abs? (car c))
-             (substitute cs (car c))
+             (substitute cs-without-t (car c))
              (car c))
          (if last
              (car c)
              (get-concrete (cdr c))))))
              (car c))
          (if last
              (car c)
              (get-concrete (cdr c))))))
+  
   (cond
    ((abs? t) (list 'abs
                   (substitute cs (cadr t))
   (cond
    ((abs? t) (list 'abs
                   (substitute cs (cadr t))
 
   (cond ((null? y) x)
        ((null? x) y)
 
   (cond ((null? y) x)
        ((null? x) y)
-       (else (let* ((a (car y))
+       (else
+        (let* ((a (car y))
                (merged (fold-left
                         (lambda (acc b)
                           (if acc
                (merged (fold-left
                         (lambda (acc b)
                           (if acc
           (if merged
               (consolidate removed (cons (car merged) (cdr y)))
               (consolidate (cons a x) (cdr y)))))))
           (if merged
               (consolidate removed (cons (car merged) (cdr y)))
               (consolidate (cons a x) (cdr y)))))))
+
+                                       ; a1 -> a2 ~ a3 -> a4;
+                                       ; a1 -> a2 !~ bool -> bool
+                                       ; basically can the tvars be renamed
+(define (types-equal? x y)
+  (error #f "todo"))
+
+                                       ; input: a list of binds ((x . y) (y . 3))
+                                       ; returns: pair of verts, edges ((x y) . (x . y))
+(define (graph bs)
+  (define (find-refs prog)
+    (ast-collect
+     (lambda (x)
+       (case (ast-type x)
+                                       ; only count a reference if its a binding
+        ['var (if (assoc x bs) (list x) '())]
+        [else '()]))
+     prog))
+  (let* [(bind (car bs))
+
+        (vert (car bind))
+        (refs (find-refs (cdr bind)))
+        (edges (map (lambda (x) (cons vert x))
+                    refs))
+
+        (rest (if (null? (cdr bs))
+                  (cons '() '())
+                  (graph (cdr bs))))
+        (total-verts (cons vert (car rest)))
+        (total-edges (append edges (cdr rest)))]
+    (cons total-verts total-edges)))
+
+(define (successors graph v)
+  (define (go v E)
+    (if (null? E)
+       '()
+       (if (eqv? v (caar E))
+           (cons (cdar E) (go v (cdr E)))
+           (go v (cdr E)))))
+  (go v (cdr graph)))
+
+                                       ; takes in a graph (pair of vertices, edges)
+                                       ; returns a list of strongly connected components
+
+                                       ; ((x y w) . ((x . y) (x . w) (w . x))
+
+                                       ; =>
+                                       ; .->x->y
+                                       ; |  |
+                                       ; |  v
+                                       ; .--w
+
+                                       ; ((x w) (y))
+
+                                       ; this uses tarjan's algorithm, to get reverse
+                                       ; topological sorting for free
+(define (sccs graph)
+  
+  (let* ([indices (make-hash-table)]
+        [lowlinks (make-hash-table)]
+        [on-stack (make-hash-table)]
+        [current 0]
+        [stack '()]
+        [result '()])
+
+    (define (index v)
+      (get-hash-table indices v #f))
+    (define (lowlink v)
+      (get-hash-table lowlinks v #f))
+
+    (letrec
+       ([strong-connect
+         (lambda (v)
+           (begin
+             (put-hash-table! indices v current)
+             (put-hash-table! lowlinks v current)
+             (set! current (+ current 1))
+             (push! stack v)
+             (put-hash-table! on-stack v #t)
+
+             (for-each
+              (lambda (w)
+                (if (not (hashtable-contains? indices w))
+                                       ; successor w has not been visited, recurse
+                    (begin
+                      (strong-connect w)
+                      (put-hash-table! lowlinks
+                                       v
+                                       (min (lowlink v) (lowlink w))))
+                                       ; successor w has been visited
+                    (when (get-hash-table on-stack w #f)
+                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
+              (successors graph v))
+
+             (when (= (index v) (lowlink v))
+               (let ([scc
+                      (let new-scc ()
+                        (let ([w (pop! stack)])
+                          (put-hash-table! on-stack w #f)
+                          (if (eqv? w v)
+                              (list w)
+                              (cons w (new-scc)))))])
+                 (set! result (cons scc result))))))])
+      
+      (for-each
+       (lambda (v)
+        (when (not (hashtable-contains? indices v)) ; v.index == -1
+          (strong-connect v)))
+       (car graph)))
+    result))
+