Fix total pattern match verification
[scheme.git] / typecheck.scm
index 98d1cf449fdaeed5f1b0f1eb22abf60a405d3a66..b2b001f1d27220211e5f67e651d641c9895f5b5c 100644 (file)
@@ -1,7 +1,44 @@
 (load "ast.scm")
+(load "stdlib.scm")
+
+(define (abs? t)
+  (and (list? t) (eq? (car t) 'abs)))
+
+(define (tvar? t)
+  (and (not (list? t))
+       (not (concrete? t))
+       (symbol? t)))
+
+(define (concrete? t)
+  (and (symbol? t)
+       (char-upper-case? (string-ref (symbol->string t) 0))))
+
+(define (pretty-type t)
+  (cond ((abs? t)
+        (string-append
+         (if (abs? (cadr t))
+             (string-append "(" (pretty-type (cadr t)) ")")
+             (pretty-type (cadr t)))
+         " -> "
+         (pretty-type (caddr t))))
+       (else (symbol->string t))))
+
+(define (pretty-constraints cs)
+  (string-append "{"
+                (fold-left string-append
+                           ""
+                           (map (lambda (c)
+                                  (string-append
+                                   (pretty-type (car c))
+                                   ": "
+                                   (pretty-type (cdr c))
+                                   ", "))
+                                cs))
+                "}"))
+
                                        ; ('a, ('b, 'a))
 (define (env-lookup env n)
-  (if (null? env) (error #f "empty env")                       ; it's a type equality
+  (if (null? env) (error #f "empty env" env n)                 ; it's a type equality
       (if (eq? (caar env) n)
          (cdar env)
          (env-lookup (cdr env) n))))
       (car xs)
       (last (cdr xs))))
 
-                                       
 (define (normalize prog) ; (+ a b) -> ((+ a) b)
-  (cond
+  (case (ast-type prog)
+    ('lambda 
                                        ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
-   ((lambda? prog)
        (if (> (length (lambda-args prog)) 1)       
            (list 'lambda (list (car (lambda-args prog)))
                  (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
            (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
-   ((app? prog)
+    ('app
      (if (null? (cddr prog))
-       (cons (normalize (car prog)) (normalize (cdr prog))) ; (f a)
-       (list (list (normalize (car prog)) (normalize (cadr prog))) (normalize (caddr prog))))) ; (f a b)
-   ((let? prog)
+        `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
+        (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
+                     ,@(cddr prog))))) ; (f a b)
+    ('let
        (append (list 'let
-                 (map (lambda (x) (cons (car x) (normalize (cdr x))))
+                     (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
                           (let-bindings prog)))
                (map normalize (let-body prog))))
-   (else prog)))
+    (else (ast-traverse normalize prog))))
 
+(define (builtin-type x)
+  (case x
+    ('+ '(abs Int (abs Int Int)))
+    ('- '(abs Int (abs Int Int)))
+    ('* '(abs Int (abs Int Int)))
+    ('! '(abs Bool Bool))
+    ('= '(abs Int (abs Int Bool)))
+    ('bool->int '(abs Bool Int))
+    ('print '(abs String Void))
+    (else (error #f "Couldn't find type for builtin" x))))
 
-; we typecheck the lambda calculus only (only single arg lambdas)
-(define (typecheck prog)
-  (define (check env x)
-    (display "check: ")
-    (display x)
-    (display "\n\t")
-    (display env)
-    (newline)
-    (let
-       ((res
-         (cond
-          ((integer? x) (list '() 'int))
-          ((boolean? x) (list '() 'bool))
-          ((eq? x 'inc) (list '() '(abs int int)))
-          ((eq? x '+)   (list '() '(abs int (abs int int))))
-          ((symbol? x)  (list '() (env-lookup env x)))
-
-          ((let? x)
-           (let ((new-env (fold-left
-                           (lambda (acc bind)
-                             (let ((t (check
-                                       (env-insert acc (car bind) (fresh-tvar))
-                                       (cadr bind))))
-                               (env-insert acc (car bind) (cadr t))))
-                           env (let-bindings x))))
-             (check new-env (last (let-body x)))))
-                 
-
-          ((lambda? x)
-           (let* ((new-env (env-insert env (lambda-arg x) (fresh-tvar)))
-                  (body-type-res (check new-env (lambda-body x)))
-                  (subd-env (substitute-env (car body-type-res) new-env)))
-             (display "lambda: ")
-             (display body-type-res)
-             (display "\n")
-             (display subd-env)
-             (display "\n")
-             (list (car body-type-res)
-                   (list 'abs
-                         (env-lookup subd-env (lambda-arg x))
-                         (cadr body-type-res)))))
-          
-          ((app? x) ; (f a)
-           (let* ((arg-type-res (check env (cadr x)))
-                  (arg-type (cadr arg-type-res))
-                  (func-type-res (check env (car x)))
-                  (func-type (cadr func-type-res))
+(define (check-let dls env x)
+
+  ; acc is a pair of (env . annotated bindings)
+  (define (process-component acc comps)
+    (let*
+                                       ; create a new env with tvars for each component
+                                       ; e.g. scc of (x y)
+                                       ; scc-env = ((x . t0) (y . t1))
+       ([scc-env
+         (fold-left
+          (lambda (acc c)
+            (env-insert acc c (fresh-tvar)))
+          (car acc) comps)]
+                                       ; typecheck each component
+        [type-results
+         (map
+          (lambda (c)
+            (let ([body (cadr (assoc c (let-bindings x)))])
+              (check dls scc-env body)))
+          comps)]
+                                       ; collect all the constraints in the scc
+        [cs
+         (fold-left
+          (lambda (acc res c)
+            (constraint-merge
+             (constraint-merge
+                                       ; unify with tvars from scc-env
+                                       ; result ~ tvar
+              (~ (env-lookup scc-env c) (cadr res))
+              (car res))                                 
+             acc))
+          '() type-results comps)]
+                                       ; substitute *only* the bindings in this scc
+        [new-env
+         (map (lambda (x)
+                (if (memv (car x) comps)
+                    (cons (car x) (substitute cs (cdr x)))
+                    x))
+              scc-env)]
+
+        [annotated-bindings (append (cdr acc) ; the previous annotated bindings
+                                    (map list
+                                         comps
+                                         (map caddr type-results)))])
+      (cons new-env annotated-bindings)))
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+  (let* ([components (reverse (sccs (graph (let-bindings x))))]
+        [results (fold-left process-component (cons env '()) components)]
+        [new-env (car results)]
+        [annotated-bindings (cdr results)]
+
+        [body-results (map (lambda (body) (check dls new-env body)) (let-body x))]
+        [let-type (cadr (last body-results))]
+        [cs (fold-left (lambda (acc cs) (constraint-merge acc cs)) '() (map car body-results))]
+
+        [annotated `((let ,annotated-bindings ,@(map caddr body-results)) : ,let-type)])
+    (list cs let-type annotated)))
+
+(define (check-app dls env x)
+  (if (eqv? (car x) (cadr x))
+                                       ; recursive function (f f)
+                                       ; TODO: what about ((f a) f)????
+      (let* ([func-type (env-lookup env (car x))]
+            [return-type (fresh-tvar)]
+            [other-func-type `(abs ,func-type ,return-type)]
+            [cs (~ func-type other-func-type)]
+            [resolved-return-type (substitute cs return-type)]
+
+            [annotated `(((,(car x) : ,func-type)
+                          (,(cadr x) : ,func-type)) : ,resolved-return-type)])
+       (list cs resolved-return-type annotated)))
+
+                                       ; regular function
+  (let* ([arg-type-res (check dls env (cadr x))]
+        [arg-type (cadr arg-type-res)]
+        [func-type-res (check dls env (car x))]
+        [func-type (cadr func-type-res)]
         
                                        ; f ~ a -> t0
-                  (func-c (unify func-type
-                                 (list 'abs
-                                       arg-type
-                                       (fresh-tvar))))
-                  (cs (append func-c (car arg-type-res) (car func-type-res)))
-                  
-                  (resolved-func-type (substitute cs func-type))
-                  (resolved-return-type (caddr resolved-func-type)))
-             ;; (display "app:\n")
-             ;; (display cs)
-             ;; (display "\n")
-             ;; (display func-type)
-             ;; (display "\n")
-             ;; (display resolved-func-type)
-             ;; (display "\n")
-             ;; (display arg-type-res)
-             ;; (display "\n")
+        [func-c (~
+                 (substitute (car arg-type-res) func-type)
+                 `(abs ,arg-type ,(fresh-tvar)))]
+        [cs (constraint-merge
+             (constraint-merge func-c (car arg-type-res))
+             (car func-type-res))]
+        
+        [resolved-func-type (substitute cs func-type)]
+        [resolved-return-type (caddr resolved-func-type)]
+
+        [annotated `((,(caddr func-type-res)
+                      ,(caddr arg-type-res)) : ,resolved-return-type)])
+
     (if (abs? resolved-func-type)
        (let ((return-type (substitute cs (caddr resolved-func-type))))
-                   (list cs return-type))
-                 (error #f "not a function")))))))
-      (display "result of ")
-      (display x)
-      (display ":\n\t")
-      (display (cadr res))
-      (display "[")
-      (display (car res))
-      (display "]\n")
+         (list cs return-type annotated))
+       (error #f "not a function"))))
+
+(define (check-case dls env x)
+
+  (define (check-match switch-type x)
+    
+    (define (get-bindings product-types pattern)
+      (define (go product-type product)
+       (case (ast-type product)
+         ['var (list (cons product product-type))]
+                                       ; an inner pattern match
+         ['app (let* ([inner-sum (car product)]
+                      [inner-sums (cdr (assoc product-type dls))]
+                      [inner-product-types (cdr (assoc inner-sum inner-sums))])
+                 (get-bindings inner-product-types product))]
+         [else '()]))
+      (flat-map go product-types (cdr pattern)))
+
+    
+    (let ([pattern (car x)]
+         [expr (cadr x)])
+      (case (ast-type pattern)
+       ['app
+                                       ; a pattern match with bindings
+         (let ([sum (assoc (car pattern) (cdr (assoc switch-type dls)))])
+           (unless sum (error #f "can't pattern match ~a with ~a" switch-type pattern))
+           (let* ([names (cdr pattern)]
+                  [product-types (cdr sum)]
+                  [new-env (append (get-bindings product-types pattern) env)])
+
+             (check dls new-env expr)))]
+                                       ; pattern match with binding and no constructor
+       ['var (check dls (env-insert env pattern switch-type) expr)]
+                                       ; a pattern match without bindings
+       [else (check dls env expr)])))
+  
+  (let* ([switch-type-res (check dls env (case-switch x))]
+        [switch-type (cadr switch-type-res)]
+        
+        [case-expr-type-res (map (lambda (x) (check-match switch-type x)) (case-cases x))]
+        [case-expr-types (map cadr case-expr-type-res)]
+
+        [case-expr-equality-cs (fold-left constraint-merge '()
+                                          (map (lambda (t) (~ t (car case-expr-types)))
+                                               (cdr case-expr-types)))]
+
+        [resolved-type (substitute case-expr-equality-cs (car case-expr-types))]
+
+        [annotated `((case ,(caddr switch-type-res)
+                       ,@(map (lambda (c e et)
+                                `(,c ((,e : ,et))))
+                              (map car (case-cases x))
+                              (map cadr (case-cases x))
+                              case-expr-types)) : ,resolved-type)]
+        
+        [cs (fold-left constraint-merge '()
+                       (cons (car switch-type-res) case-expr-equality-cs))])
+    (list cs resolved-type annotated)))
+
+; returns a list (constraints type annotated)
+(define (check dls env x)
+  (define (make-result cs type)
+    (list cs type `(,x : ,type)))
+  ;; (display "check: ")
+  ;; (display x)
+  ;; (display "\n\t")
+  ;; (display env)
+  ;; (newline)
+  (let
+      ((res
+       (case (ast-type x)
+         ('int-literal (make-result '() 'Int))
+         ('string-literal (make-result '() 'String))
+         ('builtin (make-result '() (builtin-type x)))
+
+         ('if
+          (let* ((cond-type-res (check dls env (cadr x)))
+                 (then-type-res (check dls env (caddr x)))
+                 (else-type-res (check dls env (cadddr x)))
+                 (then-eq-else-cs (~ (cadr then-type-res)
+                                     (cadr else-type-res)))
+                 (cs (constraint-merge
+                      (car then-type-res)
+                      (constraint-merge (~ (cadr cond-type-res) 'Bool)
+                                        (constraint-merge (car else-type-res)
+                                                          then-eq-else-cs))))
+                 (return-type (substitute cs (cadr then-type-res)))          
+                 [annotated `((if ,(caddr cond-type-res)
+                                  ,(caddr then-type-res)
+                                  ,(caddr else-type-res)) : ,return-type)])
+            (list cs return-type annotated)))
+         
+         ('var (make-result '() (env-lookup env x)))
+         ('let (check-let dls env x))
+
+         
+         ('lambda
+             (let* ([new-env (env-insert env (lambda-arg x) (fresh-tvar))]
+
+                    [body-type-res (check dls new-env (lambda-body x))]
+                    [cs (car body-type-res)]
+                    [subd-env (substitute-env (car body-type-res) new-env)]
+                    [arg-type (env-lookup subd-env (lambda-arg x))]
+                    [resolved-arg-type (substitute cs arg-type)]
+
+                    [lambda-type `(abs ,resolved-arg-type ,(cadr body-type-res))]
+
+                    [annotated `((lambda (,(lambda-arg x)) ,(caddr body-type-res)) : ,lambda-type)])
+               
+               (list (car body-type-res) ; constraints
+                     lambda-type  ; type
+                     annotated)))
+
+         
+         ('app (check-app dls env x))
+         ['case (check-case dls env x)])))
+             
+               
+    ;; (display "result of ")
+    ;; (display x)
+    ;; (display ":\n\t")
+    ;; (display (pretty-type (cadr res)))
+    ;; (display "\n\t[")
+    ;; (display (pretty-constraints (car res)))
+    ;; (display "]\n")
     res))
-  (cadr (check '() (normalize prog))))
 
+(define (init-adts-env prog)
+  (flat-map data-tors-type-env (program-data-layouts prog)))
 
-(define (abs? t)
-  (and (list? t) (eq? (car t) 'abs)))
+                                       ; we typecheck the lambda calculus only (only single arg lambdas)
+(define (typecheck prog-without-stdlib)
+  (let* ([prog (append stdlib prog-without-stdlib)]
+        [expanded (expand-pattern-matches prog)])
+    (cadr (check (program-data-layouts prog)
+                (init-adts-env expanded)
+                (normalize (program-body expanded))))))
 
-(define (tvar? t)
-  (and (not (list? t)) (not (concrete? t)) (symbol? t)))
 
-(define (concrete? t)
-  (case t
-    ('int #t)
-    ('bool #t)
-    (else #f)))
-
-                                       ; returns a list of pairs of constraints
-(define (unify a b)
-  (cond ((eq? a b) '())
-       ((or (tvar? a) (tvar? b)) (~ a b))
-       ((and (abs? a) (abs? b))
-        (consolidate (unify (cadr a) (cadr b))
-                     (unify (caddr a) (caddr b))))
-       (else (error #f "could not unify"))))
-
-
-                                       ; TODO: what's the most appropriate substitution?
-                                       ; should all constraints just be limited to a pair?
+                                       ; before passing annotated types onto codegen
+                                       ; we need to restore the pre-normalization structure
+                                       ; (this is important for function arity etc)
+(define (denormalize orig normed)
+
+  (define (collapse-lambdas n x)
+    (case n
+      [0 x]
+      [else
+       (let* ([inner-lambda (lambda-body (ann-expr x))]
+             [arg (lambda-arg (ann-expr x))]
+             [inner-collapsed (ann-expr (collapse-lambdas (- n 1) inner-lambda))])
+        `((lambda ,(cons arg (lambda-args inner-collapsed))
+            ,(lambda-body inner-collapsed)) : ,(ann-type x)))]))
+
+  (define (collapse-apps n x)
+    (case n
+      [-1 (error #f "nullary functions not handled yet")]
+      [0 x]
+      [else
+       (let* ([inner-app (car (ann-expr x))]
+             [inner-collapsed (collapse-apps (- n 1) inner-app)])
+        `(,(append (ann-expr inner-collapsed) (cdr (ann-expr x))) : ,(ann-type x)))]))
+
+  (case (ast-type orig)
+    ['lambda
+       (let ([collapsed (collapse-lambdas (- (length (lambda-args orig)) 1) normed)])
+         `((lambda ,(lambda-args (ann-expr collapsed))
+             ,(denormalize (lambda-body orig)
+                           (lambda-body (ann-expr collapsed)))) : ,(ann-type collapsed)))]
+    ['app
+     (let ([collapsed (collapse-apps (- (length orig) 2) normed)])
+       `(,(map (lambda (o n) (denormalize o n)) orig (ann-expr collapsed))
+        : ,(ann-type collapsed)))]
+    ['let
+       `((let ,(map (lambda (o n) (list (car o) (denormalize (cadr o) (cadr n))))
+                    (let-bindings orig)
+                    (let-bindings (ann-expr normed)))
+           ,@(map denormalize
+                  (let-body orig)
+                  (let-body (ann-expr normed)))) : ,(ann-type normed))]
+    ['if `((if ,@(map denormalize (cdr orig) (cdr (ann-expr normed))))
+          : ,(ann-type normed))]
+    ['case `((case ,(denormalize (case-switch orig) (case-switch (ann-expr normed)))
+              ,@(map (lambda (o n) (cons (car o) (denormalize (cadr o) (cadr n))))
+                     (case-cases orig) (case-cases (ann-expr normed))))
+            : ,(ann-type normed))]
+    [else normed]))
+
+(define ann-expr car)
+(define ann-type caddr)
+
+                                       ; prerequisites: expand-pattern-matches
+(define (annotate-types prog)
+  (denormalize
+   (program-body prog)
+   (caddr (check (program-data-layouts prog)
+                (init-adts-env prog)
+                (normalize (program-body prog))))))
+
+  
+                                       ; returns a list of constraints
+(define (~ a b)
+  (let ([res (unify? a b)])
+    (if res
+       res
+       (error #f
+              (format "couldn't unify ~a ~~ ~a" a b)))))
+
+(define (unify? a b)
+  (cond [(eq? a b) '()]
+       [(tvar? a) (list (cons a b))]
+       [(tvar? b) (list (cons b a))]
+       [(and (abs? a) (abs? b))
+        (let* [(arg-cs (unify? (cadr a) (cadr b)))
+               (body-cs (unify? (substitute arg-cs (caddr a))
+                                (substitute arg-cs (caddr b))))]
+          (constraint-merge body-cs arg-cs))]
+       [else #f]))
+
 (define (substitute cs t)
-                                       ; gets the first concrete type
-                                       ; otherwise returns the last type variable
-  (define (get-concrete c)
-    (if (null? (cdr c))
-       (car c)
-       (if (not (tvar? (car c)))
-           (car c)
-           (get-concrete (cdr c)))))
-  (fold-left
-   (lambda (t c)
-     (if (member t c)
-        (get-concrete c)
-        t))
-   t cs))
+  (cond
+   [(tvar? t)
+    (if (assoc t cs)
+       (cdr (assoc t cs))
+       t)]
+   [(abs? t) `(abs ,(substitute cs (cadr t))
+                  ,(substitute cs (caddr t)))]
+   [else t]))
 
+                                       ; applies substitutions to all variables in environment
 (define (substitute-env cs env)
   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
 
-(define (~ a b)
-  (list (list a b)))
-
-(define (consolidate x y)
-  (define (merge a b)
-    (cond ((null? a) b)
-         ((null? b) a)
-         (else (if (member (car b) a)
-                   (merge a (cdr b))
-                   (cons (car b) (merge a (cdr b)))))))
-  (define (overlap? a b)
-    (if (or (null? a) (null? b))
-       #f
-       (if (fold-left (lambda (acc v)
-                        (or acc (eq? v (car a))))
-                      #f b)
-           #t
-           (overlap? (cdr a) b))))
-
-  (cond ((null? y) x)
-       ((null? x) y)
-       (else (let* ((a (car y))
-                    (merged (fold-left
-                             (lambda (acc b)
-                               (if acc
-                                   acc
-                                   (if (overlap? a b)
-                                       (cons (merge a b) b)
-                                       #f)))
-                             #f x))
-                    (removed (if merged
-                                 (filter (lambda (b) (not (eq? b (cdr merged)))) x)
-                                 x)))
-               (if merged
-                   (consolidate removed (cons (car merged) (cdr y)))
-                   (consolidate (cons a x) (cdr y)))))))
+                                       ; composes constraints a onto b and merges, i.e. applies a to b
+                                       ; a should be the "more important" constraints
+(define (constraint-merge a b)
+  (define (f cs constraint)
+    (cons (car constraint)
+         (substitute cs (cdr constraint))))
+  
+  (define (most-concrete a b)
+    (cond
+     [(tvar? a) b]
+     [(tvar? b) a]
+     [(and (abs? a) (abs? b))
+      `(abs ,(most-concrete (cadr a) (cadr b))
+           ,(most-concrete (caddr a) (caddr b)))]
+     [(abs? a) b]
+     [(abs? b) a]
+     [else a]))
+
+                                       ; for any two constraints that clash, e.g. t1 ~ abs t2 t3
+                                       ; and t1 ~ abs int t3
+                                       ; prepend the most concrete version of the type to the
+                                       ; list of constraints
+  (define (clashes)
+    (define (gen acc x)
+      (if (assoc (car x) a)
+         (cons (cons (car x) (most-concrete (cdr (assoc (car x) a))
+                                            (cdr x)))
+               acc)
+         acc))
+    (fold-left gen '() b))
+
+  (define (union p q)
+    (append (filter (lambda (x) (not (assoc (car x) p)))
+                   q)
+           p))
+  (append (clashes) (union a (map (lambda (z) (f a z)) b))))
+
+
+;;                                     ; a1 -> a2 ~ a3 -> a4;
+;;                                     ; a1 -> a2 !~ Bool -> Bool
+;;                                     ; basically can the tvars be renamed
+(define (types-equal? x y)
+  (let ([cs (unify? x y)])
+    (if (not cs) #f    
+       (let*
+           ([test (lambda (acc c)
+                    (and acc
+                         (tvar? (car c)) ; the only substitutions allowed are tvar -> tvar
+                         (tvar? (cdr c))))])
+         (fold-left test #t cs)))))
+
+                                       ; input: a list of binds ((x . y) (y . 3))
+                                       ; returns: pair of verts, edges ((x y) . (x . y))
+