Add ast-traverse helper
[scheme.git] / typecheck.scm
index eaff75efbe6fdf313e0522216c5e07d2486c75ac..197798e104a7595f5637646cb102f2d9d99f44d5 100644 (file)
@@ -1,4 +1,27 @@
 (load "ast.scm")
+
+(define (abs? t)
+  (and (list? t) (eq? (car t) 'abs)))
+
+(define (tvar? t)
+  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
+
+(define (concrete? t)
+  (case t
+    ('int #t)
+    ('bool #t)
+    (else #f)))
+
+(define (pretty-type t)
+  (cond ((abs? t)
+        (string-append
+         (if (abs? (cadr t))
+             (string-append "(" (pretty-type (cadr t)) ")")
+             (pretty-type (cadr t)))
+         " -> "
+         (pretty-type (caddr t))))
+       (else (symbol->string t))))
+
                                        ; ('a, ('b, 'a))
 (define (env-lookup env n)
   (if (null? env) (error #f "empty env")                       ; it's a type equality
 
                                        
 (define (normalize prog) ; (+ a b) -> ((+ a) b)
-  (cond
+  (case (ast-type prog)
+    ('lambda 
                                        ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
-   ((lambda? prog)
        (if (> (length (lambda-args prog)) 1)
            (list 'lambda (list (car (lambda-args prog)))
                  (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
            (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
-   ((app? prog)
+    ('app
      (if (null? (cddr prog))
-       (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
-       (list (list (normalize (car prog)) (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
-   ((let? prog)
+        `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
+        `(,(list (normalize (car prog)) (normalize (cadr prog)))
+          ,(normalize (caddr prog))))) ; (f a b)
+    ;; (list (list (normalize (car prog))
+    ;;             (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
+    ('let
        (append (list 'let
-                 (map (lambda (x) (cons (car x) (normalize (cdr x))))
+                     (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
                           (let-bindings prog)))
                (map normalize (let-body prog))))
+    ('if `(if ,(normalize (cadr prog))
+             ,(normalize (caddr prog))
+             ,(normalize (cadddr prog))))
     (else prog)))
 
+(define (builtin-type x)
+  (case x
+    ('+ '(abs int (abs int int)))
+    ('- '(abs int (abs int int)))
+    ('* '(abs int (abs int int)))
+    ('! '(abs bool bool))
+    ('= '(abs int (abs int bool)))
+    ('bool->int '(abs bool int))
+    (else #f)))
 
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
     ;; (newline)
     (let
        ((res
-         (cond
-          ((integer? x) (list '() 'int))
-          ((boolean? x) (list '() 'bool))
-          ((eq? x 'inc) (list '() '(abs int int)))
-          ((eq? x '+)   (list '() '(abs int (abs int int))))
-          ((symbol? x)  (list '() (env-lookup env x)))
+         (case (ast-type x)
+          ('int-literal (list '() 'int))
+          ('bool-literal (list '() 'bool))
+          ('builtin (list '() (builtin-type x)))
+
+          ('if
+           (let* ((cond-type-res (check env (cadr x)))
+                  (then-type-res (check env (caddr x)))
+                  (else-type-res (check env (cadddr x)))
+                  (then-eq-else-cs (unify (cadr then-type-res)
+                                          (cadr else-type-res)))
+                  (cs (consolidate
+                       (car then-type-res)
+                       (consolidate (car else-type-res)
+                                    then-eq-else-cs)))
+                  (return-type (substitute cs (cadr then-type-res))))
+             (when (not (eqv? (cadr cond-type-res) 'bool))
+               (error #f "if condition isn't bool"))
+             (list cs return-type)))
           
-          ((let? x)
+          ('var  (list '() (env-lookup env x)))
+          ('let
            (let ((new-env (fold-left
                            (lambda (acc bind)
                              (let ((t (check
              (check new-env (last (let-body x)))))
                  
 
-          ((lambda? x)
+          ('lambda
            (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
                   (body-type-res (check new-env (lambda-body x)))
                   (cs (car body-type-res))
                          resolved-arg-type
                          (cadr body-type-res)))))
           
-          ((app? x) ; (f a)
+          ('app ; (f a)
            (let* ((arg-type-res (check env (cadr x)))
                   (arg-type (cadr arg-type-res))
                   (func-type-res (check env (car x)))
       res))
   (cadr (check '() (normalize prog))))
 
-
-(define (abs? t)
-  (and (list? t) (eq? (car t) 'abs)))
-
-(define (tvar? t)
-  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
-
-(define (concrete? t)
-  (case t
-    ('int #t)
-    ('bool #t)
-    (else #f)))
-
                                        ; returns a list of pairs of constraints
 (define (unify a b)
   (cond ((eq? a b) '())