WIP on typechecker refactor
[scheme.git] / typecheck.scm
index 0e4e265f2d2a47f81b8a13ea4ade24914c3b33a1..6a99869c0697302f361a9e28ac216013f5d018d6 100644 (file)
@@ -1,18 +1,50 @@
-(define (app? x)
-  (and (list? x) (>= (length x) 2) (not (eq? (car x) 'lambda))))
+(load "ast.scm")
 
-(define (lambda? x)
-  (and (list? x) (eq? (car x) 'lambda)))
+(define (abs? t)
+  (and (list? t) (eq? (car t) 'abs)))
+
+(define (tvar? t)
+  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
+
+(define (concrete? t)
+  (case t
+    ('int #t)
+    ('bool #t)
+    ('void #t)
+    (else #f)))
+
+(define (pretty-type t)
+  (cond ((abs? t)
+        (string-append
+         (if (abs? (cadr t))
+             (string-append "(" (pretty-type (cadr t)) ")")
+             (pretty-type (cadr t)))
+         " -> "
+         (pretty-type (caddr t))))
+       (else (symbol->string t))))
 
-(define lambda-arg caadr)
-(define lambda-body caddr)
+(define (pretty-constraints cs)
+  (string-append "{"
+                (fold-left string-append
+                           ""
+                           (map (lambda (c)
+                                  (string-append
+                                   (pretty-type (car c))
+                                   ": "
+                                   (pretty-type (cdr c))
+                                   ", "))
+                                cs))
+                "}"))
 
                                        ; ('a, ('b, 'a))
-(define (env-lookup env x)
+(define (env-lookup env n)
   (if (null? env) (error #f "empty env")                       ; it's a type equality
-      (if (eq? (caar env) x)
+      (if (eq? (caar env) n)
          (cdar env)
-         (env-lookup (cdr env) x))))
+         (env-lookup (cdr env) n))))
+
+(define (env-insert env n t)
+  (cons (cons n t) env))
 
 (define abs-arg cadr)
 
     (string->symbol
      (string-append "t" (number->string (- cur-tvar 1))))))
 
+(define (last xs)
+  (if (null? (cdr xs))
+      (car xs)
+      (last (cdr xs))))
 
 (define (normalize prog) ; (+ a b) -> ((+ a) b)
-  (cond
-   ((lambda? prog) '(lambda (lambda-arg prog) (normalize (lambda-body prog))))
-   ((app? prog)
+  (case (ast-type prog)
+    ('lambda 
+                                       ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
+       (if (> (length (lambda-args prog)) 1)
+           (list 'lambda (list (car (lambda-args prog)))
+                 (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
+           (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
+    ('app
      (if (null? (cddr prog))
-       (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
-       (normalize (cons (cons (car prog) (list (cadr prog))) (cddr prog))))) ; (f a b)
-   (else prog)))
+        `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
+        (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
+                     ,@(cddr prog))))) ; (f a b)
+    ('let
+       (append (list 'let
+                     (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
+                          (let-bindings prog)))
+               (map normalize (let-body prog))))
+    (else (ast-traverse normalize prog))))
 
+(define (builtin-type x)
+  (case x
+    ('+ '(abs int (abs int int)))
+    ('- '(abs int (abs int int)))
+    ('* '(abs int (abs int int)))
+    ('! '(abs bool bool))
+    ('= '(abs int (abs int bool)))
+    ('bool->int '(abs bool int))
+    ('print '(abs string void))
+    (else #f)))
 
-(define (typecheck prog)
 (define (check env x)
+  (display "check: ")
+  (display x)
+  (display "\n\t")
+  (display env)
+  (newline)
   (let
       ((res
-         (cond
-          ((integer? x) (list '() 'int))
-          ((boolean? x) (list '() 'bool))
-          ((eq? x 'inc) (list '() '(abs int int)))
-          ((eq? x '+)   (list '() '(abs int (abs int int))))
-          ((symbol? x) (list '() (env-lookup env x)))
-
-          ((lambda? x)
-           (let* ((new-env (cons (cons (lambda-arg x) (fresh-tvar)) env))
+       (case (ast-type x)
+         ('int-literal (list '() 'int))
+         ('bool-literal (list '() 'bool))
+         ('string-literal (list '() 'string))
+         ('builtin (list '() (builtin-type x)))
+
+         ('if
+          (let* ((cond-type-res (check env (cadr x)))
+                 (then-type-res (check env (caddr x)))
+                 (else-type-res (check env (cadddr x)))
+                 (then-eq-else-cs (~ (cadr then-type-res)
+                                     (cadr else-type-res)))
+                 (cs (constraint-merge
+                      (car then-type-res)
+                      (constraint-merge (car else-type-res)
+                                        then-eq-else-cs)))
+                 (return-type (substitute cs (cadr then-type-res))))
+            (when (not (eqv? (cadr cond-type-res) 'bool))
+              (error #f "if condition isn't bool"))
+            (list cs return-type)))
+         
+         ('var (list '() (env-lookup env x)))
+         ('let
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+             (let* ([components (reverse (sccs (graph (let-bindings x))))]
+                    [process-component
+                     (lambda (acc comps)
+                       (let*
+                                       ; create a new env with tvars for each component
+                                       ; e.g. scc of (x y)
+                                       ; scc-env = ((x . t0) (y . t1))
+                           ([scc-env
+                             (fold-left
+                              (lambda (acc c)
+                                (env-insert acc c (fresh-tvar)))
+                              acc comps)]
+                                       ; typecheck each component
+                            [type-results
+                             (map
+                              (lambda (c)
+                                (let ([body (cadr (assoc c (let-bindings x)))])
+                                  (check scc-env body)))
+                              comps)]
+                                       ; collect all the constraints in the scc
+                            [cs
+                             (fold-left
+                              (lambda (acc res c)
+                                (constraint-merge
+                                 (constraint-merge
+                                       ; unify with tvars from scc-env
+                                       ; result ~ tvar
+                                  (~ (env-lookup scc-env c) (cadr res))
+                                  (car res))                             
+                                 acc))
+                              '() type-results comps)]
+                                       ; substitute *only* the bindings in this scc
+                            [new-env
+                             (map (lambda (x)
+                                    (if (memv (car x) comps)
+                                        (cons (car x) (substitute cs (cdr x)))
+                                        x))
+                                  scc-env)])
+                                       (display "cs:")
+               (display cs)
+               (newline)
+                         new-env))]
+                    [new-env (fold-left process-component env components)])
+               (check new-env (last (let-body x)))))
+         
+         ('lambda
+             (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+
                     (body-type-res (check new-env (lambda-body x)))
-                  (subd-env (substitute-env (car body-type-res) new-env)))
-             ;; (display "lambda: ")
-             ;; (display body-type-res)
-             ;; (display "\n")
-             ;; (display subd-env)
-             ;; (display "\n")
+                    (cs (car body-type-res))
+                    (subd-env (substitute-env (car body-type-res) new-env))
+                    (arg-type (env-lookup subd-env (lambda-arg x)))
+                    (resolved-arg-type (substitute cs arg-type))]
+               ;; (display "lambda:\n\t")
+               ;; (display prog)
+               ;; (display "\n\t")
+               ;; (display cs)
+               ;; (display "\n\t")
+               ;; (display (format "subd-env: ~a\n" subd-env))
+               ;; (display resolved-arg-type)
+               ;; (newline)
                (list (car body-type-res)
                      (list 'abs
-                         (env-lookup subd-env (lambda-arg x))
+                           resolved-arg-type
                            (cadr body-type-res)))))
          
-          ((app? x) ; (f a)
+         ('app ; (f a)
+          (if (eqv? (car x) (cadr x))
+                                       ; recursive function (f f)
+              (let* [(func-type (env-lookup env (car x)))
+                     (return-type (fresh-tvar))
+                     (other-func-type `(abs ,func-type ,return-type))
+                     (cs (~ func-type other-func-type))
+                     (resolved-return-type (substitute cs return-type))]
+                (list cs resolved-return-type))
+
+                                       ; regular function
               (let* ((arg-type-res (check env (cadr x)))
                      (arg-type (cadr arg-type-res))
                      (func-type-res (check env (car x)))
                      (func-type (cadr func-type-res))
                      
                                        ; f ~ a -> t0
-                  (func-c (unify func-type
-                                 (list 'abs
-                                       arg-type
-                                       (fresh-tvar))))
-                  (cs (append func-c (car arg-type-res) (car func-type-res)))
+                     (func-c (~
+                              (substitute (car arg-type-res) func-type)
+                              `(abs ,arg-type ,(fresh-tvar))))
+                     (cs (constraint-merge
+                          (constraint-merge func-c (car arg-type-res))
+                          (car func-type-res)))
                      
                      (resolved-func-type (substitute cs func-type))
                      (resolved-return-type (caddr resolved-func-type)))
                 (if (abs? resolved-func-type)
                     (let ((return-type (substitute cs (caddr resolved-func-type))))
                       (list cs return-type))
-                 (error #f "wah")))))))
-      ;; (display "result of ")
-      ;; (display x)
-      ;; (display ":\n\t")
-      ;; (display (cadr res))
-      ;; (display "[")
-      ;; (display (car res))
-      ;; (display "]\n")
+                    (error #f "not a function"))))))))
+    (display "result of ")
+    (display x)
+    (display ":\n\t")
+    (display (pretty-type (cadr res)))
+    (display "\n\t[")
+    (display (pretty-constraints (car res)))
+    (display "]\n")
     res))
+
+                                       ; we typecheck the lambda calculus only (only single arg lambdas)
+(define (typecheck prog)
   (cadr (check '() (normalize prog))))
 
+                                       ; returns a list of constraints
+(define (~ a b)
+  (let ([res (unify? a b)])
+    (if res
+       res
+       (error #f
+              (format "couldn't unify ~a ~~ ~a" a b)))))
 
-(define (abs? t)
-  (and (list? t) (eq? (car t) 'abs)))
+(define (unify? a b)
+  (cond [(eq? a b) '()]
+       [(tvar? a) (list (cons a b))]
+       [(tvar? b) (list (cons b a))]
+       [(and (abs? a) (abs? b))
+        (let* [(arg-cs (unify? (cadr a) (cadr b)))
+               (body-cs (unify? (substitute arg-cs (caddr a))
+                                (substitute arg-cs (caddr b))))]
+          (constraint-merge body-cs arg-cs))]
+       [else #f]))
 
-(define (tvar? t)
-  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
+(define (substitute cs t)
+  (cond
+   [(tvar? t)
+    (if (assoc t cs)
+       (cdr (assoc t cs))
+       t)]
+   [(abs? t) `(abs ,(substitute cs (cadr t))
+                  ,(substitute cs (caddr t)))]
+   [else t]))
 
-(define (concrete? t)
-  (case t
-    ('int #t)
-    ('bool #t)
-    (else #f)))
+                                       ; applies substitutions to all variables in environment
+(define (substitute-env cs env)
+  (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
 
-                                       ; returns a list of pairs of constraints
-(define (unify a b)
-  (cond ((eq? a b) '())
-       ((or (tvar? a) (tvar? b)) (~ a b))
-       ((and (abs? a) (abs? b))
-        (consolidate (unify (cadr a) (cadr b))
-                     (unify (caddr a) (caddr b))))
-       (else (error #f "could not unify"))))
+                                       ; composes constraints a onto b and merges, i.e. applies a to b
+                                       ; a should be the "more important" constraints
+(define (constraint-merge a b)
+  (define (f constraint)
+    (cons (car constraint)
+         (substitute a (cdr constraint))))
   
+  (define (most-concrete a b)
+    (cond
+     [(tvar? a) b]
+     [(tvar? b) a]
+     [(and (abs? a) (abs? b))
+      `(abs ,(most-concrete (cadr a) (cadr b))
+           ,(most-concrete (caddr a) (caddr b)))]
+     [(abs? a) b]
+     [(abs? b) a]
+     [else (error #f "impossible! most-concrete")]))
 
-                                       ; TODO: what's the most appropriate substitution?
-                                       ; should all constraints just be limited to a pair?
-(define (substitute cs t)
-  (define (blah c)
-    (if (null? (cdr c))
-       (car c)
-       (if (not (tvar? (car c)))
-           (car c)
-           (blah (cdr c)))))
-  (fold-left
-   (lambda (t c)
-     (if (member t c)
-        (blah c)
-        t))
-   t cs))
+  (define (union p q)
+    (cond
+     [(null? p) q]
+     [(null? q) p]
+     [else
+      (let ([x (car q)])
+       (if (assoc (car x) p)
+           (if (eqv? (most-concrete (cddr (assoc (car x) p))
+                                    (cdr x))
+                     (cdr x))
+               (cons x (union (filter (p) (not (eqv? 
        
-(define (substitute-env cs env)
-  (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
   
-(define (~ a b)
-  (list (list a b)))
-
-(define (consolidate x y)
-  (define (merge a b)
-    (cond ((null? a) b)
-         ((null? b) a)
-         (else (if (member (car b) a)
-                   (merge a (cdr b))
-                   (cons (car b) (merge a (cdr b)))))))
-  (define (overlap? a b)
-    (if (or (null? a) (null? b))
-       #f
-       (if (fold-left (lambda (acc v)
-                        (or acc (eq? v (car a))))
-                      #f b)
-           #t
-           (overlap? (cdr a) b))))
-
-  (cond ((null? y) x)
-       ((null? x) y)
-       (else (let* ((a (car y))
-                    (merged (fold-left
-                             (lambda (acc b)
-                               (if acc
-                                   acc
-                                   (if (overlap? a b)
-                                       (cons (merge a b) b)
-                                       #f)))
-                             #f x))
-                    (removed (if merged
-                                 (filter (lambda (b) (not (eq? b (cdr merged)))) x)
-                                 x)))
-               (if merged
-                   (consolidate removed (cons (car merged) (cdr y)))
-                   (consolidate (cons a x) (cdr y)))))))
+  (define (union p q)
+    (append (filter (lambda (x) (not (assoc (car x) p)))
+                   q)
+           p))
+  (union a (map f b)))
+
+
+;;                                     ; a1 -> a2 ~ a3 -> a4;
+;;                                     ; a1 -> a2 !~ bool -> bool
+;;                                     ; basically can the tvars be renamed
+(define (types-equal? x y)
+  (let ([cs (unify? x y)])
+    (if (not cs) #f    
+       (let*
+           ([test (lambda (acc c)
+                    (and acc
+                         (tvar? (car c)) ; the only substitutions allowed are tvar -> tvar
+                         (tvar? (cdr c))))])
+         (fold-left test #t cs)))))
+
+                                       ; input: a list of binds ((x . y) (y . 3))
+                                       ; returns: pair of verts, edges ((x y) . (x . y))
+(define (graph bs)
+  (define (go bs orig-bs)
+    (define (find-refs prog)
+      (ast-collect
+       (lambda (x)
+        (case (ast-type x)
+                                       ; only count a reference if its a binding
+          ['var (if (assoc x orig-bs) (list x) '())]
+          [else '()]))
+       prog))
+    (if (null? bs)
+       '(() . ())
+       (let* [(bind (car bs))
+
+              (vert (car bind))
+              (refs (find-refs (cdr bind)))
+              (edges (map (lambda (x) (cons vert x))
+                          refs))
+
+              (rest (if (null? (cdr bs))
+                        (cons '() '())
+                        (go (cdr bs) orig-bs)))
+              (total-verts (cons vert (car rest)))
+              (total-edges (append edges (cdr rest)))]
+         (cons total-verts total-edges))))
+  (go bs bs))
+
+(define (successors graph v)
+  (define (go v E)
+    (if (null? E)
+       '()
+       (if (eqv? v (caar E))
+           (cons (cdar E) (go v (cdr E)))
+           (go v (cdr E)))))
+  (go v (cdr graph)))
+
+                                       ; takes in a graph (pair of vertices, edges)
+                                       ; returns a list of strongly connected components
+
+                                       ; ((x y w) . ((x . y) (x . w) (w . x))
+
+                                       ; =>
+                                       ; .->x->y
+                                       ; |  |
+                                       ; |  v
+                                       ; .--w
+
+                                       ; ((x w) (y))
+
+                                       ; this uses tarjan's algorithm, to get reverse
+                                       ; topological sorting for free
+(define (sccs graph)
+  
+  (let* ([indices (make-hash-table)]
+        [lowlinks (make-hash-table)]
+        [on-stack (make-hash-table)]
+        [current 0]
+        [stack '()]
+        [result '()])
+
+    (define (index v)
+      (get-hash-table indices v #f))
+    (define (lowlink v)
+      (get-hash-table lowlinks v #f))
+
+    (letrec
+       ([strong-connect
+         (lambda (v)
+           (begin
+             (put-hash-table! indices v current)
+             (put-hash-table! lowlinks v current)
+             (set! current (+ current 1))
+             (push! stack v)
+             (put-hash-table! on-stack v #t)
+
+             (for-each
+              (lambda (w)
+                (if (not (hashtable-contains? indices w))
+                                       ; successor w has not been visited, recurse
+                    (begin
+                      (strong-connect w)
+                      (put-hash-table! lowlinks
+                                       v
+                                       (min (lowlink v) (lowlink w))))
+                                       ; successor w has been visited
+                    (when (get-hash-table on-stack w #f)
+                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
+              (successors graph v))
+
+             (when (= (index v) (lowlink v))
+               (let ([scc
+                      (let new-scc ()
+                        (let ([w (pop! stack)])
+                          (put-hash-table! on-stack w #f)
+                          (if (eqv? w v)
+                              (list w)
+                              (cons w (new-scc)))))])
+                 (set! result (cons scc result))))))])
+      (for-each
+       (lambda (v)
+        (when (not (hashtable-contains? indices v)) ; v.index == -1
+          (strong-connect v)))
+       (car graph)))
+    result))
+