Start thinking about heap allocation
[scheme.git] / typecheck.scm
index 59e652ad9da26f36906a80964c9956179ecdad9f..313db0e2a9d48fbdbc7fcaff2f89eab749f905b7 100644 (file)
@@ -1,4 +1,28 @@
 (load "ast.scm")
 (load "ast.scm")
+
+(define (abs? t)
+  (and (list? t) (eq? (car t) 'abs)))
+
+(define (tvar? t)
+  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
+
+(define (concrete? t)
+  (case t
+    ('int #t)
+    ('bool #t)
+    ('void #t)
+    (else #f)))
+
+(define (pretty-type t)
+  (cond ((abs? t)
+        (string-append
+         (if (abs? (cadr t))
+             (string-append "(" (pretty-type (cadr t)) ")")
+             (pretty-type (cadr t)))
+         " -> "
+         (pretty-type (caddr t))))
+       (else (symbol->string t))))
+
                                        ; ('a, ('b, 'a))
 (define (env-lookup env n)
   (if (null? env) (error #f "empty env")                       ; it's a type equality
                                        ; ('a, ('b, 'a))
 (define (env-lookup env n)
   (if (null? env) (error #f "empty env")                       ; it's a type equality
       (car xs)
       (last (cdr xs))))
                                
       (car xs)
       (last (cdr xs))))
                                
-                                       
 (define (normalize prog) ; (+ a b) -> ((+ a) b)
 (define (normalize prog) ; (+ a b) -> ((+ a) b)
-  (cond
+  (case (ast-type prog)
+    ('lambda 
                                        ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
                                        ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
-   ((lambda? prog)
        (if (> (length (lambda-args prog)) 1)
            (list 'lambda (list (car (lambda-args prog)))
                  (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
            (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
        (if (> (length (lambda-args prog)) 1)
            (list 'lambda (list (car (lambda-args prog)))
                  (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
            (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
-   ((app? prog)
+    ('app
      (if (null? (cddr prog))
      (if (null? (cddr prog))
-       (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
-       (list (list (normalize (car prog)) (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
-   ((let? prog)
+        `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
+        (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
+                     ,@(cddr prog))))) ; (f a b)
+    ('let
        (append (list 'let
        (append (list 'let
-                 (map (lambda (x) (cons (car x) (normalize (cdr x))))
+                     (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
                           (let-bindings prog)))
                (map normalize (let-body prog))))
                           (let-bindings prog)))
                (map normalize (let-body prog))))
-   (else prog)))
+    (else (ast-traverse normalize prog))))
 
 (define (builtin-type x)
   (case x
 
 (define (builtin-type x)
   (case x
@@ -49,7 +73,9 @@
     ('- '(abs int (abs int int)))
     ('* '(abs int (abs int int)))
     ('! '(abs bool bool))
     ('- '(abs int (abs int int)))
     ('* '(abs int (abs int int)))
     ('! '(abs bool bool))
+    ('= '(abs int (abs int bool)))
     ('bool->int '(abs bool int))
     ('bool->int '(abs bool int))
+    ('print '(abs string void))
     (else #f)))
 
 ; we typecheck the lambda calculus only (only single arg lambdas)
     (else #f)))
 
 ; we typecheck the lambda calculus only (only single arg lambdas)
     ;; (newline)
     (let
        ((res
     ;; (newline)
     (let
        ((res
-         (cond
-          ((integer? x) (list '() 'int))
-          ((boolean? x) (list '() 'bool))
-          ((builtin-type x) (list '() (builtin-type x)))
-          ((symbol? x)  (list '() (env-lookup env x)))
-          ((let? x)
-           (let ((new-env (fold-left
-                           (lambda (acc bind)
-                             (let ((t (check
-                                       (env-insert acc (car bind) (fresh-tvar))
-                                       (cadr bind))))
-                               (env-insert acc (car bind) (cadr t))))
-                           env (let-bindings x))))
+         (case (ast-type x)
+           ('int-literal (list '() 'int))
+           ('bool-literal (list '() 'bool))
+           ('string-literal (list '() 'string))
+           ('builtin (list '() (builtin-type x)))
+
+           ('if
+            (let* ((cond-type-res (check env (cadr x)))
+                   (then-type-res (check env (caddr x)))
+                   (else-type-res (check env (cadddr x)))
+                   (then-eq-else-cs (unify (cadr then-type-res)
+                                           (cadr else-type-res)))
+                   (cs (consolidate
+                        (car then-type-res)
+                        (consolidate (car else-type-res)
+                                     then-eq-else-cs)))
+                   (return-type (substitute cs (cadr then-type-res))))
+              (when (not (eqv? (cadr cond-type-res) 'bool))
+                (error #f "if condition isn't bool"))
+              (list cs return-type)))
+           
+           ('var (list '() (env-lookup env x)))
+           ('let
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+             (let* ([components (reverse (sccs (graph (let-bindings x))))]
+                    [process-component
+                     (lambda (acc comps)
+                       (let*
+                                       ; create a new env with tvars for each component
+                                       ; e.g. scc of (x y)
+                                       ; scc-env = ((x . t0) (y . t1))
+                           ([scc-env
+                             (fold-left
+                              (lambda (acc c)
+                                (env-insert acc c (fresh-tvar)))
+                              acc comps)]
+                                       ; typecheck each component
+                            [type-results
+                             (map
+                              (lambda (c)
+                                (let ([body (cadr (assoc c (let-bindings x)))])
+                                  (check scc-env body)))
+                              comps)]
+                                       ; collect all the constraints in the scc
+                            [cs
+                             (fold-left
+                              (lambda (acc res c)
+                                (consolidate
+                                 acc
+                                 (consolidate (car res)
+                                       ; unify with tvars from scc-env
+                                       ; result ~ tvar
+                                              (unify (cadr res) (env-lookup scc-env c)))))
+                              '() type-results comps)]
+                                       ; substitute *only* the bindings in this scc
+                            [new-env
+                             (map (lambda (x)
+                                    (if (memv (car x) comps)
+                                        (cons (car x) (substitute cs (cdr x)))
+                                        x))
+                                  scc-env)])
+                         new-env))]
+                    [new-env (fold-left process-component env components)])
                (check new-env (last (let-body x)))))
            
                (check new-env (last (let-body x)))))
            
+           ('lambda
+               (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
 
 
-          ((lambda? x)
-           (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
                       (body-type-res (check new-env (lambda-body x)))
                       (cs (car body-type-res))
                       (subd-env (substitute-env (car body-type-res) new-env))
                       (arg-type (env-lookup subd-env (lambda-arg x)))
                       (body-type-res (check new-env (lambda-body x)))
                       (cs (car body-type-res))
                       (subd-env (substitute-env (car body-type-res) new-env))
                       (arg-type (env-lookup subd-env (lambda-arg x)))
-                  (resolved-arg-type (substitute cs arg-type)))
+                      (resolved-arg-type (substitute cs arg-type))]
                  ;; (display "lambda:\n\t")
                  ;; (display prog)
                  ;; (display "\n\t")
                  ;; (display "lambda:\n\t")
                  ;; (display prog)
                  ;; (display "\n\t")
                              resolved-arg-type
                              (cadr body-type-res)))))
            
                              resolved-arg-type
                              (cadr body-type-res)))))
            
-          ((app? x) ; (f a)
+           ('app ; (f a)
+            (if (eqv? (car x) (cadr x))
+                                       ; recursive function (f f)
+                (let* [(func-type (env-lookup env (car x)))
+                       (return-type (fresh-tvar))
+                       (other-func-type `(abs ,func-type ,return-type))
+                       (cs (unify func-type other-func-type))]
+                  (list cs return-type))
+
+                                       ; regular function
                 (let* ((arg-type-res (check env (cadr x)))
                        (arg-type (cadr arg-type-res))
                        (func-type-res (check env (car x)))
                 (let* ((arg-type-res (check env (cadr x)))
                        (arg-type (cadr arg-type-res))
                        (func-type-res (check env (car x)))
                   (if (abs? resolved-func-type)
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
                   (if (abs? resolved-func-type)
                       (let ((return-type (substitute cs (caddr resolved-func-type))))
                         (list cs return-type))
-                 (error #f "not a function")))))))
+                      (error #f "not a function"))))))))
       ;; (display "result of ")
       ;; (display x)
       ;; (display ":\n\t")
       ;; (display "result of ")
       ;; (display x)
       ;; (display ":\n\t")
-      ;; (display (cadr res))
-      ;; (display "[")
+      ;; (display (pretty-type (cadr res)))
+      ;; (display "\n\t[")
       ;; (display (car res))
       ;; (display "]\n")
       res))
   (cadr (check '() (normalize prog))))
 
       ;; (display (car res))
       ;; (display "]\n")
       res))
   (cadr (check '() (normalize prog))))
 
-
-(define (abs? t)
-  (and (list? t) (eq? (car t) 'abs)))
-
-(define (tvar? t)
-  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
-
-(define (concrete? t)
-  (case t
-    ('int #t)
-    ('bool #t)
-    (else #f)))
-
                                        ; returns a list of pairs of constraints
 (define (unify a b)
                                        ; returns a list of pairs of constraints
 (define (unify a b)
+  (let ([res (unify? a b)])
+    (if res
+       res
+       (error #f
+              (format "couldn't unify ~a ~~ ~a" a b)))))
+
+(define (unify? a b)
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
   (cond ((eq? a b) '())
        ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
-        (consolidate (unify (cadr a) (cadr b))
-                     (unify (caddr a) (caddr b))))
-       (else (error #f "could not unify"))))
+        (let* [(arg-cs (unify (cadr a) (cadr b)))
+               (body-cs (unify (substitute arg-cs (caddr a))
+                               (substitute arg-cs (caddr b))))]
+          (consolidate arg-cs body-cs)))
+       (else #f)))
 
                                        ; TODO: what's the most appropriate substitution?
                                        ; should all constraints just be limited to a pair?
 
                                        ; TODO: what's the most appropriate substitution?
                                        ; should all constraints just be limited to a pair?
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
 
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
 
+  (define cs-without-t
+    (map (lambda (c)
+          (filter (lambda (x) (not (eqv? t x))) c))
+        cs))
+
   (define (get-concrete c)
   (define (get-concrete c)
-    (let ((last (null? (cdr c))))
+    (let [(last (null? (cdr c)))]
       (if (not (tvar? (car c)))
          (if (abs? (car c))
       (if (not (tvar? (car c)))
          (if (abs? (car c))
-             (substitute cs (car c))
+             (substitute cs-without-t (car c))
              (car c))
          (if last
              (car c)
              (get-concrete (cdr c))))))
              (car c))
          (if last
              (car c)
              (get-concrete (cdr c))))))
+  
   (cond
    ((abs? t) (list 'abs
                   (substitute cs (cadr t))
   (cond
    ((abs? t) (list 'abs
                   (substitute cs (cadr t))
 
   (cond ((null? y) x)
        ((null? x) y)
 
   (cond ((null? y) x)
        ((null? x) y)
-       (else (let* ((a (car y))
+       (else
+        (let* ((a (car y))
                (merged (fold-left
                         (lambda (acc b)
                           (if acc
                (merged (fold-left
                         (lambda (acc b)
                           (if acc
           (if merged
               (consolidate removed (cons (car merged) (cdr y)))
               (consolidate (cons a x) (cdr y)))))))
           (if merged
               (consolidate removed (cons (car merged) (cdr y)))
               (consolidate (cons a x) (cdr y)))))))
+
+                                       ; a1 -> a2 ~ a3 -> a4;
+                                       ; a1 -> a2 !~ bool -> bool
+                                       ; basically can the tvars be renamed
+(define (types-equal? x y)
+  (let ([cs (unify? x y)])
+    (if (not cs) #f
+       (let*
+           ([test-kind
+             (lambda (acc c)
+               (if (tvar? c) acc #f))]
+            [test (lambda (acc c)
+                    (and acc (fold-left test-kind #t c)))])
+         (fold-left test #t cs)))))
+
+                                       ; input: a list of binds ((x . y) (y . 3))
+                                       ; returns: pair of verts, edges ((x y) . (x . y))
+(define (graph bs)
+  (define (find-refs prog)
+    (ast-collect
+     (lambda (x)
+       (case (ast-type x)
+                                       ; only count a reference if its a binding
+        ['var (if (assoc x bs) (list x) '())]
+        [else '()]))
+     prog))
+  (if (null? bs)
+      '(() . ())
+      (let* [(bind (car bs))
+
+            (vert (car bind))
+            (refs (find-refs (cdr bind)))
+            (edges (map (lambda (x) (cons vert x))
+                        refs))
+
+            (rest (if (null? (cdr bs))
+                      (cons '() '())
+                      (graph (cdr bs))))
+            (total-verts (cons vert (car rest)))
+            (total-edges (append edges (cdr rest)))]
+       (cons total-verts total-edges))))
+
+(define (successors graph v)
+  (define (go v E)
+    (if (null? E)
+       '()
+       (if (eqv? v (caar E))
+           (cons (cdar E) (go v (cdr E)))
+           (go v (cdr E)))))
+  (go v (cdr graph)))
+
+                                       ; takes in a graph (pair of vertices, edges)
+                                       ; returns a list of strongly connected components
+
+                                       ; ((x y w) . ((x . y) (x . w) (w . x))
+
+                                       ; =>
+                                       ; .->x->y
+                                       ; |  |
+                                       ; |  v
+                                       ; .--w
+
+                                       ; ((x w) (y))
+
+                                       ; this uses tarjan's algorithm, to get reverse
+                                       ; topological sorting for free
+(define (sccs graph)
+  
+  (let* ([indices (make-hash-table)]
+        [lowlinks (make-hash-table)]
+        [on-stack (make-hash-table)]
+        [current 0]
+        [stack '()]
+        [result '()])
+
+    (define (index v)
+      (get-hash-table indices v #f))
+    (define (lowlink v)
+      (get-hash-table lowlinks v #f))
+
+    (letrec
+       ([strong-connect
+         (lambda (v)
+           (begin
+             (put-hash-table! indices v current)
+             (put-hash-table! lowlinks v current)
+             (set! current (+ current 1))
+             (push! stack v)
+             (put-hash-table! on-stack v #t)
+
+             (for-each
+              (lambda (w)
+                (if (not (hashtable-contains? indices w))
+                                       ; successor w has not been visited, recurse
+                    (begin
+                      (strong-connect w)
+                      (put-hash-table! lowlinks
+                                       v
+                                       (min (lowlink v) (lowlink w))))
+                                       ; successor w has been visited
+                    (when (get-hash-table on-stack w #f)
+                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
+              (successors graph v))
+
+             (when (= (index v) (lowlink v))
+               (let ([scc
+                      (let new-scc ()
+                        (let ([w (pop! stack)])
+                          (put-hash-table! on-stack w #f)
+                          (if (eqv? w v)
+                              (list w)
+                              (cons w (new-scc)))))])
+                 (set! result (cons scc result))))))])
+      (for-each
+       (lambda (v)
+        (when (not (hashtable-contains? indices v)) ; v.index == -1
+          (strong-connect v)))
+       (car graph)))
+    result))
+