Fix total pattern match verification
[scheme.git] / typecheck.scm
index ebfd816ad62a9538839c808187d55cbeaad593db..b2b001f1d27220211e5f67e651d641c9895f5b5c 100644 (file)
@@ -1,4 +1,5 @@
 (load "ast.scm")
+(load "stdlib.scm")
 
 (define (abs? t)
   (and (list? t) (eq? (car t) 'abs)))
 (define (check-case dls env x)
 
   (define (check-match switch-type x)
+    
+    (define (get-bindings product-types pattern)
+      (define (go product-type product)
+       (case (ast-type product)
+         ['var (list (cons product product-type))]
+                                       ; an inner pattern match
+         ['app (let* ([inner-sum (car product)]
+                      [inner-sums (cdr (assoc product-type dls))]
+                      [inner-product-types (cdr (assoc inner-sum inner-sums))])
+                 (get-bindings inner-product-types product))]
+         [else '()]))
+      (flat-map go product-types (cdr pattern)))
+
+    
     (let ([pattern (car x)]
          [expr (cadr x)])
-      (if (eqv? (ast-type pattern) 'app)
+      (case (ast-type pattern)
+       ['app
                                        ; a pattern match with bindings
          (let ([sum (assoc (car pattern) (cdr (assoc switch-type dls)))])
            (unless sum (error #f "can't pattern match ~a with ~a" switch-type pattern))
            (let* ([names (cdr pattern)]
-                  [types (cdr sum)]
-                  [new-env (fold-left env-insert env names types)])
-             (check dls new-env expr)))
+                  [product-types (cdr sum)]
+                  [new-env (append (get-bindings product-types pattern) env)])
+
+             (check dls new-env expr)))]
+                                       ; pattern match with binding and no constructor
+       ['var (check dls (env-insert env pattern switch-type) expr)]
                                        ; a pattern match without bindings
-         (check dls env expr))))
+       [else (check dls env expr)])))
   
   (let* ([switch-type-res (check dls env (case-switch x))]
         [switch-type (cadr switch-type-res)]
 
         [resolved-type (substitute case-expr-equality-cs (car case-expr-types))]
 
-        [annotated `((case (,(case-expr x) : ,switch-type)
-                       ,(map (lambda (c e et)
-                               `(,c (,e : ,et)))
+        [annotated `((case ,(caddr switch-type-res)
+                       ,@(map (lambda (c e et)
+                                `(,c ((,e : ,et))))
                               (map car (case-cases x))
                               (map cadr (case-cases x))
                               case-expr-types)) : ,resolved-type)]
       ((res
        (case (ast-type x)
          ('int-literal (make-result '() 'Int))
-         ('bool-literal (make-result '() 'Bool))
          ('string-literal (make-result '() 'String))
          ('builtin (make-result '() (builtin-type x)))
 
   (flat-map data-tors-type-env (program-data-layouts prog)))
 
                                        ; we typecheck the lambda calculus only (only single arg lambdas)
-(define (typecheck prog)
-  (let ([expanded (expand-pattern-matches prog)])
+(define (typecheck prog-without-stdlib)
+  (let* ([prog (append stdlib prog-without-stdlib)]
+        [expanded (expand-pattern-matches prog)])
     (cadr (check (program-data-layouts prog)
                 (init-adts-env expanded)
                 (normalize (program-body expanded))))))
        `((let ,(map (lambda (o n) (list (car o) (denormalize (cadr o) (cadr n))))
                     (let-bindings orig)
                     (let-bindings (ann-expr normed)))
-           ,@(map (lambda (o n) (denormalize o n))
+           ,@(map denormalize
                   (let-body orig)
                   (let-body (ann-expr normed)))) : ,(ann-type normed))]
     ['if `((if ,@(map denormalize (cdr orig) (cdr (ann-expr normed))))
-          : (ann-type normed))]
+          : ,(ann-type normed))]
+    ['case `((case ,(denormalize (case-switch orig) (case-switch (ann-expr normed)))
+              ,@(map (lambda (o n) (cons (car o) (denormalize (cadr o) (cadr n))))
+                     (case-cases orig) (case-cases (ann-expr normed))))
+            : ,(ann-type normed))]
     [else normed]))
 
 (define ann-expr car)
 (define (annotate-types prog)
   (denormalize
    (program-body prog)
-   (caddr (check (init-adts-env prog) (normalize (program-body prog))))))
+   (caddr (check (program-data-layouts prog)
+                (init-adts-env prog)
+                (normalize (program-body prog))))))
 
   
                                        ; returns a list of constraints