Add ast-traverse helper
[scheme.git] / typecheck.scm
index a396a246ee67cf1c65d975d1021a0f91693f36b5..197798e104a7595f5637646cb102f2d9d99f44d5 100644 (file)
@@ -1,19 +1,36 @@
-(define (is-app? x)
-  (and (list? x) (not (eq? (car x) 'lambda))))
+(load "ast.scm")
 
+(define (abs? t)
+  (and (list? t) (eq? (car t) 'abs)))
+
+(define (tvar? t)
+  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
 
-(define (is-lambda? x)
-  (and (list? x) (eq? (car x) 'lambda)))
+(define (concrete? t)
+  (case t
+    ('int #t)
+    ('bool #t)
+    (else #f)))
 
-(define lambda-arg cadr)
-(define lambda-body caddr)
+(define (pretty-type t)
+  (cond ((abs? t)
+        (string-append
+         (if (abs? (cadr t))
+             (string-append "(" (pretty-type (cadr t)) ")")
+             (pretty-type (cadr t)))
+         " -> "
+         (pretty-type (caddr t))))
+       (else (symbol->string t))))
 
                                        ; ('a, ('b, 'a))
-(define (env-lookup env x)
+(define (env-lookup env n)
   (if (null? env) (error #f "empty env")                       ; it's a type equality
-      (if (eq? (caar env) x)
+      (if (eq? (caar env) n)
          (cdar env)
-         (env-lookup (cdr env) x))))
+         (env-lookup (cdr env) n))))
+
+(define (env-insert env n t)
+  (cons (cons n t) env))
 
 (define abs-arg cadr)
 
     (string->symbol
      (string-append "t" (number->string (- cur-tvar 1))))))
 
-(define (typecheck env x)
-  (display "typechecking:\n\t")
-  (display x)
-  (display "\t")
-  (display env)
-  (display "\n")
+(define (last xs)
+  (if (null? (cdr xs))
+      (car xs)
+      (last (cdr xs))))
+
+                                       
+(define (normalize prog) ; (+ a b) -> ((+ a) b)
+  (case (ast-type prog)
+    ('lambda 
+                                       ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
+       (if (> (length (lambda-args prog)) 1)
+           (list 'lambda (list (car (lambda-args prog)))
+                 (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
+           (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
+    ('app
+     (if (null? (cddr prog))
+        `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
+        `(,(list (normalize (car prog)) (normalize (cadr prog)))
+          ,(normalize (caddr prog))))) ; (f a b)
+    ;; (list (list (normalize (car prog))
+    ;;             (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
+    ('let
+       (append (list 'let
+                     (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
+                          (let-bindings prog)))
+               (map normalize (let-body prog))))
+    ('if `(if ,(normalize (cadr prog))
+             ,(normalize (caddr prog))
+             ,(normalize (cadddr prog))))
+    (else prog)))
+
+(define (builtin-type x)
+  (case x
+    ('+ '(abs int (abs int int)))
+    ('- '(abs int (abs int int)))
+    ('* '(abs int (abs int int)))
+    ('! '(abs bool bool))
+    ('= '(abs int (abs int bool)))
+    ('bool->int '(abs bool int))
+    (else #f)))
+
+; we typecheck the lambda calculus only (only single arg lambdas)
+(define (typecheck prog)
+  (define (check env x)
+    ;; (display "check: ")
+    ;; (display x)
+    ;; (display "\n\t")
+    ;; (display env)
+    ;; (newline)
     (let
        ((res
-       (cond
-        ((integer? x) (list '() 'int))
-        ((boolean? x) (list '() 'bool))
-        ((eq? x 'inc) (list '() '(abs int int)))
-        ((symbol? x) (list '() (env-lookup env x)))
-
-        ((is-lambda? x)
-         (let* ((new-env (cons (cons (lambda-arg x) (fresh-tvar)) env))
-                (body-type-res (typecheck new-env (lambda-body x)))
-                (subd-env (substitute (car body-type-res) new-env)))
-           (display "lambda: ")
-           (display body-type-res)
-           (display "\n")
+         (case (ast-type x)
+          ('int-literal (list '() 'int))
+          ('bool-literal (list '() 'bool))
+          ('builtin (list '() (builtin-type x)))
+
+          ('if
+           (let* ((cond-type-res (check env (cadr x)))
+                  (then-type-res (check env (caddr x)))
+                  (else-type-res (check env (cadddr x)))
+                  (then-eq-else-cs (unify (cadr then-type-res)
+                                          (cadr else-type-res)))
+                  (cs (consolidate
+                       (car then-type-res)
+                       (consolidate (car else-type-res)
+                                    then-eq-else-cs)))
+                  (return-type (substitute cs (cadr then-type-res))))
+             (when (not (eqv? (cadr cond-type-res) 'bool))
+               (error #f "if condition isn't bool"))
+             (list cs return-type)))
+          
+          ('var  (list '() (env-lookup env x)))
+          ('let
+           (let ((new-env (fold-left
+                           (lambda (acc bind)
+                             (let ((t (check
+                                       (env-insert acc (car bind) (fresh-tvar))
+                                       (cadr bind))))
+                               (env-insert acc (car bind) (cadr t))))
+                           env (let-bindings x))))
+             (check new-env (last (let-body x)))))
+                 
+
+          ('lambda
+           (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
+                  (body-type-res (check new-env (lambda-body x)))
+                  (cs (car body-type-res))
+                  (subd-env (substitute-env (car body-type-res) new-env))
+                  (arg-type (env-lookup subd-env (lambda-arg x)))
+                  (resolved-arg-type (substitute cs arg-type)))
+             ;; (display "lambda:\n\t")
+             ;; (display prog)
+             ;; (display "\n\t")
+             ;; (display cs)
+             ;; (display "\n\t")
+             ;; (display resolved-arg-type)
+             ;; (newline)
              (list (car body-type-res)
                    (list 'abs
-                       (env-lookup subd-env (lambda-arg x))
+                         resolved-arg-type
                          (cadr body-type-res)))))
           
-        ((is-app? x) ; (f a)
-         (let* ((arg-type-res (typecheck env (cadr x)))
-                                       ; typecheck f with the knowledge that f : a -> x
-                (func-type-res (typecheck env (car x)))
+          ('app ; (f a)
+           (let* ((arg-type-res (check env (cadr x)))
+                  (arg-type (cadr arg-type-res))
+                  (func-type-res (check env (car x)))
                   (func-type (cadr func-type-res))
-                (c (unify func-type
+                  
+                                       ; f ~ a -> t0
+                  (func-c (unify func-type
                                  (list 'abs
-                                (cadr arg-type-res)
+                                       arg-type
                                        (fresh-tvar))))
-                (new-env (substitute c env))
-                (resolved-func-type (env-lookup new-env (car x))))
-           (display "is-app:\n")
-           (display c)
-           (display "\n")
-           (display new-env)
-           (display "\n")
-           (display resolved-func-type)
-           (display "\n")
-           (display arg-type-res)
-           (display "\n")
+                  (cs (consolidate
+                       (consolidate func-c (car arg-type-res))
+                       (car func-type-res)))
+                  
+                  (resolved-func-type (substitute cs func-type))
+                  (resolved-return-type (caddr resolved-func-type)))
+             ;; (display "app:\n")
+             ;; (display cs)
+             ;; (display "\n")
+             ;; (display func-type)
+             ;; (display "\n")
+             ;; (display resolved-func-type)
+             ;; (display "\n")
+             ;; (display arg-type-res)
+             ;; (display "\n")
              (if (abs? resolved-func-type)
-               (list (append c
-                             (unify (cadr arg-type-res)
-                                    (cadr resolved-func-type)))
-                     (caddr resolved-func-type))
-               (error #f "wah")))))))
-    (display "result of ")
-    (display x)
-    (display ":\n\t")
-    (display (cadr res))
-    (display "[")
-    (display (car res))
-    (display "]\n")
+                 (let ((return-type (substitute cs (caddr resolved-func-type))))
+                   (list cs return-type))
+                 (error #f "not a function")))))))
+      ;; (display "result of ")
+      ;; (display x)
+      ;; (display ":\n\t")
+      ;; (display (cadr res))
+      ;; (display "[")
+      ;; (display (car res))
+      ;; (display "]\n")
       res))
-
-
-(define (abs? t)
-  (and (list? t) (eq? (car t) 'abs)))
-
-(define (tvar? t)
-  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
-
-(define (concrete? t)
-  (case t
-    ('int #t)
-    ('bool #t)
-    (else #f)))
+  (cadr (check '() (normalize prog))))
 
                                        ; returns a list of pairs of constraints
 (define (unify a b)
   (cond ((eq? a b) '())
-       ((or (tvar? a) (tvar? b)) (list (cons a b)))
+       ((or (tvar? a) (tvar? b)) (~ a b))
        ((and (abs? a) (abs? b))
-        (append (unify (cadr a) (cadr b))
+        (consolidate (unify (cadr a) (cadr b))
                      (unify (caddr a) (caddr b))))
        (else (error #f "could not unify"))))
 
-                                       ; takes a list of constraints and a type environment, and makes it work
-(define (substitute c env)
-  (let ((go (lambda (x) (let ((tv (cdr x))
-                             (n (car x)))
-                         ;; (display tv)
-                         ;; (display "\n")
-                         ;; (display n)
-                         (cons n (fold-left
-                                  (lambda (a y)
-                                    ;; (display y)
-                                    ;; (display ":")
-                                    ;; (display a)
-                                    (cond ((eq? a (car y)) (cdr y))
-                                          ((eq? a (cdr y)) (car y))
-                                          (else a)))
-                                  tv c))))))
-  (map go env)))
+                                       ; TODO: what's the most appropriate substitution?
+                                       ; should all constraints just be limited to a pair?
+(define (substitute cs t)
+                                       ; gets the first concrete type
+                                       ; otherwise returns the last type variable
+
+  (define (get-concrete c)
+    (let ((last (null? (cdr c))))
+      (if (not (tvar? (car c)))
+         (if (abs? (car c))
+             (substitute cs (car c))
+             (car c))
+         (if last
+             (car c)
+             (get-concrete (cdr c))))))
+  (cond
+   ((abs? t) (list 'abs
+                  (substitute cs (cadr t))
+                  (substitute cs (caddr t))))
+   (else
+    (fold-left
+     (lambda (t c)
+       (if (member t c)
+          (get-concrete c)
+          t))
+     t cs))))
+
+(define (substitute-env cs env)
+  (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
+
+(define (~ a b)
+  (list (list a b)))
+
+(define (consolidate x y)
+  (define (merge a b)
+    (cond ((null? a) b)
+         ((null? b) a)
+         (else (if (member (car b) a)
+                   (merge a (cdr b))
+                   (cons (car b) (merge a (cdr b)))))))
+  (define (overlap? a b)
+    (if (or (null? a) (null? b))
+       #f
+       (if (fold-left (lambda (acc v)
+                        (or acc (eq? v (car a))))
+                      #f b)
+           #t
+           (overlap? (cdr a) b))))
+
+  (cond ((null? y) x)
+       ((null? x) y)
+       (else (let* ((a (car y))
+                    (merged (fold-left
+                             (lambda (acc b)
+                               (if acc
+                                   acc
+                                   (if (overlap? a b)
+                                       (cons (merge a b) b)
+                                       #f)))
+                             #f x))
+                    (removed (if merged
+                                 (filter (lambda (b) (not (eq? b (cdr merged)))) x)
+                                 x)))
+               (if merged
+                   (consolidate removed (cons (car merged) (cdr y)))
+                   (consolidate (cons a x) (cdr y)))))))