Fix total pattern match verification
[scheme.git] / typecheck.scm
index 821035230affddf4d813074316f5aecaeaa2efc5..b2b001f1d27220211e5f67e651d641c9895f5b5c 100644 (file)
@@ -1,4 +1,5 @@
 (load "ast.scm")
+(load "stdlib.scm")
 
 (define (abs? t)
   (and (list? t) (eq? (car t) 'abs)))
     ('print '(abs String Void))
     (else (error #f "Couldn't find type for builtin" x))))
 
-(define (check-let env x)
-                                       ; takes in the current environment and a scc
-                                       ; returns new environment with scc's types added in
-  (let* ([components (reverse (sccs (graph (let-bindings x))))]
-        [process-component
-         (lambda (acc comps)
+(define (check-let dls env x)
+
+  ; acc is a pair of (env . annotated bindings)
+  (define (process-component acc comps)
     (let*
                                        ; create a new env with tvars for each component
                                        ; e.g. scc of (x y)
          (fold-left
           (lambda (acc c)
             (env-insert acc c (fresh-tvar)))
-                  acc comps)]
+          (car acc) comps)]
                                        ; typecheck each component
         [type-results
          (map
           (lambda (c)
             (let ([body (cadr (assoc c (let-bindings x)))])
-                      (check scc-env body)))
+              (check dls scc-env body)))
           comps)]
                                        ; collect all the constraints in the scc
         [cs
                 (if (memv (car x) comps)
                     (cons (car x) (substitute cs (cdr x)))
                     x))
-                      scc-env)])
-             new-env))]
-        [new-env (fold-left process-component env components)])
-    (check new-env (last (let-body x)))))
+              scc-env)]
+
+        [annotated-bindings (append (cdr acc) ; the previous annotated bindings
+                                    (map list
+                                         comps
+                                         (map caddr type-results)))])
+      (cons new-env annotated-bindings)))
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+  (let* ([components (reverse (sccs (graph (let-bindings x))))]
+        [results (fold-left process-component (cons env '()) components)]
+        [new-env (car results)]
+        [annotated-bindings (cdr results)]
+
+        [body-results (map (lambda (body) (check dls new-env body)) (let-body x))]
+        [let-type (cadr (last body-results))]
+        [cs (fold-left (lambda (acc cs) (constraint-merge acc cs)) '() (map car body-results))]
+
+        [annotated `((let ,annotated-bindings ,@(map caddr body-results)) : ,let-type)])
+    (list cs let-type annotated)))
+
+(define (check-app dls env x)
+  (if (eqv? (car x) (cadr x))
+                                       ; recursive function (f f)
+                                       ; TODO: what about ((f a) f)????
+      (let* ([func-type (env-lookup env (car x))]
+            [return-type (fresh-tvar)]
+            [other-func-type `(abs ,func-type ,return-type)]
+            [cs (~ func-type other-func-type)]
+            [resolved-return-type (substitute cs return-type)]
+
+            [annotated `(((,(car x) : ,func-type)
+                          (,(cadr x) : ,func-type)) : ,resolved-return-type)])
+       (list cs resolved-return-type annotated)))
+
+                                       ; regular function
+  (let* ([arg-type-res (check dls env (cadr x))]
+        [arg-type (cadr arg-type-res)]
+        [func-type-res (check dls env (car x))]
+        [func-type (cadr func-type-res)]
+        
+                                       ; f ~ a -> t0
+        [func-c (~
+                 (substitute (car arg-type-res) func-type)
+                 `(abs ,arg-type ,(fresh-tvar)))]
+        [cs (constraint-merge
+             (constraint-merge func-c (car arg-type-res))
+             (car func-type-res))]
+        
+        [resolved-func-type (substitute cs func-type)]
+        [resolved-return-type (caddr resolved-func-type)]
 
-(define (check env x)
+        [annotated `((,(caddr func-type-res)
+                      ,(caddr arg-type-res)) : ,resolved-return-type)])
+
+    (if (abs? resolved-func-type)
+       (let ((return-type (substitute cs (caddr resolved-func-type))))
+         (list cs return-type annotated))
+       (error #f "not a function"))))
+
+(define (check-case dls env x)
+
+  (define (check-match switch-type x)
+    
+    (define (get-bindings product-types pattern)
+      (define (go product-type product)
+       (case (ast-type product)
+         ['var (list (cons product product-type))]
+                                       ; an inner pattern match
+         ['app (let* ([inner-sum (car product)]
+                      [inner-sums (cdr (assoc product-type dls))]
+                      [inner-product-types (cdr (assoc inner-sum inner-sums))])
+                 (get-bindings inner-product-types product))]
+         [else '()]))
+      (flat-map go product-types (cdr pattern)))
+
+    
+    (let ([pattern (car x)]
+         [expr (cadr x)])
+      (case (ast-type pattern)
+       ['app
+                                       ; a pattern match with bindings
+         (let ([sum (assoc (car pattern) (cdr (assoc switch-type dls)))])
+           (unless sum (error #f "can't pattern match ~a with ~a" switch-type pattern))
+           (let* ([names (cdr pattern)]
+                  [product-types (cdr sum)]
+                  [new-env (append (get-bindings product-types pattern) env)])
+
+             (check dls new-env expr)))]
+                                       ; pattern match with binding and no constructor
+       ['var (check dls (env-insert env pattern switch-type) expr)]
+                                       ; a pattern match without bindings
+       [else (check dls env expr)])))
+  
+  (let* ([switch-type-res (check dls env (case-switch x))]
+        [switch-type (cadr switch-type-res)]
+        
+        [case-expr-type-res (map (lambda (x) (check-match switch-type x)) (case-cases x))]
+        [case-expr-types (map cadr case-expr-type-res)]
+
+        [case-expr-equality-cs (fold-left constraint-merge '()
+                                          (map (lambda (t) (~ t (car case-expr-types)))
+                                               (cdr case-expr-types)))]
+
+        [resolved-type (substitute case-expr-equality-cs (car case-expr-types))]
+
+        [annotated `((case ,(caddr switch-type-res)
+                       ,@(map (lambda (c e et)
+                                `(,c ((,e : ,et))))
+                              (map car (case-cases x))
+                              (map cadr (case-cases x))
+                              case-expr-types)) : ,resolved-type)]
+        
+        [cs (fold-left constraint-merge '()
+                       (cons (car switch-type-res) case-expr-equality-cs))])
+    (list cs resolved-type annotated)))
+
+; returns a list (constraints type annotated)
+(define (check dls env x)
+  (define (make-result cs type)
+    (list cs type `(,x : ,type)))
   ;; (display "check: ")
   ;; (display x)
   ;; (display "\n\t")
   (let
       ((res
        (case (ast-type x)
-         ('int-literal (list '() 'Int))
-         ('bool-literal (list '() 'Bool))
-         ('string-literal (list '() 'String))
-         ('builtin (list '() (builtin-type x)))
+         ('int-literal (make-result '() 'Int))
+         ('string-literal (make-result '() 'String))
+         ('builtin (make-result '() (builtin-type x)))
 
          ('if
-          (let* ((cond-type-res (check env (cadr x)))
-                 (then-type-res (check env (caddr x)))
-                 (else-type-res (check env (cadddr x)))
+          (let* ((cond-type-res (check dls env (cadr x)))
+                 (then-type-res (check dls env (caddr x)))
+                 (else-type-res (check dls env (cadddr x)))
                  (then-eq-else-cs (~ (cadr then-type-res)
                                      (cadr else-type-res)))
                  (cs (constraint-merge
                       (constraint-merge (~ (cadr cond-type-res) 'Bool)
                                         (constraint-merge (car else-type-res)
                                                           then-eq-else-cs))))
-                 (return-type (substitute cs (cadr then-type-res))))
-            (list cs return-type)))
+                 (return-type (substitute cs (cadr then-type-res)))          
+                 [annotated `((if ,(caddr cond-type-res)
+                                  ,(caddr then-type-res)
+                                  ,(caddr else-type-res)) : ,return-type)])
+            (list cs return-type annotated)))
          
-         ('var (list '() (env-lookup env x)))
-         ('let (check-let env x))
+         ('var (make-result '() (env-lookup env x)))
+         ('let (check-let dls env x))
 
          
          ('lambda
-             (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
-
-                    (body-type-res (check new-env (lambda-body x)))
-                    (cs (car body-type-res))
-                    (subd-env (substitute-env (car body-type-res) new-env))
-                    (arg-type (env-lookup subd-env (lambda-arg x)))
-                    (resolved-arg-type (substitute cs arg-type))]
-               ;; (display "lambda:\n\t")
-               ;; (display prog)
-               ;; (display "\n\t")
-               ;; (display cs)
-               ;; (display "\n\t")
-               ;; (display (format "subd-env: ~a\n" subd-env))
-               ;; (display resolved-arg-type)
-               ;; (newline)
-               (list (car body-type-res)
-                     (list 'abs
-                           resolved-arg-type
-                           (cadr body-type-res)))))
+             (let* ([new-env (env-insert env (lambda-arg x) (fresh-tvar))]
 
-         ('app ; (f a)
-          (if (eqv? (car x) (cadr x))
-                                       ; recursive function (f f)
-              (let* [(func-type (env-lookup env (car x)))
-                     (return-type (fresh-tvar))
-                     (other-func-type `(abs ,func-type ,return-type))
-                     (cs (~ func-type other-func-type))
-                     (resolved-return-type (substitute cs return-type))]
-                (list cs resolved-return-type)))
+                    [body-type-res (check dls new-env (lambda-body x))]
+                    [cs (car body-type-res)]
+                    [subd-env (substitute-env (car body-type-res) new-env)]
+                    [arg-type (env-lookup subd-env (lambda-arg x))]
+                    [resolved-arg-type (substitute cs arg-type)]
+
+                    [lambda-type `(abs ,resolved-arg-type ,(cadr body-type-res))]
+
+                    [annotated `((lambda (,(lambda-arg x)) ,(caddr body-type-res)) : ,lambda-type)])
+               
+               (list (car body-type-res) ; constraints
+                     lambda-type  ; type
+                     annotated)))
+
+         
+         ('app (check-app dls env x))
+         ['case (check-case dls env x)])))
              
-                                       ; regular function
-              (let* ((arg-type-res (check env (cadr x)))
-                     (arg-type (cadr arg-type-res))
-                     (func-type-res (check env (car x)))
-                     (func-type (cadr func-type-res))
                
-                                       ; f ~ a -> t0
-                     (func-c (~
-                              (substitute (car arg-type-res) func-type)
-                              `(abs ,arg-type ,(fresh-tvar))))
-                     (cs (constraint-merge
-                          (constraint-merge func-c (car arg-type-res))
-                          (car func-type-res)))
-                     
-                     (resolved-func-type (substitute cs func-type))
-                     (resolved-return-type (caddr resolved-func-type)))
-                ;; (display "app:\n")
-                ;; (display cs)
-                ;; (display "\n")
-                ;; (display func-type)
-                ;; (display "\n")
-                ;; (display resolved-func-type)
-                ;; (display "\n")
-                ;; (display arg-type-res)
-                ;; (display "\n")
-                (if (abs? resolved-func-type)
-                    (let ((return-type (substitute cs (caddr resolved-func-type))))
-                      (list cs return-type))
-                    (error #f "not a function")))))))
     ;; (display "result of ")
     ;; (display x)
     ;; (display ":\n\t")
     ;; (display "]\n")
     res))
 
-                                       ; we typecheck the lambda calculus only (only single arg lambdas)
-(define (typecheck prog)
+(define (init-adts-env prog)
+  (flat-map data-tors-type-env (program-data-layouts prog)))
 
-  (let ([init-env (flat-map data-tors (program-datas prog))])
-    (display init-env)
-    (newline)
-    (cadr (check init-env (normalize (program-body prog))))))
+                                       ; we typecheck the lambda calculus only (only single arg lambdas)
+(define (typecheck prog-without-stdlib)
+  (let* ([prog (append stdlib prog-without-stdlib)]
+        [expanded (expand-pattern-matches prog)])
+    (cadr (check (program-data-layouts prog)
+                (init-adts-env expanded)
+                (normalize (program-body expanded))))))
+
+
+                                       ; before passing annotated types onto codegen
+                                       ; we need to restore the pre-normalization structure
+                                       ; (this is important for function arity etc)
+(define (denormalize orig normed)
+
+  (define (collapse-lambdas n x)
+    (case n
+      [0 x]
+      [else
+       (let* ([inner-lambda (lambda-body (ann-expr x))]
+             [arg (lambda-arg (ann-expr x))]
+             [inner-collapsed (ann-expr (collapse-lambdas (- n 1) inner-lambda))])
+        `((lambda ,(cons arg (lambda-args inner-collapsed))
+            ,(lambda-body inner-collapsed)) : ,(ann-type x)))]))
+
+  (define (collapse-apps n x)
+    (case n
+      [-1 (error #f "nullary functions not handled yet")]
+      [0 x]
+      [else
+       (let* ([inner-app (car (ann-expr x))]
+             [inner-collapsed (collapse-apps (- n 1) inner-app)])
+        `(,(append (ann-expr inner-collapsed) (cdr (ann-expr x))) : ,(ann-type x)))]))
+
+  (case (ast-type orig)
+    ['lambda
+       (let ([collapsed (collapse-lambdas (- (length (lambda-args orig)) 1) normed)])
+         `((lambda ,(lambda-args (ann-expr collapsed))
+             ,(denormalize (lambda-body orig)
+                           (lambda-body (ann-expr collapsed)))) : ,(ann-type collapsed)))]
+    ['app
+     (let ([collapsed (collapse-apps (- (length orig) 2) normed)])
+       `(,(map (lambda (o n) (denormalize o n)) orig (ann-expr collapsed))
+        : ,(ann-type collapsed)))]
+    ['let
+       `((let ,(map (lambda (o n) (list (car o) (denormalize (cadr o) (cadr n))))
+                    (let-bindings orig)
+                    (let-bindings (ann-expr normed)))
+           ,@(map denormalize
+                  (let-body orig)
+                  (let-body (ann-expr normed)))) : ,(ann-type normed))]
+    ['if `((if ,@(map denormalize (cdr orig) (cdr (ann-expr normed))))
+          : ,(ann-type normed))]
+    ['case `((case ,(denormalize (case-switch orig) (case-switch (ann-expr normed)))
+              ,@(map (lambda (o n) (cons (car o) (denormalize (cadr o) (cadr n))))
+                     (case-cases orig) (case-cases (ann-expr normed))))
+            : ,(ann-type normed))]
+    [else normed]))
+
+(define ann-expr car)
+(define ann-type caddr)
+
+                                       ; prerequisites: expand-pattern-matches
+(define (annotate-types prog)
+  (denormalize
+   (program-body prog)
+   (caddr (check (program-data-layouts prog)
+                (init-adts-env prog)
+                (normalize (program-body prog))))))
 
   
                                        ; returns a list of constraints
 
                                        ; input: a list of binds ((x . y) (y . 3))
                                        ; returns: pair of verts, edges ((x y) . (x . y))
-(define (graph bs)
-  (define (go bs orig-bs)
-    (define (find-refs prog)
-      (ast-collect
-       (lambda (x)
-        (case (ast-type x)
-                                       ; only count a reference if its a binding
-          ['var (if (assoc x orig-bs) (list x) '())]
-          [else '()]))
-       prog))
-    (if (null? bs)
-       '(() . ())
-       (let* [(bind (car bs))
-
-              (vert (car bind))
-              (refs (find-refs (cdr bind)))
-              (edges (map (lambda (x) (cons vert x))
-                          refs))
-
-              (rest (if (null? (cdr bs))
-                        (cons '() '())
-                        (go (cdr bs) orig-bs)))
-              (total-verts (cons vert (car rest)))
-              (total-edges (append edges (cdr rest)))]
-         (cons total-verts total-edges))))
-  (go bs bs))
-
-(define (successors graph v)
-  (define (go v E)
-    (if (null? E)
-       '()
-       (if (eqv? v (caar E))
-           (cons (cdar E) (go v (cdr E)))
-           (go v (cdr E)))))
-  (go v (cdr graph)))
-
-                                       ; takes in a graph (pair of vertices, edges)
-                                       ; returns a list of strongly connected components
-
-                                       ; ((x y w) . ((x . y) (x . w) (w . x))
-
-                                       ; =>
-                                       ; .->x->y
-                                       ; |  |
-                                       ; |  v
-                                       ; .--w
-
-                                       ; ((x w) (y))
-
-                                       ; this uses tarjan's algorithm, to get reverse
-                                       ; topological sorting for free
-(define (sccs graph)
-  
-  (let* ([indices (make-hash-table)]
-        [lowlinks (make-hash-table)]
-        [on-stack (make-hash-table)]
-        [current 0]
-        [stack '()]
-        [result '()])
-
-    (define (index v)
-      (get-hash-table indices v #f))
-    (define (lowlink v)
-      (get-hash-table lowlinks v #f))
-
-    (letrec
-       ([strong-connect
-         (lambda (v)
-           (begin
-             (put-hash-table! indices v current)
-             (put-hash-table! lowlinks v current)
-             (set! current (+ current 1))
-             (push! stack v)
-             (put-hash-table! on-stack v #t)
-
-             (for-each
-              (lambda (w)
-                (if (not (hashtable-contains? indices w))
-                                       ; successor w has not been visited, recurse
-                    (begin
-                      (strong-connect w)
-                      (put-hash-table! lowlinks
-                                       v
-                                       (min (lowlink v) (lowlink w))))
-                                       ; successor w has been visited
-                    (when (get-hash-table on-stack w #f)
-                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
-              (successors graph v))
-
-             (when (= (index v) (lowlink v))
-               (let ([scc
-                      (let new-scc ()
-                        (let ([w (pop! stack)])
-                          (put-hash-table! on-stack w #f)
-                          (if (eqv? w v)
-                              (list w)
-                              (cons w (new-scc)))))])
-                 (set! result (cons scc result))))))])
-      (for-each
-       (lambda (v)
-        (when (not (hashtable-contains? indices v)) ; v.index == -1
-          (strong-connect v)))
-       (car graph)))
-    result))