Add recursive let-bindings
[scheme.git] / typecheck.scm
index 98d1cf449fdaeed5f1b0f1eb22abf60a405d3a66..55c2fd8201f2467ff3e136189412b6dc351780fd 100644 (file)
@@ -1,4 +1,27 @@
 (load "ast.scm")
+
+(define (abs? t)
+  (and (list? t) (eq? (car t) 'abs)))
+
+(define (tvar? t)
+  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
+
+(define (concrete? t)
+  (case t
+    ('int #t)
+    ('bool #t)
+    (else #f)))
+
+(define (pretty-type t)
+  (cond ((abs? t)
+        (string-append
+         (if (abs? (cadr t))
+             (string-append "(" (pretty-type (cadr t)) ")")
+             (pretty-type (cadr t)))
+         " -> "
+         (pretty-type (caddr t))))
+       (else (symbol->string t))))
+
                                        ; ('a, ('b, 'a))
 (define (env-lookup env n)
   (if (null? env) (error #f "empty env")                       ; it's a type equality
        (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
    ((app? prog)
     (if (null? (cddr prog))
-       (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
-       (list (list (normalize (car prog)) (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
+       `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
+       `(,(list (normalize (car prog)) (normalize (cadr prog)))
+         ,(normalize (caddr prog))))) ; (f a b)
+       ;; (list (list (normalize (car prog))
+       ;;          (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
    ((let? prog)
     (append (list 'let
-                 (map (lambda (x) (cons (car x) (normalize (cdr x))))
+                 (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
                       (let-bindings prog)))
            (map normalize (let-body prog))))
    (else prog)))
 
+(define (builtin-type x)
+  (case x
+    ('+ '(abs int (abs int int)))
+    ('- '(abs int (abs int int)))
+    ('* '(abs int (abs int int)))
+    ('! '(abs bool bool))
+    ('bool->int '(abs bool int))
+    (else #f)))
 
 ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
   (define (check env x)
-    (display "check: ")
-    (display x)
-    (display "\n\t")
-    (display env)
-    (newline)
+    ;; (display "check: ")
+    ;; (display x)
+    ;; (display "\n\t")
+    ;; (display env)
+    ;; (newline)
     (let
        ((res
          (cond
           ((integer? x) (list '() 'int))
           ((boolean? x) (list '() 'bool))
-          ((eq? x 'inc) (list '() '(abs int int)))
-          ((eq? x '+)   (list '() '(abs int (abs int int))))
+          ((builtin-type x) (list '() (builtin-type x)))
           ((symbol? x)  (list '() (env-lookup env x)))
-
           ((let? x)
            (let ((new-env (fold-left
                            (lambda (acc bind)
           ((lambda? x)
            (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
                   (body-type-res (check new-env (lambda-body x)))
-                  (subd-env (substitute-env (car body-type-res) new-env)))
-             (display "lambda: ")
-             (display body-type-res)
-             (display "\n")
-             (display subd-env)
-             (display "\n")
+                  (cs (car body-type-res))
+                  (subd-env (substitute-env (car body-type-res) new-env))
+                  (arg-type (env-lookup subd-env (lambda-arg x)))
+                  (resolved-arg-type (substitute cs arg-type)))
+             ;; (display "lambda:\n\t")
+             ;; (display prog)
+             ;; (display "\n\t")
+             ;; (display cs)
+             ;; (display "\n\t")
+             ;; (display resolved-arg-type)
+             ;; (newline)
              (list (car body-type-res)
                    (list 'abs
-                         (env-lookup subd-env (lambda-arg x))
+                         resolved-arg-type
                          (cadr body-type-res)))))
           
           ((app? x) ; (f a)
                                  (list 'abs
                                        arg-type
                                        (fresh-tvar))))
-                  (cs (append func-c (car arg-type-res) (car func-type-res)))
+                  (cs (consolidate
+                       (consolidate func-c (car arg-type-res))
+                       (car func-type-res)))
                   
                   (resolved-func-type (substitute cs func-type))
                   (resolved-return-type (caddr resolved-func-type)))
                  (let ((return-type (substitute cs (caddr resolved-func-type))))
                    (list cs return-type))
                  (error #f "not a function")))))))
-      (display "result of ")
-      (display x)
-      (display ":\n\t")
-      (display (cadr res))
-      (display "[")
-      (display (car res))
-      (display "]\n")
+      ;; (display "result of ")
+      ;; (display x)
+      ;; (display ":\n\t")
+      ;; (display (cadr res))
+      ;; (display "[")
+      ;; (display (car res))
+      ;; (display "]\n")
       res))
   (cadr (check '() (normalize prog))))
 
-
-(define (abs? t)
-  (and (list? t) (eq? (car t) 'abs)))
-
-(define (tvar? t)
-  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
-
-(define (concrete? t)
-  (case t
-    ('int #t)
-    ('bool #t)
-    (else #f)))
-
                                        ; returns a list of pairs of constraints
 (define (unify a b)
   (cond ((eq? a b) '())
                      (unify (caddr a) (caddr b))))
        (else (error #f "could not unify"))))
 
-
                                        ; TODO: what's the most appropriate substitution?
                                        ; should all constraints just be limited to a pair?
 (define (substitute cs t)
                                        ; gets the first concrete type
                                        ; otherwise returns the last type variable
+
   (define (get-concrete c)
-    (if (null? (cdr c))
-       (car c)
+    (let ((last (null? (cdr c))))
       (if (not (tvar? (car c)))
+         (if (abs? (car c))
+             (substitute cs (car c))
+             (car c))
+         (if last
              (car c)
-           (get-concrete (cdr c)))))
+             (get-concrete (cdr c))))))
+  (cond
+   ((abs? t) (list 'abs
+                  (substitute cs (cadr t))
+                  (substitute cs (caddr t))))
+   (else
     (fold-left
      (lambda (t c)
        (if (member t c)
           (get-concrete c)
           t))
-   t cs))
+     t cs))))
 
 (define (substitute-env cs env)
   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))